2025-12-11 08:37:09 +08:00
|
|
|
import os
|
2025-12-16 15:00:24 +08:00
|
|
|
import shutil
|
2025-12-11 08:37:09 +08:00
|
|
|
from pathlib import Path
|
|
|
|
|
import cv2
|
|
|
|
|
from ultralytics import YOLO
|
|
|
|
|
|
|
|
|
|
# ---------------------------
|
|
|
|
|
# 配置路径(请按需修改)
|
|
|
|
|
# ---------------------------
|
2026-03-10 13:58:21 +08:00
|
|
|
MODEL_PATH = "gaiban.pt" # 你的二分类模型
|
|
|
|
|
INPUT_FOLDER = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/camera02" # 输入图像文件夹
|
|
|
|
|
OUTPUT_ROOT = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/camera02" # 输出根目录
|
2025-12-11 08:37:09 +08:00
|
|
|
|
|
|
|
|
# 类别映射(必须与训练时的 data.yaml 顺序一致)
|
2026-03-10 13:58:21 +08:00
|
|
|
CLASS_NAMES = {0: "noready", 1: "ready"}
|
2025-12-11 08:37:09 +08:00
|
|
|
|
|
|
|
|
# ---------------------------
|
2025-12-16 15:00:24 +08:00
|
|
|
# 批量推理函数(移动原图)
|
2025-12-11 08:37:09 +08:00
|
|
|
# ---------------------------
|
|
|
|
|
def batch_classify(model_path, input_folder, output_root):
|
|
|
|
|
# 加载模型
|
|
|
|
|
model = YOLO(model_path)
|
|
|
|
|
print(f"✅ 模型加载成功: {model_path}")
|
|
|
|
|
|
|
|
|
|
# 创建输出目录
|
|
|
|
|
output_root = Path(output_root)
|
|
|
|
|
for cls_name in CLASS_NAMES.values():
|
|
|
|
|
(output_root / cls_name).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# 支持的图像格式
|
|
|
|
|
IMG_EXTS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
|
|
|
|
|
|
|
|
|
|
input_dir = Path(input_folder)
|
|
|
|
|
processed = 0
|
|
|
|
|
|
|
|
|
|
for img_path in input_dir.iterdir():
|
|
|
|
|
if img_path.suffix.lower() not in IMG_EXTS:
|
|
|
|
|
continue
|
|
|
|
|
|
2025-12-16 15:00:24 +08:00
|
|
|
# 读取图像(用于推理)
|
2025-12-11 08:37:09 +08:00
|
|
|
img = cv2.imread(str(img_path))
|
|
|
|
|
if img is None:
|
2025-12-16 15:00:24 +08:00
|
|
|
print(f"❌ 无法读取图像(可能已损坏或被占用): {img_path}")
|
2025-12-11 08:37:09 +08:00
|
|
|
continue
|
|
|
|
|
|
2025-12-16 15:00:24 +08:00
|
|
|
# 推理(整图分类)
|
2025-12-11 08:37:09 +08:00
|
|
|
results = model(img)
|
|
|
|
|
probs = results[0].probs.data.cpu().numpy()
|
|
|
|
|
pred_class_id = int(probs.argmax())
|
|
|
|
|
pred_label = CLASS_NAMES[pred_class_id]
|
|
|
|
|
confidence = float(probs[pred_class_id])
|
|
|
|
|
|
2025-12-16 15:00:24 +08:00
|
|
|
# ⚠️ 关键修改:移动原图(不是复制)
|
2025-12-11 08:37:09 +08:00
|
|
|
dst = output_root / pred_label / img_path.name
|
2025-12-16 15:00:24 +08:00
|
|
|
try:
|
|
|
|
|
shutil.move(str(img_path), str(dst))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"❌ 移动失败 {img_path} → {dst}: {e}")
|
|
|
|
|
continue
|
2025-12-11 08:37:09 +08:00
|
|
|
|
|
|
|
|
print(f"✅ {img_path.name} → {pred_label} ({confidence:.2f})")
|
|
|
|
|
processed += 1
|
|
|
|
|
|
2025-12-16 15:00:24 +08:00
|
|
|
print(f"\n🎉 共处理并移动 {processed} 张图像,结果已保存至: {output_root}")
|
|
|
|
|
|
2025-12-11 08:37:09 +08:00
|
|
|
|
|
|
|
|
# ---------------------------
|
|
|
|
|
# 运行入口
|
|
|
|
|
# ---------------------------
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
batch_classify(
|
|
|
|
|
model_path=MODEL_PATH,
|
|
|
|
|
input_folder=INPUT_FOLDER,
|
|
|
|
|
output_root=OUTPUT_ROOT
|
|
|
|
|
)
|