81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
|
|
import cv2
|
||
|
|
import numpy as np
|
||
|
|
from pathlib import Path
|
||
|
|
from ultralytics import YOLO
|
||
|
|
|
||
|
|
# ------------------- 全局变量 -------------------
|
||
|
|
GLOBAL_MODEL = None
|
||
|
|
|
||
|
|
# 类别映射(二分类)
|
||
|
|
CLASS_NAMES = {
|
||
|
|
0: "夹紧",
|
||
|
|
1: "打开",
|
||
|
|
}
|
||
|
|
|
||
|
|
# ROI (x, y, w, h),只需修改这里即可
|
||
|
|
ROI = (818, 175, 1381, 1271)
|
||
|
|
|
||
|
|
# ------------------- 模型初始化 -------------------
|
||
|
|
def init_model(model_path):
|
||
|
|
global GLOBAL_MODEL
|
||
|
|
if GLOBAL_MODEL is None:
|
||
|
|
GLOBAL_MODEL = YOLO(model_path)
|
||
|
|
print(f"[INFO] 模型加载成功: {model_path}")
|
||
|
|
return GLOBAL_MODEL
|
||
|
|
|
||
|
|
# ------------------- ROI 裁剪 + resize -------------------
|
||
|
|
def preprocess(img, target_size=640):
|
||
|
|
x, y, w, h = ROI
|
||
|
|
roi_img = img[y:y+h, x:x+w]
|
||
|
|
roi_resized = cv2.resize(roi_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
|
||
|
|
return roi_resized
|
||
|
|
|
||
|
|
# ------------------- 单张推理 -------------------
|
||
|
|
def classify_image(img, model):
|
||
|
|
results = model(img)
|
||
|
|
pred_probs = results[0].probs.data.cpu().numpy().flatten()
|
||
|
|
class_id = int(pred_probs.argmax())
|
||
|
|
confidence = float(pred_probs[class_id])
|
||
|
|
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
|
||
|
|
return class_name, confidence
|
||
|
|
|
||
|
|
# ------------------- 批量推理 -------------------
|
||
|
|
def batch_classify_images(model_path, input_folder, output_root, target_size=640):
|
||
|
|
model = init_model(model_path)
|
||
|
|
|
||
|
|
input_folder = Path(input_folder)
|
||
|
|
output_root = Path(output_root)
|
||
|
|
output_root.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# 为每个类别创建目录
|
||
|
|
class_dirs = {name: (output_root / name) for name in CLASS_NAMES.values()}
|
||
|
|
for d in class_dirs.values():
|
||
|
|
d.mkdir(exist_ok=True)
|
||
|
|
|
||
|
|
# 遍历图片
|
||
|
|
for img_path in input_folder.glob("*.*"):
|
||
|
|
if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp', '.tif']:
|
||
|
|
continue
|
||
|
|
|
||
|
|
img = cv2.imread(str(img_path))
|
||
|
|
if img is None:
|
||
|
|
continue
|
||
|
|
|
||
|
|
roi_img = preprocess(img, target_size)
|
||
|
|
class_name, confidence = classify_image(roi_img, model)
|
||
|
|
|
||
|
|
suffix = f"_{class_name}_conf{confidence:.2f}"
|
||
|
|
dst_path = class_dirs[class_name] / f"{img_path.stem}{suffix}{img_path.suffix}"
|
||
|
|
cv2.imwrite(dst_path, roi_img)
|
||
|
|
|
||
|
|
print(f"{img_path.name}{suffix} -> {class_name} (confidence={confidence:.4f})")
|
||
|
|
|
||
|
|
# ------------------- 示例调用 -------------------
|
||
|
|
if __name__ == "__main__":
|
||
|
|
model_path = "best1.pt"
|
||
|
|
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailai_cls/train/class1"
|
||
|
|
output_root = "./classified_results"
|
||
|
|
target_size = 640
|
||
|
|
|
||
|
|
batch_classify_images(model_path, input_folder, output_root, target_size)
|