bushu
This commit is contained in:
55
zhuangtai_class_cls/test_cls_file.py
Normal file
55
zhuangtai_class_cls/test_cls_file.py
Normal file
@ -0,0 +1,55 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
|
||||
|
||||
def classify_and_save_images(model_path, input_folder, output_root):
|
||||
# 加载模型
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 确保输出根目录存在
|
||||
output_root = Path(output_root)
|
||||
output_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建类别子文件夹 (class0 到 class4)
|
||||
class_dirs = []
|
||||
for i in range(5): # 假设有5个类别 (0-4)
|
||||
class_dir = output_root / f"class{i}"
|
||||
class_dir.mkdir(exist_ok=True)
|
||||
class_dirs.append(class_dir)
|
||||
|
||||
# 遍历输入文件夹中的所有图片
|
||||
for img_path in Path(input_folder).glob("*.*"):
|
||||
if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp', '.tif']:
|
||||
continue # 跳过非图片文件
|
||||
|
||||
try:
|
||||
# 执行推理
|
||||
results = model(img_path)
|
||||
|
||||
# 获取预测结果 (分类任务通常返回一个包含类别概率的数组)
|
||||
pred = results[0].probs.data # 获取概率分布 (shape: [5])
|
||||
class_id = int(pred.argmax()) # 获取概率最高的类别ID
|
||||
|
||||
# 复制图片到对应类别文件夹
|
||||
dst_path = class_dirs[class_id] / img_path.name
|
||||
shutil.copy2(img_path, dst_path)
|
||||
|
||||
print(f"Processed {img_path.name} -> Class {class_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {img_path.name}: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置路径
|
||||
model_path = r'/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_cls3/weights/best.pt' # 或直接使用训练好的权重路径如 'runs/train/cls/exp_cls/weights/best.pt'
|
||||
#input_folder = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/datalodad/f13' # 替换为你的测试图片文件夹路径
|
||||
#output_root = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/datalodad' # 输出根目录
|
||||
input_folder = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/f6' # 替换为你的测试图片文件夹路径
|
||||
output_root = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/class11' # 输出根目录
|
||||
|
||||
# 执行分类
|
||||
classify_and_save_images(model_path, input_folder, output_root)
|
||||
Reference in New Issue
Block a user