import os import cv2 from ultralytics import YOLO class FullImageClassifier: """封装 YOLO 分类模型,对整张图像进行二分类(无缺陷/有缺陷)""" def __init__(self, model_path): if not os.path.exists(model_path): raise FileNotFoundError(f"模型文件不存在: {model_path}") self.model = YOLO(model_path) print(f"[INFO] 成功加载 YOLO 分类模型: {model_path}") def classify(self, img_np): """ 对整张图像进行分类,返回类别 ID(0 或 1) Args: img_np (np.ndarray): BGR 格式的 OpenCV 图像 (H, W, C) Returns: int: 分类结果,{0: "有缺陷", 1: "无缺陷"} """ # 直接 resize 整图到模型输入尺寸(YOLO 默认为 224x224 或 640x640,由训练决定) # Ultralytics YOLO 会自动处理 resize,但显式指定更可控 resized_img = cv2.resize(img_np, (640, 640)) # 推理(verbose=False 关闭进度条) results = self.model.predict(resized_img, verbose=False) cls = int(results[0].probs.top1) # 获取 top-1 类别索引 return cls def cls_quexian(): # ================== 配置 ================== model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt" img_path = "1.png" # ================== 1. 加载模型 ================== classifier = FullImageClassifier(model_path) # ================== 2. 加载图像 ================== if not os.path.exists(img_path): raise FileNotFoundError(f"图片不存在: {img_path}") img_np = cv2.imread(img_path) if img_np is None: raise ValueError("图像加载失败,可能路径错误或图像文件损坏。") # ================== 3. 推理(整图)================== result = classifier.classify(img_np) # ================== 4. 输出 ================== label_map = {0: "有缺陷", 1: "无缺陷"} print(f"\n===== 推理结果 =====") print(f"分类结果:{result} → {label_map.get(result, '未知')}\n") if __name__ == "__main__": cls_quexian()