Files
zjsh_yolov11/image/class5.py
琉璃月光 df7c0730f5 bushu
2025-10-21 14:11:52 +08:00

133 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from pathlib import Path
import cv2
import numpy as np
from ultralytics import YOLO
import shutil
# ---------------------------
# 全局配置
# ---------------------------
# ✅ 定义唯一的 ROI (x, y, w, h) —— 请根据你的实际坐标修改!
SINGLE_ROI = (859,810,696,328) # 格式: x, y, w, h
CLASS_NAMES = {
0: "未堆料",
1: "小堆料",
2: "大堆料",
3: "未浇筑满",
4: "浇筑满"
}
TARGET_SIZE = 640
THRESHOLD = 0.4 # 加权得分阈值
# ---------------------------
# class1/class2 加权判断
# ---------------------------
def weighted_small_large(pred_probs, threshold=0.4, w1=0.3, w2=0.7):
p1 = float(pred_probs[1])
p2 = float(pred_probs[2])
total = p1 + p2
if total > 0:
score = (w1 * p1 + w2 * p2) / total
else:
score = 0.0
final_class = "大堆料" if score >= threshold else "小堆料"
return final_class, score, p1, p2
# ---------------------------
# 批量推理主函数(单 ROI + 批量处理)
# ---------------------------
def batch_classify_images(model_path, input_folder, output_root, target_size=640, threshold=0.4):
model = YOLO(model_path)
output_root = Path(output_root)
output_root.mkdir(parents=True, exist_ok=True)
# 创建每个类别的输出目录
class_dirs = {}
for name in CLASS_NAMES.values():
d = output_root / name
d.mkdir(exist_ok=True)
class_dirs[name] = d
x, y, w, h = SINGLE_ROI
# 存储所有裁剪后的图像和对应的原始路径
crops = []
img_paths = []
input_folder = Path(input_folder)
supported_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tif'}
print("🔍 正在裁剪所有图片的 ROI 区域...")
for img_path in input_folder.glob("*.*"):
if img_path.suffix.lower() not in supported_exts:
continue
img = cv2.imread(str(img_path))
if img is None:
continue
# 裁剪并 resize 到模型输入大小
roi = img[y:y+h, x:x+w]
roi_resized = cv2.resize(roi, (target_size, target_size), interpolation=cv2.INTER_AREA)
crops.append(roi_resized)
img_paths.append(img_path)
if not crops:
print("❌ 没有有效图片可供处理")
return
total = len(crops)
print(f"✅ 共准备 {total} 张图片,开始批量推理...")
# 🔥 批量推理(一次 forward pass
results = model(
source=crops,
verbose=False,
imgsz=target_size,
half=False, # 如果使用 GPU 可开启半精度: half=True
device=0 # 使用 GPU 0如用 CPU 改为: device=None
)
# 后处理结果
print("📦 正在处理结果并保存...")
for i, result in enumerate(results):
pred_probs = result.probs.data.cpu().numpy().flatten()
class_id = int(pred_probs.argmax())
confidence = float(pred_probs[class_id])
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
if class_id in [1, 2]:
final_class, score, p1, p2 = weighted_small_large(pred_probs, threshold=threshold)
else:
final_class = class_name
score = confidence
p1 = float(pred_probs[1])
p2 = float(pred_probs[2])
# 构造目标路径
original_path = img_paths[i]
suffix = f"_roi0_{final_class}_score{score:.2f}_p1{p1:.2f}_p2{p2:.2f}"
dst_path = class_dirs[final_class] / f"{original_path.stem}{suffix}{original_path.suffix}"
# 复制原图到对应类别文件夹
shutil.copy2(str(original_path), str(dst_path))
print(f"{original_path.name} -> {final_class} (score={score:.2f}, p1={p1:.2f}, p2={p2:.2f})")
print(f"\n🎉 分类完成!共处理 {total} 张图片")
# ---------------------------
# 使用示例
# ---------------------------
if __name__ == "__main__":
MODEL_PATH = r"cls5.pt"
INPUT_FOLDER = r"./test_image"
OUTPUT_ROOT = r"./classified_images"
batch_classify_images(MODEL_PATH, INPUT_FOLDER, OUTPUT_ROOT, TARGET_SIZE, THRESHOLD)