2025-10-22 17:52:29 +08:00
|
|
|
from ultralytics import YOLO
|
2025-11-03 16:10:50 +08:00
|
|
|
from ultralytics.utils.ops import non_max_suppression
|
|
|
|
|
import torch
|
|
|
|
|
import cv2
|
2025-10-22 17:52:29 +08:00
|
|
|
|
|
|
|
|
# ======================
|
|
|
|
|
# 配置参数
|
|
|
|
|
# ======================
|
2025-12-30 17:29:49 +08:00
|
|
|
MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai_detect2/weights/best.pt'
|
|
|
|
|
IMG_PATH = '4.jpg'
|
2025-11-03 16:10:50 +08:00
|
|
|
OUTPUT_PATH = 'output_pt.jpg'
|
|
|
|
|
CONF_THRESH = 0.5
|
|
|
|
|
IOU_THRESH = 0.45
|
2025-12-30 17:29:49 +08:00
|
|
|
CLASS_NAMES = ['bag', 'bag35']
|
2025-10-22 17:52:29 +08:00
|
|
|
|
|
|
|
|
# ======================
|
2025-12-30 17:29:49 +08:00
|
|
|
# 主函数
|
2025-10-22 17:52:29 +08:00
|
|
|
# ======================
|
|
|
|
|
def main():
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
print(f"✅ 使用设备: {device}")
|
|
|
|
|
|
2025-12-30 17:29:49 +08:00
|
|
|
model = YOLO(MODEL_PATH).to(device)
|
2025-10-22 17:52:29 +08:00
|
|
|
|
2025-11-03 16:10:50 +08:00
|
|
|
print("➡️ 开始推理...")
|
|
|
|
|
results = model(IMG_PATH, imgsz=640, conf=CONF_THRESH, device=device, verbose=True)
|
2025-10-22 17:52:29 +08:00
|
|
|
|
|
|
|
|
r = results[0]
|
2025-12-30 17:29:49 +08:00
|
|
|
pred = r.boxes.data # GPU tensor [N,6]
|
2025-10-22 17:52:29 +08:00
|
|
|
|
2025-11-03 16:10:50 +08:00
|
|
|
det = non_max_suppression(
|
2025-12-30 17:29:49 +08:00
|
|
|
pred.unsqueeze(0),
|
2025-11-03 16:10:50 +08:00
|
|
|
conf_thres=CONF_THRESH,
|
|
|
|
|
iou_thres=IOU_THRESH,
|
|
|
|
|
classes=None,
|
|
|
|
|
agnostic=False,
|
|
|
|
|
max_det=100
|
2025-12-30 17:29:49 +08:00
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
if det is None or len(det) == 0:
|
|
|
|
|
print("❌ 未检测到任何目标")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
det = det.cpu().numpy() # 只拷贝一次
|
2025-11-03 16:10:50 +08:00
|
|
|
|
2025-12-30 17:29:49 +08:00
|
|
|
# ======================
|
|
|
|
|
# ⭐ 关键:取置信度最高的结果
|
|
|
|
|
# ======================
|
|
|
|
|
best_det = max(det, key=lambda x: x[4])
|
2025-11-03 16:10:50 +08:00
|
|
|
|
2025-12-30 17:29:49 +08:00
|
|
|
x1, y1, x2, y2, conf, cls_id = best_det
|
|
|
|
|
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
|
|
|
|
|
cls_id = int(cls_id)
|
|
|
|
|
cls_name = CLASS_NAMES[cls_id]
|
|
|
|
|
|
|
|
|
|
print("\n🏆 置信度最高结果:")
|
|
|
|
|
print(f" 类别: {cls_name}")
|
|
|
|
|
print(f" 置信度: {conf:.3f}")
|
|
|
|
|
print(f" 框: [{x1}, {y1}, {x2}, {y2}]")
|
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
|
|
# 可视化(只画最高的)
|
|
|
|
|
# ======================
|
2025-10-22 17:52:29 +08:00
|
|
|
img = cv2.imread(IMG_PATH)
|
|
|
|
|
if img is None:
|
|
|
|
|
raise FileNotFoundError(f"无法读取图像: {IMG_PATH}")
|
|
|
|
|
|
2025-12-30 17:29:49 +08:00
|
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
|
|
|
label = f"{cls_name} {conf:.2f}"
|
|
|
|
|
cv2.putText(
|
|
|
|
|
img,
|
|
|
|
|
label,
|
|
|
|
|
(x1, max(y1 - 10, 0)),
|
|
|
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
|
|
0.9,
|
|
|
|
|
(0, 255, 0),
|
|
|
|
|
2
|
|
|
|
|
)
|
2025-10-22 17:52:29 +08:00
|
|
|
|
|
|
|
|
cv2.imwrite(OUTPUT_PATH, img)
|
2025-11-03 16:10:50 +08:00
|
|
|
print(f"\n🖼️ 可视化结果已保存: {OUTPUT_PATH}")
|
2025-10-22 17:52:29 +08:00
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2025-12-30 17:29:49 +08:00
|
|
|
main()
|