115 lines
4.6 KiB
Python
115 lines
4.6 KiB
Python
|
|
import os
|
|||
|
|
import cv2
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ObjectDetector:
|
|||
|
|
"""封装 YOLO 目标检测模型,检测图像中的缺陷"""
|
|||
|
|
|
|||
|
|
def __init__(self, model_path):
|
|||
|
|
if not os.path.exists(model_path):
|
|||
|
|
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
|||
|
|
self.model = YOLO(model_path)
|
|||
|
|
print(f"[INFO] 成功加载 YOLO 目标检测模型: {model_path}")
|
|||
|
|
|
|||
|
|
def detect(self, img_np, conf_threshold=0.5):
|
|||
|
|
# 注意:这里先不设 conf 阈值,以便后续按类别筛选最高分
|
|||
|
|
results = self.model.predict(img_np, conf=0.0, verbose=False) # 获取所有预测
|
|||
|
|
detections = []
|
|||
|
|
for result in results:
|
|||
|
|
boxes = result.boxes.cpu().numpy()
|
|||
|
|
for box in boxes:
|
|||
|
|
if box.conf.item() >= conf_threshold: # 在 Python 层过滤
|
|||
|
|
detection_info = {
|
|||
|
|
'bbox': box.xyxy[0],
|
|||
|
|
'confidence': box.conf.item(),
|
|||
|
|
'class_id': int(box.cls.item())
|
|||
|
|
}
|
|||
|
|
detections.append(detection_info)
|
|||
|
|
return detections
|
|||
|
|
|
|||
|
|
|
|||
|
|
def detect_quexian(img_path="1.png", model_path="/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt",
|
|||
|
|
conf_threshold=0.5, debug=False):
|
|||
|
|
"""
|
|||
|
|
检测木条图像中的孔洞/裂缝缺陷,并返回是否为良品。
|
|||
|
|
每个类别仅保留置信度最高的一个框。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
img_path (str): 输入图像路径
|
|||
|
|
model_path (str): YOLO 检测模型路径(必须是 detect 任务训练的)
|
|||
|
|
conf_threshold (float): 置信度阈值
|
|||
|
|
debug (bool): 是否启用调试模式(打印详细信息)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
bool: True 表示无缺陷(良品),False 表示有缺陷(不良品)
|
|||
|
|
"""
|
|||
|
|
# 1. 加载图像
|
|||
|
|
if not os.path.exists(img_path):
|
|||
|
|
raise FileNotFoundError(f"图片不存在: {img_path}")
|
|||
|
|
img_np = cv2.imread(img_path)
|
|||
|
|
if img_np is None:
|
|||
|
|
raise ValueError("图像加载失败,可能路径错误或图像文件损坏。")
|
|||
|
|
|
|||
|
|
# 2. 加载模型并检测(获取所有 ≥ conf_threshold 的框)
|
|||
|
|
detector = ObjectDetector(model_path)
|
|||
|
|
all_detections = detector.detect(img_np, conf_threshold=conf_threshold)
|
|||
|
|
|
|||
|
|
# 3. 按类别分组,取每个类别中置信度最高的框
|
|||
|
|
best_per_class = {}
|
|||
|
|
for det in all_detections:
|
|||
|
|
cls_id = det['class_id']
|
|||
|
|
if cls_id not in best_per_class or det['confidence'] > best_per_class[cls_id]['confidence']:
|
|||
|
|
best_per_class[cls_id] = det
|
|||
|
|
|
|||
|
|
# 转为列表(用于后续处理)
|
|||
|
|
top_detections = list(best_per_class.values())
|
|||
|
|
|
|||
|
|
# 4. 判定是否有缺陷:只要有一个类别有框,就算有缺陷
|
|||
|
|
has_defect = len(top_detections) > 0
|
|||
|
|
is_good = not has_defect
|
|||
|
|
|
|||
|
|
# 5. 可视化:只绘制每个类别的最高置信度框
|
|||
|
|
label_map = {0: "hole", 1: "crack"}
|
|||
|
|
vis_img = img_np.copy()
|
|||
|
|
for det in top_detections:
|
|||
|
|
x1, y1, x2, y2 = map(int, det['bbox'])
|
|||
|
|
conf = det['confidence']
|
|||
|
|
cls_id = det['class_id']
|
|||
|
|
label = f"{label_map.get(cls_id, '未知')} {conf:.2f}"
|
|||
|
|
cv2.rectangle(vis_img, (x1, y1), (x2, y2), (0, 0, 255), 2)
|
|||
|
|
cv2.putText(vis_img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
|
|||
|
|
|
|||
|
|
# 6. 调试信息(仅在 debug=True 时输出)
|
|||
|
|
if debug:
|
|||
|
|
print(f"\n===== 缺陷检测结果 (DEBUG 模式) =====")
|
|||
|
|
print(f"置信度阈值: {conf_threshold}")
|
|||
|
|
print(f"有效类别数量: {len(top_detections)}")
|
|||
|
|
for i, det in enumerate(top_detections):
|
|||
|
|
cls_name = label_map.get(det['class_id'], '未知')
|
|||
|
|
bbox_int = det['bbox'].astype(int).tolist()
|
|||
|
|
print(f" - 类别 '{cls_name}' 最高置信度框: 置信度={det['confidence']:.3f}, bbox={bbox_int}")
|
|||
|
|
|
|||
|
|
# 7. 显示结果图像
|
|||
|
|
cv2.imshow('Detection Results', vis_img)
|
|||
|
|
cv2.waitKey(0)
|
|||
|
|
cv2.destroyAllWindows()
|
|||
|
|
|
|||
|
|
return is_good
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 示例:启用 debug 模式
|
|||
|
|
is_good_product = detect_quexian(
|
|||
|
|
#img_path="/home/hx/开发/ML_xiantiao/class_xiantiao_pc/test_image/val/1.jpg",
|
|||
|
|
img_path="/home/hx/开发/ML_xiantiao/class_xiantiao_pc/test_image/train/微信图片_20251216095823_227.jpg",
|
|||
|
|
model_path="/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_detect/weights/best.pt",
|
|||
|
|
conf_threshold=0.5,
|
|||
|
|
debug=True # 改为 False 即静默模式
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 主程序最终输出(简洁版)
|
|||
|
|
if is_good_product:
|
|||
|
|
print("产品合格")
|
|||
|
|
else:
|
|||
|
|
print("产品存在缺陷")
|