Files
zjsh_yolov11/zjsh_code/charge_3cls/val/3cls_file.py
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

75 lines
2.5 KiB
Python

import os
import shutil
from pathlib import Path
import cv2
from ultralytics import YOLO
# ---------------------------
# 配置路径(请按需修改)
# ---------------------------
MODEL_PATH = "/zjsh_yolov11/ultralytics_yolo11-main/runs/train/charge_cls/exp_cls2/weights/best.pt" # 你的三分类模型
INPUT_FOLDER = "/home/dy/dataset/charge/val/class0" # 输入图像文件夹
OUTPUT_ROOT = "/home/dy//dataset/charge/val/class2/1" # 输出根目录
# 🔥 关键修改:三分类!顺序必须与 data.yaml 中的 names 一致
CLASS_NAMES = {0: "class0", 1: "class1", 2: "class2"} # 示例:第三类叫 "other"
# ---------------------------
# 批量推理函数(移动原图)
# ---------------------------
def batch_classify(model_path, input_folder, output_root):
# 加载模型
model = YOLO(model_path)
print(f"✅ 模型加载成功: {model_path}")
# 创建输出目录(自动为每个类别建文件夹)
output_root = Path(output_root)
for cls_name in CLASS_NAMES.values():
(output_root / cls_name).mkdir(parents=True, exist_ok=True)
# 支持的图像格式
IMG_EXTS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
input_dir = Path(input_folder)
processed = 0
for img_path in input_dir.iterdir():
if img_path.suffix.lower() not in IMG_EXTS:
continue
# 读取图像
img = cv2.imread(str(img_path))
if img is None:
print(f"❌ 无法读取图像(可能已损坏或被占用): {img_path}")
continue
# 推理(整图分类)
results = model(img)
probs = results[0].probs.data.cpu().numpy()
pred_class_id = int(probs.argmax())
pred_label = CLASS_NAMES[pred_class_id]
confidence = float(probs[pred_class_id])
# 移动原图到对应类别文件夹
dst = output_root / pred_label / img_path.name
try:
shutil.move(str(img_path), str(dst))
except Exception as e:
print(f"❌ 移动失败 {img_path}{dst}: {e}")
continue
print(f"{img_path.name}{pred_label} ({confidence:.2f})")
processed += 1
print(f"\n🎉 共处理并移动 {processed} 张图像,结果已保存至: {output_root}")
# ---------------------------
# 运行入口
# ---------------------------
if __name__ == "__main__":
batch_classify(
model_path=MODEL_PATH,
input_folder=INPUT_FOLDER,
output_root=OUTPUT_ROOT
)