import os from pathlib import Path import cv2 import shutil from ultralytics import YOLO # --------------------------- # 三分类类别定义(必须与模型训练时的顺序一致!) # --------------------------- CLASS_NAMES = { 0: "模具车 1", 1: "模具车 2", 2: "有遮挡" } # --------------------------- # 单张图片推理函数(直接输入原图) # --------------------------- def classify_full_image(image_numpy, model): """ 直接将整张原图送入模型推理 输入:numpy 数组 (BGR) 输出:(类别名称,置信度) """ # YOLO classification 模型会自动将输入图像 resize 到训练时的大小 (如 224x224 或 640x640) results = model(image_numpy) # 获取概率分布 pred_probs = results[0].probs.data.cpu().numpy().flatten() # 获取最大概率的类别索引 class_id = int(pred_probs.argmax()) confidence = float(pred_probs[class_id]) class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})") return class_name, confidence # --------------------------- # 批量推理主函数 (直接推理原图并移动) # --------------------------- def batch_classify_full_images(model_path, input_folder, output_root): print(f"🚀 开始加载模型:{model_path}") model = YOLO(model_path) # 创建输出目录结构 output_root = Path(output_root) output_root.mkdir(parents=True, exist_ok=True) class_dirs = {} for name in CLASS_NAMES.values(): d = output_root / name d.mkdir(exist_ok=True) class_dirs[name] = d print(f"✅ 准备输出目录:{d}") input_folder = Path(input_folder) image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'} processed_count = 0 error_count = 0 # 获取所有图片文件 image_files = [f for f in input_folder.iterdir() if f.suffix.lower() in image_extensions] print(f"\n📂 发现 {len(image_files)} 张图片,开始全图推理...\n") for img_path in image_files: try: print(f"📄 处理:{img_path.name}") # 1. 读取原图 img = cv2.imread(str(img_path)) if img is None: print(f"❌ 无法读取图像 (可能是损坏或格式不支持): {img_path.name}") error_count += 1 continue # 2. 【核心变化】直接对整张图进行推理,不再裁剪 ROI final_class, conf = classify_full_image(img, model) print(f" 🔍 全图识别结果:{final_class} (conf={conf:.2f})") # 3. 确定目标目录 dst_dir = class_dirs[final_class] dst_path = dst_dir / img_path.name # 处理文件名冲突 (如果目标文件夹已有同名文件) if dst_path.exists(): stem = img_path.stem suffix = img_path.suffix counter = 1 while True: new_name = f"{stem}_{counter}{suffix}" dst_path = dst_dir / new_name if not dst_path.exists(): break counter += 1 print(f" ⚠️ 目标文件已存在,重命名为:{dst_path.name}") # 4. 移动【原图】到对应分类文件夹 shutil.move(str(img_path), str(dst_path)) print(f" ✅ 成功移动原图 -> [{final_class}]") processed_count += 1 except Exception as e: print(f"❌ 处理失败 {img_path.name}: {e}") import traceback traceback.print_exc() error_count += 1 print(f"\n🎉 批量处理完成!") print(f" 成功移动:{processed_count} 张") print(f" 失败/跳过:{error_count} 张") # --------------------------- # 主程序入口 # --------------------------- if __name__ == "__main__": # 配置路径 model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls_resize_muju/exp_cls2/weights/best.pt" input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/61/浇筑满" output_root = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/61" # 注意:不再需要 roi_file 和 target_size 参数 batch_classify_full_images( model_path=model_path, input_folder=input_folder, output_root=output_root )