Files
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

131 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from pathlib import Path
import cv2
import shutil
from ultralytics import YOLO
# ---------------------------
# 三分类类别定义(必须与模型训练时的顺序一致!)
# ---------------------------
CLASS_NAMES = {
0: "模具车 1",
1: "模具车 2",
2: "有遮挡"
}
# ---------------------------
# 单张图片推理函数(直接输入原图)
# ---------------------------
def classify_full_image(image_numpy, model):
"""
直接将整张原图送入模型推理
输入numpy 数组 (BGR)
输出:(类别名称,置信度)
"""
# YOLO classification 模型会自动将输入图像 resize 到训练时的大小 (如 224x224 或 640x640)
results = model(image_numpy)
# 获取概率分布
pred_probs = results[0].probs.data.cpu().numpy().flatten()
# 获取最大概率的类别索引
class_id = int(pred_probs.argmax())
confidence = float(pred_probs[class_id])
class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})")
return class_name, confidence
# ---------------------------
# 批量推理主函数 (直接推理原图并移动)
# ---------------------------
def batch_classify_full_images(model_path, input_folder, output_root):
print(f"🚀 开始加载模型:{model_path}")
model = YOLO(model_path)
# 创建输出目录结构
output_root = Path(output_root)
output_root.mkdir(parents=True, exist_ok=True)
class_dirs = {}
for name in CLASS_NAMES.values():
d = output_root / name
d.mkdir(exist_ok=True)
class_dirs[name] = d
print(f"✅ 准备输出目录:{d}")
input_folder = Path(input_folder)
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
processed_count = 0
error_count = 0
# 获取所有图片文件
image_files = [f for f in input_folder.iterdir() if f.suffix.lower() in image_extensions]
print(f"\n📂 发现 {len(image_files)} 张图片,开始全图推理...\n")
for img_path in image_files:
try:
print(f"📄 处理:{img_path.name}")
# 1. 读取原图
img = cv2.imread(str(img_path))
if img is None:
print(f"❌ 无法读取图像 (可能是损坏或格式不支持): {img_path.name}")
error_count += 1
continue
# 2. 【核心变化】直接对整张图进行推理,不再裁剪 ROI
final_class, conf = classify_full_image(img, model)
print(f" 🔍 全图识别结果:{final_class} (conf={conf:.2f})")
# 3. 确定目标目录
dst_dir = class_dirs[final_class]
dst_path = dst_dir / img_path.name
# 处理文件名冲突 (如果目标文件夹已有同名文件)
if dst_path.exists():
stem = img_path.stem
suffix = img_path.suffix
counter = 1
while True:
new_name = f"{stem}_{counter}{suffix}"
dst_path = dst_dir / new_name
if not dst_path.exists():
break
counter += 1
print(f" ⚠️ 目标文件已存在,重命名为:{dst_path.name}")
# 4. 移动【原图】到对应分类文件夹
shutil.move(str(img_path), str(dst_path))
print(f" ✅ 成功移动原图 -> [{final_class}]")
processed_count += 1
except Exception as e:
print(f"❌ 处理失败 {img_path.name}: {e}")
import traceback
traceback.print_exc()
error_count += 1
print(f"\n🎉 批量处理完成!")
print(f" 成功移动:{processed_count}")
print(f" 失败/跳过:{error_count}")
# ---------------------------
# 主程序入口
# ---------------------------
if __name__ == "__main__":
# 配置路径
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls_resize_muju/exp_cls2/weights/best.pt"
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/61/浇筑满"
output_root = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/61"
# 注意:不再需要 roi_file 和 target_size 参数
batch_classify_full_images(
model_path=model_path,
input_folder=input_folder,
output_root=output_root
)