Files

131 lines
4.3 KiB
Python
Raw Permalink Normal View History

2026-03-10 13:58:21 +08:00
import os
from pathlib import Path
import cv2
import shutil
from ultralytics import YOLO
# ---------------------------
# 三分类类别定义(必须与模型训练时的顺序一致!)
# ---------------------------
CLASS_NAMES = {
0: "模具车 1",
1: "模具车 2",
2: "有遮挡"
}
# ---------------------------
# 单张图片推理函数(直接输入原图)
# ---------------------------
def classify_full_image(image_numpy, model):
"""
直接将整张原图送入模型推理
输入numpy 数组 (BGR)
输出(类别名称置信度)
"""
# YOLO classification 模型会自动将输入图像 resize 到训练时的大小 (如 224x224 或 640x640)
results = model(image_numpy)
# 获取概率分布
pred_probs = results[0].probs.data.cpu().numpy().flatten()
# 获取最大概率的类别索引
class_id = int(pred_probs.argmax())
confidence = float(pred_probs[class_id])
class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})")
return class_name, confidence
# ---------------------------
# 批量推理主函数 (直接推理原图并移动)
# ---------------------------
def batch_classify_full_images(model_path, input_folder, output_root):
print(f"🚀 开始加载模型:{model_path}")
model = YOLO(model_path)
# 创建输出目录结构
output_root = Path(output_root)
output_root.mkdir(parents=True, exist_ok=True)
class_dirs = {}
for name in CLASS_NAMES.values():
d = output_root / name
d.mkdir(exist_ok=True)
class_dirs[name] = d
print(f"✅ 准备输出目录:{d}")
input_folder = Path(input_folder)
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
processed_count = 0
error_count = 0
# 获取所有图片文件
image_files = [f for f in input_folder.iterdir() if f.suffix.lower() in image_extensions]
print(f"\n📂 发现 {len(image_files)} 张图片,开始全图推理...\n")
for img_path in image_files:
try:
print(f"📄 处理:{img_path.name}")
# 1. 读取原图
img = cv2.imread(str(img_path))
if img is None:
print(f"❌ 无法读取图像 (可能是损坏或格式不支持): {img_path.name}")
error_count += 1
continue
# 2. 【核心变化】直接对整张图进行推理,不再裁剪 ROI
final_class, conf = classify_full_image(img, model)
print(f" 🔍 全图识别结果:{final_class} (conf={conf:.2f})")
# 3. 确定目标目录
dst_dir = class_dirs[final_class]
dst_path = dst_dir / img_path.name
# 处理文件名冲突 (如果目标文件夹已有同名文件)
if dst_path.exists():
stem = img_path.stem
suffix = img_path.suffix
counter = 1
while True:
new_name = f"{stem}_{counter}{suffix}"
dst_path = dst_dir / new_name
if not dst_path.exists():
break
counter += 1
print(f" ⚠️ 目标文件已存在,重命名为:{dst_path.name}")
# 4. 移动【原图】到对应分类文件夹
shutil.move(str(img_path), str(dst_path))
print(f" ✅ 成功移动原图 -> [{final_class}]")
processed_count += 1
except Exception as e:
print(f"❌ 处理失败 {img_path.name}: {e}")
import traceback
traceback.print_exc()
error_count += 1
print(f"\n🎉 批量处理完成!")
print(f" 成功移动:{processed_count}")
print(f" 失败/跳过:{error_count}")
# ---------------------------
# 主程序入口
# ---------------------------
if __name__ == "__main__":
# 配置路径
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls_resize_muju/exp_cls2/weights/best.pt"
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/61/浇筑满"
output_root = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/61"
# 注意:不再需要 roi_file 和 target_size 参数
batch_classify_full_images(
model_path=model_path,
input_folder=input_folder,
output_root=output_root
)