Files
zjsh_yolov11/zhuangtai_class_cls/resize_main.py
琉璃月光 df7c0730f5 bushu
2025-10-21 14:11:52 +08:00

107 lines
3.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import shutil
from pathlib import Path
from ultralytics import YOLO
import cv2
# ---------------------------
# ROI 裁剪函数
# ---------------------------
def load_global_rois(txt_path):
"""加载全局 ROI 坐标"""
rois = []
if not os.path.exists(txt_path):
print(f"❌ ROI 文件不存在: {txt_path}")
return rois
with open(txt_path, 'r') as f:
for line in f:
line = line.strip()
if line:
try:
x, y, w, h = map(int, line.split(','))
rois.append((x, y, w, h))
print(f"📌 加载 ROI: (x={x}, y={y}, w={w}, h={h})")
except Exception as e:
print(f"⚠️ 无法解析 ROI 行: {line}, 错误: {e}")
return rois
def crop_and_resize(img, rois, target_size=640):
"""根据 ROI 裁剪并 resize"""
crops = []
for i, (x, y, w, h) in enumerate(rois):
h_img, w_img = img.shape[:2]
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
print(f"⚠️ ROI 越界,跳过: {x},{y},{w},{h}")
continue
roi_img = img[y:y+h, x:x+w]
roi_resized = cv2.resize(roi_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
crops.append((roi_resized, i))
return crops
# ---------------------------
# 分类函数
# ---------------------------
def classify_and_save_images(model_path, input_folder, output_root, roi_file, target_size=640):
# 加载模型
model = YOLO(model_path)
# 确保输出根目录存在
output_root = Path(output_root)
output_root.mkdir(parents=True, exist_ok=True)
# 创建类别子文件夹 (class0 到 class4)
class_dirs = []
for i in range(5): # 假设有5个类别 (0-4)
class_dir = output_root / f"class{i}"
class_dir.mkdir(exist_ok=True)
class_dirs.append(class_dir)
# 加载 ROI
rois = load_global_rois(roi_file)
if len(rois) == 0:
print("❌ 没有有效 ROI退出")
return
# 遍历输入文件夹
for img_path in Path(input_folder).glob("*.*"):
if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp', '.tif']:
continue
try:
# 读取原图
img = cv2.imread(str(img_path))
if img is None:
print(f"❌ 无法读取图像: {img_path}")
continue
# 根据 ROI 裁剪
crops = crop_and_resize(img, rois, target_size)
for roi_img, roi_idx in crops:
# YOLO 推理
results = model(roi_img)
pred = results[0].probs.data # 获取概率分布
class_id = int(pred.argmax())
# 保存到对应类别文件夹
suffix = f"_roi{roi_idx}" if len(crops) > 1 else ""
dst_path = class_dirs[class_id] / f"{img_path.stem}{suffix}{img_path.suffix}"
cv2.imwrite(dst_path, roi_img) # 保存裁剪后的 ROI 图像
print(f"Processed {img_path.name}{suffix} -> Class {class_id}")
except Exception as e:
print(f"Error processing {img_path.name}: {str(e)}")
# ---------------------------
# 主程序
# ---------------------------
if __name__ == "__main__":
model_path = r"best.pt"
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/f6"
output_root = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/class44"
roi_file = "./roi_coordinates/1_rois.txt" # 训练时使用的 ROI 文件
target_size = 640
classify_and_save_images(model_path, input_folder, output_root, roi_file, target_size)