Files
zjsh_yolov11/zhuangtai_class_cls/test_cls_file.py
琉璃月光 df7c0730f5 bushu
2025-10-21 14:11:52 +08:00

55 lines
2.2 KiB
Python

import os
import shutil
from pathlib import Path
from ultralytics import YOLO
import cv2
def classify_and_save_images(model_path, input_folder, output_root):
# 加载模型
model = YOLO(model_path)
# 确保输出根目录存在
output_root = Path(output_root)
output_root.mkdir(parents=True, exist_ok=True)
# 创建类别子文件夹 (class0 到 class4)
class_dirs = []
for i in range(5): # 假设有5个类别 (0-4)
class_dir = output_root / f"class{i}"
class_dir.mkdir(exist_ok=True)
class_dirs.append(class_dir)
# 遍历输入文件夹中的所有图片
for img_path in Path(input_folder).glob("*.*"):
if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp', '.tif']:
continue # 跳过非图片文件
try:
# 执行推理
results = model(img_path)
# 获取预测结果 (分类任务通常返回一个包含类别概率的数组)
pred = results[0].probs.data # 获取概率分布 (shape: [5])
class_id = int(pred.argmax()) # 获取概率最高的类别ID
# 复制图片到对应类别文件夹
dst_path = class_dirs[class_id] / img_path.name
shutil.copy2(img_path, dst_path)
print(f"Processed {img_path.name} -> Class {class_id}")
except Exception as e:
print(f"Error processing {img_path.name}: {str(e)}")
if __name__ == "__main__":
# 配置路径
model_path = r'/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_cls3/weights/best.pt' # 或直接使用训练好的权重路径如 'runs/train/cls/exp_cls/weights/best.pt'
#input_folder = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/datalodad/f13' # 替换为你的测试图片文件夹路径
#output_root = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/datalodad' # 输出根目录
input_folder = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/f6' # 替换为你的测试图片文件夹路径
output_root = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/class11' # 输出根目录
# 执行分类
classify_and_save_images(model_path, input_folder, output_root)