Files
xiantiao_CV/class_xiantiao_pc/detect.py

115 lines
4.6 KiB
Python
Raw Normal View History

2025-12-16 15:12:02 +08:00
import os
import cv2
from ultralytics import YOLO
class ObjectDetector:
"""封装 YOLO 目标检测模型,检测图像中的缺陷"""
def __init__(self, model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
self.model = YOLO(model_path)
print(f"[INFO] 成功加载 YOLO 目标检测模型: {model_path}")
def detect(self, img_np, conf_threshold=0.5):
# 注意:这里先不设 conf 阈值,以便后续按类别筛选最高分
results = self.model.predict(img_np, conf=0.0, verbose=False) # 获取所有预测
detections = []
for result in results:
boxes = result.boxes.cpu().numpy()
for box in boxes:
if box.conf.item() >= conf_threshold: # 在 Python 层过滤
detection_info = {
'bbox': box.xyxy[0],
'confidence': box.conf.item(),
'class_id': int(box.cls.item())
}
detections.append(detection_info)
return detections
def detect_quexian(img_path="1.png", model_path="/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt",
conf_threshold=0.5, debug=False):
"""
检测木条图像中的孔洞/裂缝缺陷并返回是否为良品
每个类别仅保留置信度最高的一个框
Args:
img_path (str): 输入图像路径
model_path (str): YOLO 检测模型路径必须是 detect 任务训练的
conf_threshold (float): 置信度阈值
debug (bool): 是否启用调试模式打印详细信息
Returns:
bool: True 表示无缺陷良品False 表示有缺陷不良品
"""
# 1. 加载图像
if not os.path.exists(img_path):
raise FileNotFoundError(f"图片不存在: {img_path}")
img_np = cv2.imread(img_path)
if img_np is None:
raise ValueError("图像加载失败,可能路径错误或图像文件损坏。")
# 2. 加载模型并检测(获取所有 ≥ conf_threshold 的框)
detector = ObjectDetector(model_path)
all_detections = detector.detect(img_np, conf_threshold=conf_threshold)
# 3. 按类别分组,取每个类别中置信度最高的框
best_per_class = {}
for det in all_detections:
cls_id = det['class_id']
if cls_id not in best_per_class or det['confidence'] > best_per_class[cls_id]['confidence']:
best_per_class[cls_id] = det
# 转为列表(用于后续处理)
top_detections = list(best_per_class.values())
# 4. 判定是否有缺陷:只要有一个类别有框,就算有缺陷
has_defect = len(top_detections) > 0
is_good = not has_defect
# 5. 可视化:只绘制每个类别的最高置信度框
label_map = {0: "hole", 1: "crack"}
vis_img = img_np.copy()
for det in top_detections:
x1, y1, x2, y2 = map(int, det['bbox'])
conf = det['confidence']
cls_id = det['class_id']
label = f"{label_map.get(cls_id, '未知')} {conf:.2f}"
cv2.rectangle(vis_img, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.putText(vis_img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
# 6. 调试信息(仅在 debug=True 时输出)
if debug:
print(f"\n===== 缺陷检测结果 (DEBUG 模式) =====")
print(f"置信度阈值: {conf_threshold}")
print(f"有效类别数量: {len(top_detections)}")
for i, det in enumerate(top_detections):
cls_name = label_map.get(det['class_id'], '未知')
bbox_int = det['bbox'].astype(int).tolist()
print(f" - 类别 '{cls_name}' 最高置信度框: 置信度={det['confidence']:.3f}, bbox={bbox_int}")
# 7. 显示结果图像
cv2.imshow('Detection Results', vis_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
return is_good
if __name__ == "__main__":
# 示例:启用 debug 模式
is_good_product = detect_quexian(
#img_path="/home/hx/开发/ML_xiantiao/class_xiantiao_pc/test_image/val/1.jpg",
img_path="/home/hx/开发/ML_xiantiao/class_xiantiao_pc/test_image/train/微信图片_20251216095823_227.jpg",
model_path="/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_detect/weights/best.pt",
conf_threshold=0.5,
debug=True # 改为 False 即静默模式
)
# 主程序最终输出(简洁版)
if is_good_product:
print("产品合格")
else:
print("产品存在缺陷")