import os import cv2 from ultralytics import YOLO class ObjectDetector: """封装 YOLO 目标检测模型,检测图像中的缺陷""" def __init__(self, model_path): if not os.path.exists(model_path): raise FileNotFoundError(f"模型文件不存在: {model_path}") self.model = YOLO(model_path) print(f"[INFO] 成功加载 YOLO 目标检测模型: {model_path}") def detect(self, img_np, conf_threshold=0.5): # 注意:这里先不设 conf 阈值,以便后续按类别筛选最高分 results = self.model.predict(img_np, conf=0.0, verbose=False) # 获取所有预测 detections = [] for result in results: boxes = result.boxes.cpu().numpy() for box in boxes: if box.conf.item() >= conf_threshold: # 在 Python 层过滤 detection_info = { 'bbox': box.xyxy[0], 'confidence': box.conf.item(), 'class_id': int(box.cls.item()) } detections.append(detection_info) return detections def detect_quexian(img_path="1.png", model_path="/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt", conf_threshold=0.5, debug=False): """ 检测木条图像中的孔洞/裂缝缺陷,并返回是否为良品。 每个类别仅保留置信度最高的一个框。 Args: img_path (str): 输入图像路径 model_path (str): YOLO 检测模型路径(必须是 detect 任务训练的) conf_threshold (float): 置信度阈值 debug (bool): 是否启用调试模式(打印详细信息) Returns: bool: True 表示无缺陷(良品),False 表示有缺陷(不良品) """ # 1. 加载图像 if not os.path.exists(img_path): raise FileNotFoundError(f"图片不存在: {img_path}") img_np = cv2.imread(img_path) if img_np is None: raise ValueError("图像加载失败,可能路径错误或图像文件损坏。") # 2. 加载模型并检测(获取所有 ≥ conf_threshold 的框) detector = ObjectDetector(model_path) all_detections = detector.detect(img_np, conf_threshold=conf_threshold) # 3. 按类别分组,取每个类别中置信度最高的框 best_per_class = {} for det in all_detections: cls_id = det['class_id'] if cls_id not in best_per_class or det['confidence'] > best_per_class[cls_id]['confidence']: best_per_class[cls_id] = det # 转为列表(用于后续处理) top_detections = list(best_per_class.values()) # 4. 判定是否有缺陷:只要有一个类别有框,就算有缺陷 has_defect = len(top_detections) > 0 is_good = not has_defect # 5. 可视化:只绘制每个类别的最高置信度框 label_map = {0: "hole", 1: "crack"} vis_img = img_np.copy() for det in top_detections: x1, y1, x2, y2 = map(int, det['bbox']) conf = det['confidence'] cls_id = det['class_id'] label = f"{label_map.get(cls_id, '未知')} {conf:.2f}" cv2.rectangle(vis_img, (x1, y1), (x2, y2), (0, 0, 255), 2) cv2.putText(vis_img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2) # 6. 调试信息(仅在 debug=True 时输出) if debug: print(f"\n===== 缺陷检测结果 (DEBUG 模式) =====") print(f"置信度阈值: {conf_threshold}") print(f"有效类别数量: {len(top_detections)}") for i, det in enumerate(top_detections): cls_name = label_map.get(det['class_id'], '未知') bbox_int = det['bbox'].astype(int).tolist() print(f" - 类别 '{cls_name}' 最高置信度框: 置信度={det['confidence']:.3f}, bbox={bbox_int}") # 7. 显示结果图像 cv2.imshow('Detection Results', vis_img) cv2.waitKey(0) cv2.destroyAllWindows() return is_good if __name__ == "__main__": # 示例:启用 debug 模式 is_good_product = detect_quexian( #img_path="/home/hx/开发/ML_xiantiao/class_xiantiao_pc/test_image/val/1.jpg", img_path="/home/hx/开发/ML_xiantiao/class_xiantiao_pc/test_image/train/微信图片_20251216095823_227.jpg", model_path="/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_detect/weights/best.pt", conf_threshold=0.5, debug=True # 改为 False 即静默模式 ) # 主程序最终输出(简洁版) if is_good_product: print("产品合格") else: print("产品存在缺陷")