import os import cv2 from ultralytics import YOLO class ROIClassifier: """ 封装 YOLO 分类模型 + ROI txt 推理 """ def __init__(self, model_path): if not os.path.exists(model_path): raise FileNotFoundError(f"模型文件不存在: {model_path}") self.model = YOLO(model_path) print(f"[INFO] 成功加载 YOLO 分类模型: {model_path}") def load_roi(self, roi_txt_path): """读取 ROI txt,可自动支持逗号/空格分隔""" if not os.path.exists(roi_txt_path): raise FileNotFoundError(f"ROI 文件不存在: {roi_txt_path}") with open(roi_txt_path, "r") as f: text = f.read().strip() # 把逗号替换成空格,再 split for ch in [",", ";"]: text = text.replace(ch, " ") parts = text.split() if len(parts) != 4: raise ValueError(f"ROI txt 格式错误,应为4个数字,解析得到: {parts}\n文件: {roi_txt_path}") x1, y1, x2, y2 = map(int, parts) return x1, y1, x2, y2 def classify(self, img_np, roi_txt_path): """对 ROI 区域做分类,返回 0/1""" h, w = img_np.shape[:2] x1, y1, x2, y2 = self.load_roi(roi_txt_path) # -------- ROI 边界安全裁剪 -------- x1 = max(0, min(x1, w - 1)) x2 = max(0, min(x2, w - 1)) y1 = max(0, min(y1, h - 1)) y2 = max(0, min(y2, h - 1)) if x2 <= x1 or y2 <= y1: raise ValueError(f"ROI坐标无效: {x1, y1, x2, y2}") # -------- 1. 裁剪 ROI -------- roi_img = img_np[y1:y2, x1:x2] # -------- 2. resize 到 640×640(强制送给模型)-------- roi_img = cv2.resize(roi_img, (640, 640)) # -------- 3. YOLO 分类推理 -------- results = self.model.predict(roi_img, verbose=False) cls = int(results[0].probs.top1) # 0 或 1 return cls def class_xiantiao(): # ================== 配置 ================== model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt" img_path = "1.png" roi_txt_path = "/home/hx/开发/ML_xiantiao/class_xiantiao/roi_1/1/1_rois1.txt" # ================== 1. 加载模型 ================== classifier = ROIClassifier(model_path) # ================== 2. 加载图像 ================== if not os.path.exists(img_path): raise FileNotFoundError(f"图片不存在: {img_path}") img_np = cv2.imread(img_path) if img_np is None: raise ValueError("图像加载失败,可能路径错误或图像文件损坏。") # ================== 3. 推理 ================== result = classifier.classify(img_np, roi_txt_path) # ================== 4. 输出 ================== print(f"\n===== 推理结果 =====") print(f"ROI 文件:{roi_txt_path}") print(f"分类结果:{result} (0=异常 / 1=正常)\n") if __name__ == "__main__": class_xiantiao()