Files
zjsh_yolov11/yolo11-mobilenetv4/detet_pc.py
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

76 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
from ultralytics.utils.ops import non_max_suppression
import torch
import cv2
# ======================
# 配置参数
# ======================
MODEL_PATH = '/home/hx/yolo11-jz/runs/train/exp4/weights/best.pt'
IMG_PATH = '1.jpg'
OUTPUT_PATH = 'output_pt.jpg'
CONF_THRESH = 0.5
IOU_THRESH = 0.45
CLASS_NAMES = ['bag']
# ======================
# 主函数(优化版)
# ======================
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"✅ 使用设备: {device}")
# 加载模型
model = YOLO(MODEL_PATH)
model.to(device)
# 推理:获取原始结果(不立即解析)
print("➡️ 开始推理...")
results = model(IMG_PATH, imgsz=640, conf=CONF_THRESH, device=device, verbose=True)
# 获取第一张图的结果
r = results[0]
# 🚀 关键:使用原始 tensor 在 GPU 上处理
# pred: [x1, y1, x2, y2, conf, cls] 形状为 [num_boxes, 6]
pred = r.boxes.data # 已经在 GPU 上,类型: torch.Tensor
# 🔍 在 GPU 上做 NMS这才是正确姿势
# 注意non_max_suppression 输入是 [batch, num_boxes, 6]
det = non_max_suppression(
pred.unsqueeze(0), # 增加 batch 维度
conf_thres=CONF_THRESH,
iou_thres=IOU_THRESH,
classes=None,
agnostic=False,
max_det=100
)[0] # 取第一个也是唯一一个batch
# ✅ 此时所有后处理已完成,现在才从 GPU 拷贝到 CPU
if det is not None and len(det):
det = det.cpu().numpy() # ← 只拷贝一次!
else:
det = []
# 读取图像
img = cv2.imread(IMG_PATH)
if img is None:
raise FileNotFoundError(f"无法读取图像: {IMG_PATH}")
print("\n📋 检测结果:")
for *xyxy, conf, cls_id in det:
x1, y1, x2, y2 = map(int, xyxy)
cls_name = CLASS_NAMES[int(cls_id)]
print(f" 类别: {cls_name}, 置信度: {conf:.3f}, 框: [{x1}, {y1}, {x2}, {y2}]")
# 画框和标签
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"{cls_name} {conf:.2f}"
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# 保存结果
cv2.imwrite(OUTPUT_PATH, img)
print(f"\n🖼️ 可视化结果已保存: {OUTPUT_PATH}")
if __name__ == '__main__':
main()