Files
zjsh_yolov11/image/class_data.py

75 lines
2.6 KiB
Python
Raw Normal View History

2025-10-21 14:11:52 +08:00
import os
import shutil
from pathlib import Path
from ultralytics import YOLO
def classify_images_by_model(model_path, image_folder):
"""
使用分类模型对图片进行预测并将每张图片复制到对应类别的子文件夹中
Args:
model_path (str): 分类模型权重路径.pt 文件
image_folder (str): 包含待分类图片的文件夹路径
"""
image_folder = Path(image_folder)
if not image_folder.exists():
raise FileNotFoundError(f"图片文件夹不存在: {image_folder}")
# 支持的图片格式
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
image_files = [f for f in image_folder.iterdir() if f.suffix.lower() in image_extensions]
if not image_files:
print(f"{image_folder} 中未找到图片文件。")
return
# 加载模型
print(f"正在加载分类模型: {model_path}")
try:
model = YOLO(model_path) # 支持 YOLOv8 分类模型
print("模型加载完成。")
except Exception as e:
print(f"加载模型失败: {e}")
return
classified_count = 0
for img_path in image_files:
try:
# 推理
results = model(img_path, verbose=False)
result = results[0]
# 获取预测类别名称
# 注意:分类模型 result.probs.top1 可直接获取类别索引
if hasattr(result.probs, 'top1'):
class_idx = result.probs.top1
class_name = result.names[class_idx]
else:
print(f"[跳过] {img_path.name}: 未获取到有效分类结果")
continue
print(f"[{img_path.name}] 预测类别: {class_name}")
# 创建类别子文件夹
class_folder = image_folder / class_name
class_folder.mkdir(exist_ok=True)
# ✅ 复制图片到对应类别文件夹(保留原图)
dest_path = class_folder / img_path.name
shutil.copy2(str(img_path), str(dest_path)) # ← 关键:使用 copy2
print(f" → 已复制到 {class_name} 文件夹")
classified_count += 1
except Exception as e:
print(f"[错误] 处理 {img_path.name} 时出错: {e}")
print(f"\n✅ 处理完成!共复制 {classified_count} 张图片到对应的类别文件夹。")
# ================== 使用示例 ==================
if __name__ == "__main__":
MODEL_PATH = "cls5.pt" # 替换为你的分类模型 .pt 文件路径
IMAGE_FOLDER = "./test_image" # 替换为你的图片文件夹
classify_images_by_model(MODEL_PATH, IMAGE_FOLDER)