131 lines
4.3 KiB
Python
131 lines
4.3 KiB
Python
|
|
import os
|
|||
|
|
from pathlib import Path
|
|||
|
|
import cv2
|
|||
|
|
import shutil
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
|
|||
|
|
# ---------------------------
|
|||
|
|
# 三分类类别定义(必须与模型训练时的顺序一致!)
|
|||
|
|
# ---------------------------
|
|||
|
|
CLASS_NAMES = {
|
|||
|
|
0: "模具车 1",
|
|||
|
|
1: "模具车 2",
|
|||
|
|
2: "有遮挡"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------
|
|||
|
|
# 单张图片推理函数(直接输入原图)
|
|||
|
|
# ---------------------------
|
|||
|
|
def classify_full_image(image_numpy, model):
|
|||
|
|
"""
|
|||
|
|
直接将整张原图送入模型推理
|
|||
|
|
输入:numpy 数组 (BGR)
|
|||
|
|
输出:(类别名称,置信度)
|
|||
|
|
"""
|
|||
|
|
# YOLO classification 模型会自动将输入图像 resize 到训练时的大小 (如 224x224 或 640x640)
|
|||
|
|
results = model(image_numpy)
|
|||
|
|
|
|||
|
|
# 获取概率分布
|
|||
|
|
pred_probs = results[0].probs.data.cpu().numpy().flatten()
|
|||
|
|
|
|||
|
|
# 获取最大概率的类别索引
|
|||
|
|
class_id = int(pred_probs.argmax())
|
|||
|
|
confidence = float(pred_probs[class_id])
|
|||
|
|
|
|||
|
|
class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})")
|
|||
|
|
return class_name, confidence
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------
|
|||
|
|
# 批量推理主函数 (直接推理原图并移动)
|
|||
|
|
# ---------------------------
|
|||
|
|
def batch_classify_full_images(model_path, input_folder, output_root):
|
|||
|
|
print(f"🚀 开始加载模型:{model_path}")
|
|||
|
|
model = YOLO(model_path)
|
|||
|
|
|
|||
|
|
# 创建输出目录结构
|
|||
|
|
output_root = Path(output_root)
|
|||
|
|
output_root.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
class_dirs = {}
|
|||
|
|
for name in CLASS_NAMES.values():
|
|||
|
|
d = output_root / name
|
|||
|
|
d.mkdir(exist_ok=True)
|
|||
|
|
class_dirs[name] = d
|
|||
|
|
print(f"✅ 准备输出目录:{d}")
|
|||
|
|
|
|||
|
|
input_folder = Path(input_folder)
|
|||
|
|
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
|
|||
|
|
processed_count = 0
|
|||
|
|
error_count = 0
|
|||
|
|
|
|||
|
|
# 获取所有图片文件
|
|||
|
|
image_files = [f for f in input_folder.iterdir() if f.suffix.lower() in image_extensions]
|
|||
|
|
print(f"\n📂 发现 {len(image_files)} 张图片,开始全图推理...\n")
|
|||
|
|
|
|||
|
|
for img_path in image_files:
|
|||
|
|
try:
|
|||
|
|
print(f"📄 处理:{img_path.name}")
|
|||
|
|
|
|||
|
|
# 1. 读取原图
|
|||
|
|
img = cv2.imread(str(img_path))
|
|||
|
|
if img is None:
|
|||
|
|
print(f"❌ 无法读取图像 (可能是损坏或格式不支持): {img_path.name}")
|
|||
|
|
error_count += 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 2. 【核心变化】直接对整张图进行推理,不再裁剪 ROI
|
|||
|
|
final_class, conf = classify_full_image(img, model)
|
|||
|
|
print(f" 🔍 全图识别结果:{final_class} (conf={conf:.2f})")
|
|||
|
|
|
|||
|
|
# 3. 确定目标目录
|
|||
|
|
dst_dir = class_dirs[final_class]
|
|||
|
|
dst_path = dst_dir / img_path.name
|
|||
|
|
|
|||
|
|
# 处理文件名冲突 (如果目标文件夹已有同名文件)
|
|||
|
|
if dst_path.exists():
|
|||
|
|
stem = img_path.stem
|
|||
|
|
suffix = img_path.suffix
|
|||
|
|
counter = 1
|
|||
|
|
while True:
|
|||
|
|
new_name = f"{stem}_{counter}{suffix}"
|
|||
|
|
dst_path = dst_dir / new_name
|
|||
|
|
if not dst_path.exists():
|
|||
|
|
break
|
|||
|
|
counter += 1
|
|||
|
|
print(f" ⚠️ 目标文件已存在,重命名为:{dst_path.name}")
|
|||
|
|
|
|||
|
|
# 4. 移动【原图】到对应分类文件夹
|
|||
|
|
shutil.move(str(img_path), str(dst_path))
|
|||
|
|
print(f" ✅ 成功移动原图 -> [{final_class}]")
|
|||
|
|
|
|||
|
|
processed_count += 1
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ 处理失败 {img_path.name}: {e}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
error_count += 1
|
|||
|
|
|
|||
|
|
print(f"\n🎉 批量处理完成!")
|
|||
|
|
print(f" 成功移动:{processed_count} 张")
|
|||
|
|
print(f" 失败/跳过:{error_count} 张")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------
|
|||
|
|
# 主程序入口
|
|||
|
|
# ---------------------------
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 配置路径
|
|||
|
|
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls_resize_muju/exp_cls2/weights/best.pt"
|
|||
|
|
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/61/浇筑满"
|
|||
|
|
output_root = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/61"
|
|||
|
|
|
|||
|
|
# 注意:不再需要 roi_file 和 target_size 参数
|
|||
|
|
|
|||
|
|
batch_classify_full_images(
|
|||
|
|
model_path=model_path,
|
|||
|
|
input_folder=input_folder,
|
|||
|
|
output_root=output_root
|
|||
|
|
)
|