Files
ailai_image_point_diff/ailai_pc/detet_pc.py
琉璃月光 1ec9bbab60 2.0
2025-10-22 17:52:29 +08:00

72 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# detect_pt.py
import cv2
import torch
from ultralytics import YOLO
# ======================
# 配置参数
# ======================
MODEL_PATH = 'best.pt' # 你的训练模型路径yolov8n.pt 或你自己训练的)
#IMG_PATH = '/home/hx/开发/ailai_image_obb/ailai_pc/train/192.168.0.234_01_202510141514352.jpg' # 测试图像路径
IMG_PATH = '1.jpg'
OUTPUT_PATH = '/home/hx/开发/ailai_image_obb/ailai_pc/output_pt.jpg' # 可视化结果保存路径
CONF_THRESH = 0.5 # 置信度阈值
CLASS_NAMES = ['bag'] # 你的类别名列表(按训练时顺序)
# 是否显示窗口(适合有 GUI 的 PC
SHOW_IMAGE = True
# ======================
# 主函数
# ======================
def main():
# 检查 CUDA
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"✅ 使用设备: {device}")
# 加载模型
print("➡️ 加载 YOLO 模型...")
model = YOLO(MODEL_PATH) # 自动加载架构和权重
model.to(device)
# 推理
print("➡️ 开始推理...")
results = model(IMG_PATH, imgsz=640, conf=CONF_THRESH, device=device)
# 获取第一张图的结果
r = results[0]
# 获取原始图像BGR
img = cv2.imread(IMG_PATH)
if img is None:
raise FileNotFoundError(f"无法读取图像: {IMG_PATH}")
print("\n📋 检测结果:")
for box in r.boxes:
# 获取数据
xyxy = box.xyxy[0].cpu().numpy() # [x1, y1, x2, y2]
conf = box.conf.cpu().numpy()[0] # 置信度
cls_id = int(box.cls.cpu().numpy()[0]) # 类别 ID
cls_name = CLASS_NAMES[cls_id] # 类别名
x1, y1, x2, y2 = map(int, xyxy)
print(f" 类别: {cls_name}, 置信度: {conf:.3f}, 框: [{x1}, {y1}, {x2}, {y2}]")
# 画框
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
# 画标签
label = f"{cls_name} {conf:.2f}"
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# 保存结果
cv2.imwrite(OUTPUT_PATH, img)
print(f"\n🖼️ 可视化结果已保存: {OUTPUT_PATH}")
# 显示(可选)
if SHOW_IMAGE:
cv2.imshow("YOLOv8 Detection", img)
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == '__main__':
main()