Files
zjsh_yolov11/ailai_cls/resize_tuili_file.py
琉璃月光 8b263167f8 更新
2025-12-11 08:37:09 +08:00

81 lines
2.6 KiB
Python

import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
# ------------------- 全局变量 -------------------
GLOBAL_MODEL = None
# 类别映射(二分类)
CLASS_NAMES = {
0: "夹紧",
1: "打开",
}
# ROI (x, y, w, h),只需修改这里即可
ROI = (818, 175, 1381, 1271)
# ------------------- 模型初始化 -------------------
def init_model(model_path):
global GLOBAL_MODEL
if GLOBAL_MODEL is None:
GLOBAL_MODEL = YOLO(model_path)
print(f"[INFO] 模型加载成功: {model_path}")
return GLOBAL_MODEL
# ------------------- ROI 裁剪 + resize -------------------
def preprocess(img, target_size=640):
x, y, w, h = ROI
roi_img = img[y:y+h, x:x+w]
roi_resized = cv2.resize(roi_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
return roi_resized
# ------------------- 单张推理 -------------------
def classify_image(img, model):
results = model(img)
pred_probs = results[0].probs.data.cpu().numpy().flatten()
class_id = int(pred_probs.argmax())
confidence = float(pred_probs[class_id])
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
return class_name, confidence
# ------------------- 批量推理 -------------------
def batch_classify_images(model_path, input_folder, output_root, target_size=640):
model = init_model(model_path)
input_folder = Path(input_folder)
output_root = Path(output_root)
output_root.mkdir(parents=True, exist_ok=True)
# 为每个类别创建目录
class_dirs = {name: (output_root / name) for name in CLASS_NAMES.values()}
for d in class_dirs.values():
d.mkdir(exist_ok=True)
# 遍历图片
for img_path in input_folder.glob("*.*"):
if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp', '.tif']:
continue
img = cv2.imread(str(img_path))
if img is None:
continue
roi_img = preprocess(img, target_size)
class_name, confidence = classify_image(roi_img, model)
suffix = f"_{class_name}_conf{confidence:.2f}"
dst_path = class_dirs[class_name] / f"{img_path.stem}{suffix}{img_path.suffix}"
cv2.imwrite(dst_path, roi_img)
print(f"{img_path.name}{suffix} -> {class_name} (confidence={confidence:.4f})")
# ------------------- 示例调用 -------------------
if __name__ == "__main__":
model_path = "best1.pt"
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailai_cls/train/class1"
output_root = "./classified_results"
target_size = 640
batch_classify_images(model_path, input_folder, output_root, target_size)