更新加入料带目标检测,判断料带到位,以及控制滚筒逻辑

This commit is contained in:
琉璃月光
2025-12-30 17:29:49 +08:00
parent d6918e90f2
commit 2028a96819
27 changed files with 1499 additions and 1224 deletions

View File

@ -6,71 +6,79 @@ import cv2
# ======================
# 配置参数
# ======================
MODEL_PATH = '/home/hx/开发/ailai_image_obb/ailai_pc/best12.pt'
IMG_PATH = '1.jpg'
MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai_detect2/weights/best.pt'
IMG_PATH = '4.jpg'
OUTPUT_PATH = 'output_pt.jpg'
CONF_THRESH = 0.5
IOU_THRESH = 0.45
CLASS_NAMES = ['bag']
CLASS_NAMES = ['bag', 'bag35']
# ======================
# 主函数(优化版)
# 主函数
# ======================
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"✅ 使用设备: {device}")
# 加载模型
model = YOLO(MODEL_PATH)
model.to(device)
model = YOLO(MODEL_PATH).to(device)
# 推理:获取原始结果(不立即解析)
print("➡️ 开始推理...")
results = model(IMG_PATH, imgsz=640, conf=CONF_THRESH, device=device, verbose=True)
# 获取第一张图的结果
r = results[0]
pred = r.boxes.data # GPU tensor [N,6]
# 🚀 关键:使用原始 tensor 在 GPU 上处理
# pred: [x1, y1, x2, y2, conf, cls] 形状为 [num_boxes, 6]
pred = r.boxes.data # 已经在 GPU 上,类型: torch.Tensor
# 🔍 在 GPU 上做 NMS这才是正确姿势
# 注意non_max_suppression 输入是 [batch, num_boxes, 6]
det = non_max_suppression(
pred.unsqueeze(0), # 增加 batch 维度
pred.unsqueeze(0),
conf_thres=CONF_THRESH,
iou_thres=IOU_THRESH,
classes=None,
agnostic=False,
max_det=100
)[0] # 取第一个也是唯一一个batch
)[0]
# ✅ 此时所有后处理已完成,现在才从 GPU 拷贝到 CPU
if det is not None and len(det):
det = det.cpu().numpy() # ← 只拷贝一次!
else:
det = []
if det is None or len(det) == 0:
print("❌ 未检测到任何目标")
return
# 读取图像
det = det.cpu().numpy() # 只拷贝一次
# ======================
# ⭐ 关键:取置信度最高的结果
# ======================
best_det = max(det, key=lambda x: x[4])
x1, y1, x2, y2, conf, cls_id = best_det
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
cls_id = int(cls_id)
cls_name = CLASS_NAMES[cls_id]
print("\n🏆 置信度最高结果:")
print(f" 类别: {cls_name}")
print(f" 置信度: {conf:.3f}")
print(f" 框: [{x1}, {y1}, {x2}, {y2}]")
# ======================
# 可视化(只画最高的)
# ======================
img = cv2.imread(IMG_PATH)
if img is None:
raise FileNotFoundError(f"无法读取图像: {IMG_PATH}")
print("\n📋 检测结果:")
for *xyxy, conf, cls_id in det:
x1, y1, x2, y2 = map(int, xyxy)
cls_name = CLASS_NAMES[int(cls_id)]
print(f" 类别: {cls_name}, 置信度: {conf:.3f}, 框: [{x1}, {y1}, {x2}, {y2}]")
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"{cls_name} {conf:.2f}"
cv2.putText(
img,
label,
(x1, max(y1 - 10, 0)),
cv2.FONT_HERSHEY_SIMPLEX,
0.9,
(0, 255, 0),
2
)
# 画框和标签
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"{cls_name} {conf:.2f}"
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# 保存结果
cv2.imwrite(OUTPUT_PATH, img)
print(f"\n🖼️ 可视化结果已保存: {OUTPUT_PATH}")
if __name__ == '__main__':
main()
main()