import cv2 import numpy as np from pathlib import Path from ultralytics import YOLO # ------------------- 全局变量 ------------------- GLOBAL_MODEL = None # 类别映射(二分类) CLASS_NAMES = { 0: "夹紧", 1: "打开", } # ROI (x, y, w, h),只需修改这里即可 ROI = (818, 175, 1381, 1271) # ------------------- 模型初始化 ------------------- def init_model(model_path): global GLOBAL_MODEL if GLOBAL_MODEL is None: GLOBAL_MODEL = YOLO(model_path) print(f"[INFO] 模型加载成功: {model_path}") return GLOBAL_MODEL # ------------------- ROI 裁剪 + resize ------------------- def preprocess(img, target_size=640): x, y, w, h = ROI roi_img = img[y:y+h, x:x+w] roi_resized = cv2.resize(roi_img, (target_size, target_size), interpolation=cv2.INTER_AREA) return roi_resized # ------------------- 单张推理 ------------------- def classify_image(img, model): results = model(img) pred_probs = results[0].probs.data.cpu().numpy().flatten() class_id = int(pred_probs.argmax()) confidence = float(pred_probs[class_id]) class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})") return class_name, confidence # ------------------- 批量推理 ------------------- def batch_classify_images(model_path, input_folder, output_root, target_size=640): model = init_model(model_path) input_folder = Path(input_folder) output_root = Path(output_root) output_root.mkdir(parents=True, exist_ok=True) # 为每个类别创建目录 class_dirs = {name: (output_root / name) for name in CLASS_NAMES.values()} for d in class_dirs.values(): d.mkdir(exist_ok=True) # 遍历图片 for img_path in input_folder.glob("*.*"): if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp', '.tif']: continue img = cv2.imread(str(img_path)) if img is None: continue roi_img = preprocess(img, target_size) class_name, confidence = classify_image(roi_img, model) suffix = f"_{class_name}_conf{confidence:.2f}" dst_path = class_dirs[class_name] / f"{img_path.stem}{suffix}{img_path.suffix}" cv2.imwrite(dst_path, roi_img) print(f"{img_path.name}{suffix} -> {class_name} (confidence={confidence:.4f})") # ------------------- 示例调用 ------------------- if __name__ == "__main__": model_path = "best1.pt" input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailai_cls/train/class1" output_root = "./classified_results" target_size = 640 batch_classify_images(model_path, input_folder, output_root, target_size)