55 lines
2.2 KiB
Python
55 lines
2.2 KiB
Python
import os
|
|
import shutil
|
|
from pathlib import Path
|
|
from ultralytics import YOLO
|
|
import cv2
|
|
|
|
|
|
def classify_and_save_images(model_path, input_folder, output_root):
|
|
# 加载模型
|
|
model = YOLO(model_path)
|
|
|
|
# 确保输出根目录存在
|
|
output_root = Path(output_root)
|
|
output_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
# 创建类别子文件夹 (class0 到 class4)
|
|
class_dirs = []
|
|
for i in range(2): # 假设有5个类别 (0-4)
|
|
class_dir = output_root / f"class{i}"
|
|
class_dir.mkdir(exist_ok=True)
|
|
class_dirs.append(class_dir)
|
|
|
|
# 遍历输入文件夹中的所有图片
|
|
for img_path in Path(input_folder).glob("*.*"):
|
|
if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp', '.tif']:
|
|
continue # 跳过非图片文件
|
|
|
|
try:
|
|
# 执行推理
|
|
results = model(img_path)
|
|
|
|
# 获取预测结果 (分类任务通常返回一个包含类别概率的数组)
|
|
pred = results[0].probs.data # 获取概率分布 (shape: [5])
|
|
class_id = int(pred.argmax()) # 获取概率最高的类别ID
|
|
|
|
# 复制图片到对应类别文件夹
|
|
dst_path = class_dirs[class_id] / img_path.name
|
|
shutil.move(img_path, dst_path)
|
|
|
|
print(f"Processed {img_path.name} -> Class {class_id}")
|
|
|
|
except Exception as e:
|
|
print(f"Error processing {img_path.name}: {str(e)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 配置路径
|
|
model_path = r'/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_zdb_cls2/weights/best.pt' # 或直接使用训练好的权重路径如 'runs/train/cls/exp_cls/weights/best.pt'
|
|
#input_folder = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/datalodad/f13' # 替换为你的测试图片文件夹路径
|
|
#output_root = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/datalodad' # 输出根目录
|
|
input_folder = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/class/class0' # 替换为你的测试图片文件夹路径
|
|
output_root = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/class/class' # 输出根目录
|
|
|
|
# 执行分类
|
|
classify_and_save_images(model_path, input_folder, output_root) |