75 lines
2.6 KiB
Python
75 lines
2.6 KiB
Python
|
|
import os
|
|||
|
|
import shutil
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
|
|||
|
|
|
|||
|
|
def classify_images_by_model(model_path, image_folder):
|
|||
|
|
"""
|
|||
|
|
使用分类模型对图片进行预测,并将每张图片复制到对应类别的子文件夹中。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
model_path (str): 分类模型权重路径(.pt 文件)
|
|||
|
|
image_folder (str): 包含待分类图片的文件夹路径
|
|||
|
|
"""
|
|||
|
|
image_folder = Path(image_folder)
|
|||
|
|
if not image_folder.exists():
|
|||
|
|
raise FileNotFoundError(f"图片文件夹不存在: {image_folder}")
|
|||
|
|
|
|||
|
|
# 支持的图片格式
|
|||
|
|
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
|||
|
|
image_files = [f for f in image_folder.iterdir() if f.suffix.lower() in image_extensions]
|
|||
|
|
if not image_files:
|
|||
|
|
print(f"在 {image_folder} 中未找到图片文件。")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 加载模型
|
|||
|
|
print(f"正在加载分类模型: {model_path}")
|
|||
|
|
try:
|
|||
|
|
model = YOLO(model_path) # 支持 YOLOv8 分类模型
|
|||
|
|
print("模型加载完成。")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"加载模型失败: {e}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
classified_count = 0
|
|||
|
|
for img_path in image_files:
|
|||
|
|
try:
|
|||
|
|
# 推理
|
|||
|
|
results = model(img_path, verbose=False)
|
|||
|
|
result = results[0]
|
|||
|
|
|
|||
|
|
# 获取预测类别名称
|
|||
|
|
# 注意:分类模型 result.probs.top1 可直接获取类别索引
|
|||
|
|
if hasattr(result.probs, 'top1'):
|
|||
|
|
class_idx = result.probs.top1
|
|||
|
|
class_name = result.names[class_idx]
|
|||
|
|
else:
|
|||
|
|
print(f"[跳过] {img_path.name}: 未获取到有效分类结果")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print(f"[{img_path.name}] 预测类别: {class_name}")
|
|||
|
|
|
|||
|
|
# 创建类别子文件夹
|
|||
|
|
class_folder = image_folder / class_name
|
|||
|
|
class_folder.mkdir(exist_ok=True)
|
|||
|
|
|
|||
|
|
# ✅ 复制图片到对应类别文件夹(保留原图)
|
|||
|
|
dest_path = class_folder / img_path.name
|
|||
|
|
shutil.copy2(str(img_path), str(dest_path)) # ← 关键:使用 copy2
|
|||
|
|
print(f" → 已复制到 {class_name} 文件夹")
|
|||
|
|
classified_count += 1
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[错误] 处理 {img_path.name} 时出错: {e}")
|
|||
|
|
|
|||
|
|
print(f"\n✅ 处理完成!共复制 {classified_count} 张图片到对应的类别文件夹。")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ================== 使用示例 ==================
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
MODEL_PATH = "cls5.pt" # 替换为你的分类模型 .pt 文件路径
|
|||
|
|
IMAGE_FOLDER = "./test_image" # 替换为你的图片文件夹
|
|||
|
|
|
|||
|
|
classify_images_by_model(MODEL_PATH, IMAGE_FOLDER)
|