first commit
This commit is contained in:
115
class_xiantiao_pc/detect.py
Normal file
115
class_xiantiao_pc/detect.py
Normal file
@ -0,0 +1,115 @@
|
||||
import os
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
class ObjectDetector:
|
||||
"""封装 YOLO 目标检测模型,检测图像中的缺陷"""
|
||||
|
||||
def __init__(self, model_path):
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
self.model = YOLO(model_path)
|
||||
print(f"[INFO] 成功加载 YOLO 目标检测模型: {model_path}")
|
||||
|
||||
def detect(self, img_np, conf_threshold=0.5):
|
||||
# 注意:这里先不设 conf 阈值,以便后续按类别筛选最高分
|
||||
results = self.model.predict(img_np, conf=0.0, verbose=False) # 获取所有预测
|
||||
detections = []
|
||||
for result in results:
|
||||
boxes = result.boxes.cpu().numpy()
|
||||
for box in boxes:
|
||||
if box.conf.item() >= conf_threshold: # 在 Python 层过滤
|
||||
detection_info = {
|
||||
'bbox': box.xyxy[0],
|
||||
'confidence': box.conf.item(),
|
||||
'class_id': int(box.cls.item())
|
||||
}
|
||||
detections.append(detection_info)
|
||||
return detections
|
||||
|
||||
|
||||
def detect_quexian(img_path="1.png", model_path="/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt",
|
||||
conf_threshold=0.5, debug=False):
|
||||
"""
|
||||
检测木条图像中的孔洞/裂缝缺陷,并返回是否为良品。
|
||||
每个类别仅保留置信度最高的一个框。
|
||||
|
||||
Args:
|
||||
img_path (str): 输入图像路径
|
||||
model_path (str): YOLO 检测模型路径(必须是 detect 任务训练的)
|
||||
conf_threshold (float): 置信度阈值
|
||||
debug (bool): 是否启用调试模式(打印详细信息)
|
||||
|
||||
Returns:
|
||||
bool: True 表示无缺陷(良品),False 表示有缺陷(不良品)
|
||||
"""
|
||||
# 1. 加载图像
|
||||
if not os.path.exists(img_path):
|
||||
raise FileNotFoundError(f"图片不存在: {img_path}")
|
||||
img_np = cv2.imread(img_path)
|
||||
if img_np is None:
|
||||
raise ValueError("图像加载失败,可能路径错误或图像文件损坏。")
|
||||
|
||||
# 2. 加载模型并检测(获取所有 ≥ conf_threshold 的框)
|
||||
detector = ObjectDetector(model_path)
|
||||
all_detections = detector.detect(img_np, conf_threshold=conf_threshold)
|
||||
|
||||
# 3. 按类别分组,取每个类别中置信度最高的框
|
||||
best_per_class = {}
|
||||
for det in all_detections:
|
||||
cls_id = det['class_id']
|
||||
if cls_id not in best_per_class or det['confidence'] > best_per_class[cls_id]['confidence']:
|
||||
best_per_class[cls_id] = det
|
||||
|
||||
# 转为列表(用于后续处理)
|
||||
top_detections = list(best_per_class.values())
|
||||
|
||||
# 4. 判定是否有缺陷:只要有一个类别有框,就算有缺陷
|
||||
has_defect = len(top_detections) > 0
|
||||
is_good = not has_defect
|
||||
|
||||
# 5. 可视化:只绘制每个类别的最高置信度框
|
||||
label_map = {0: "hole", 1: "crack"}
|
||||
vis_img = img_np.copy()
|
||||
for det in top_detections:
|
||||
x1, y1, x2, y2 = map(int, det['bbox'])
|
||||
conf = det['confidence']
|
||||
cls_id = det['class_id']
|
||||
label = f"{label_map.get(cls_id, '未知')} {conf:.2f}"
|
||||
cv2.rectangle(vis_img, (x1, y1), (x2, y2), (0, 0, 255), 2)
|
||||
cv2.putText(vis_img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
|
||||
|
||||
# 6. 调试信息(仅在 debug=True 时输出)
|
||||
if debug:
|
||||
print(f"\n===== 缺陷检测结果 (DEBUG 模式) =====")
|
||||
print(f"置信度阈值: {conf_threshold}")
|
||||
print(f"有效类别数量: {len(top_detections)}")
|
||||
for i, det in enumerate(top_detections):
|
||||
cls_name = label_map.get(det['class_id'], '未知')
|
||||
bbox_int = det['bbox'].astype(int).tolist()
|
||||
print(f" - 类别 '{cls_name}' 最高置信度框: 置信度={det['confidence']:.3f}, bbox={bbox_int}")
|
||||
|
||||
# 7. 显示结果图像
|
||||
cv2.imshow('Detection Results', vis_img)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
return is_good
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例:启用 debug 模式
|
||||
is_good_product = detect_quexian(
|
||||
#img_path="/home/hx/开发/ML_xiantiao/class_xiantiao_pc/test_image/val/1.jpg",
|
||||
img_path="/home/hx/开发/ML_xiantiao/class_xiantiao_pc/test_image/train/微信图片_20251216095823_227.jpg",
|
||||
model_path="/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_detect/weights/best.pt",
|
||||
conf_threshold=0.5,
|
||||
debug=True # 改为 False 即静默模式
|
||||
)
|
||||
|
||||
# 主程序最终输出(简洁版)
|
||||
if is_good_product:
|
||||
print("产品合格")
|
||||
else:
|
||||
print("产品存在缺陷")
|
||||
Reference in New Issue
Block a user