Files
xiantiao_CV/class_xiantiao_pc/detect.py
琉璃月光 8506c3af79 first commit
2025-12-16 15:12:02 +08:00

115 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
from ultralytics import YOLO
class ObjectDetector:
"""封装 YOLO 目标检测模型,检测图像中的缺陷"""
def __init__(self, model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
self.model = YOLO(model_path)
print(f"[INFO] 成功加载 YOLO 目标检测模型: {model_path}")
def detect(self, img_np, conf_threshold=0.5):
# 注意:这里先不设 conf 阈值,以便后续按类别筛选最高分
results = self.model.predict(img_np, conf=0.0, verbose=False) # 获取所有预测
detections = []
for result in results:
boxes = result.boxes.cpu().numpy()
for box in boxes:
if box.conf.item() >= conf_threshold: # 在 Python 层过滤
detection_info = {
'bbox': box.xyxy[0],
'confidence': box.conf.item(),
'class_id': int(box.cls.item())
}
detections.append(detection_info)
return detections
def detect_quexian(img_path="1.png", model_path="/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt",
conf_threshold=0.5, debug=False):
"""
检测木条图像中的孔洞/裂缝缺陷,并返回是否为良品。
每个类别仅保留置信度最高的一个框。
Args:
img_path (str): 输入图像路径
model_path (str): YOLO 检测模型路径(必须是 detect 任务训练的)
conf_threshold (float): 置信度阈值
debug (bool): 是否启用调试模式(打印详细信息)
Returns:
bool: True 表示无缺陷良品False 表示有缺陷(不良品)
"""
# 1. 加载图像
if not os.path.exists(img_path):
raise FileNotFoundError(f"图片不存在: {img_path}")
img_np = cv2.imread(img_path)
if img_np is None:
raise ValueError("图像加载失败,可能路径错误或图像文件损坏。")
# 2. 加载模型并检测(获取所有 ≥ conf_threshold 的框)
detector = ObjectDetector(model_path)
all_detections = detector.detect(img_np, conf_threshold=conf_threshold)
# 3. 按类别分组,取每个类别中置信度最高的框
best_per_class = {}
for det in all_detections:
cls_id = det['class_id']
if cls_id not in best_per_class or det['confidence'] > best_per_class[cls_id]['confidence']:
best_per_class[cls_id] = det
# 转为列表(用于后续处理)
top_detections = list(best_per_class.values())
# 4. 判定是否有缺陷:只要有一个类别有框,就算有缺陷
has_defect = len(top_detections) > 0
is_good = not has_defect
# 5. 可视化:只绘制每个类别的最高置信度框
label_map = {0: "hole", 1: "crack"}
vis_img = img_np.copy()
for det in top_detections:
x1, y1, x2, y2 = map(int, det['bbox'])
conf = det['confidence']
cls_id = det['class_id']
label = f"{label_map.get(cls_id, '未知')} {conf:.2f}"
cv2.rectangle(vis_img, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.putText(vis_img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
# 6. 调试信息(仅在 debug=True 时输出)
if debug:
print(f"\n===== 缺陷检测结果 (DEBUG 模式) =====")
print(f"置信度阈值: {conf_threshold}")
print(f"有效类别数量: {len(top_detections)}")
for i, det in enumerate(top_detections):
cls_name = label_map.get(det['class_id'], '未知')
bbox_int = det['bbox'].astype(int).tolist()
print(f" - 类别 '{cls_name}' 最高置信度框: 置信度={det['confidence']:.3f}, bbox={bbox_int}")
# 7. 显示结果图像
cv2.imshow('Detection Results', vis_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
return is_good
if __name__ == "__main__":
# 示例:启用 debug 模式
is_good_product = detect_quexian(
#img_path="/home/hx/开发/ML_xiantiao/class_xiantiao_pc/test_image/val/1.jpg",
img_path="/home/hx/开发/ML_xiantiao/class_xiantiao_pc/test_image/train/微信图片_20251216095823_227.jpg",
model_path="/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_detect/weights/best.pt",
conf_threshold=0.5,
debug=True # 改为 False 即静默模式
)
# 主程序最终输出(简洁版)
if is_good_product:
print("产品合格")
else:
print("产品存在缺陷")