Files
zjsh_yolov11/zhuangtai_class_cls/resize_tuili_file.py
琉璃月光 df7c0730f5 bushu
2025-10-21 14:11:52 +08:00

144 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from pathlib import Path
import cv2
import numpy as np
from ultralytics import YOLO
# ---------------------------
# 类别映射
# ---------------------------
CLASS_NAMES = {
0: "未堆料",
1: "小堆料",
2: "大堆料",
3: "未浇筑满",
4: "浇筑满"
}
# ---------------------------
# 加载 ROI 列表
# ---------------------------
def load_global_rois(txt_path):
rois = []
if not os.path.exists(txt_path):
print(f"❌ ROI 文件不存在: {txt_path}")
return rois
with open(txt_path, 'r') as f:
for line in f:
s = line.strip()
if s:
try:
x, y, w, h = map(int, s.split(','))
rois.append((x, y, w, h))
except Exception as e:
print(f"⚠️ 无法解析 ROI 行 '{s}': {e}")
return rois
# ---------------------------
# 裁剪并 resize ROI
# ---------------------------
def crop_and_resize(img, rois, target_size=640):
crops = []
h_img, w_img = img.shape[:2]
for i, (x, y, w, h) in enumerate(rois):
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
continue
roi = img[y:y+h, x:x+w]
roi_resized = cv2.resize(roi, (target_size, target_size), interpolation=cv2.INTER_AREA)
crops.append((roi_resized, i))
return crops
# ---------------------------
# class1/class2 加权判断
# ---------------------------
def weighted_small_large(pred_probs, threshold=0.4, w1=0.3, w2=0.7):
p1 = float(pred_probs[1])
p2 = float(pred_probs[2])
total = p1 + p2
if total > 0:
score = (w1 * p1 + w2 * p2) / total
else:
score = 0.0
final_class = "大堆料" if score >= threshold else "小堆料"
return final_class, score, p1, p2
# ---------------------------
# 单张图片推理函数
# ---------------------------
def classify_image_weighted(image, model, threshold=0.5):
results = model(image)
pred_probs = results[0].probs.data.cpu().numpy().flatten()
class_id = int(pred_probs.argmax())
confidence = float(pred_probs[class_id])
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
# class1/class2 使用加权得分
if class_id in [1, 2]:
final_class, score, p1, p2 = weighted_small_large(pred_probs, threshold=threshold)
else:
final_class = class_name
score = confidence
p1 = float(pred_probs[1])
p2 = float(pred_probs[2])
return final_class, score, p1, p2
# ---------------------------
# 批量推理主函数
# ---------------------------
def batch_classify_images(model_path, input_folder, output_root, roi_file, target_size=640, threshold=0.5):
# 加载模型
model = YOLO(model_path)
# 确保输出根目录存在
output_root = Path(output_root)
output_root.mkdir(parents=True, exist_ok=True)
# 为所有类别创建目录
class_dirs = {}
for name in CLASS_NAMES.values():
d = output_root / name
d.mkdir(exist_ok=True)
class_dirs[name] = d
rois = load_global_rois(roi_file)
if not rois:
print("❌ 没有有效 ROI退出")
return
# 遍历图片
for img_path in Path(input_folder).glob("*.*"):
if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp', '.tif']:
continue
try:
img = cv2.imread(str(img_path))
if img is None:
continue
crops = crop_and_resize(img, rois, target_size)
for roi_resized, roi_idx in crops:
final_class, score, p1, p2 = classify_image_weighted(roi_resized, model, threshold=threshold)
# 文件名中保存 ROI、类别、加权分数、class1/class2 置信度
suffix = f"_roi{roi_idx}_{final_class}_score{score:.2f}_p1{p1:.2f}_p2{p2:.2f}"
dst_path = class_dirs[final_class] / f"{img_path.stem}{suffix}{img_path.suffix}"
cv2.imwrite(dst_path, roi_resized)
print(f"{img_path.name}{suffix} -> {final_class} (score={score:.2f}, p1={p1:.2f}, p2={p2:.2f})")
except Exception as e:
print(f"⚠️ 处理失败 {img_path.name}: {e}")
# ---------------------------
# 使用示例
# ---------------------------
if __name__ == "__main__":
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls_resize/exp_cls2/weights/best.pt"
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1000"
output_root = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1000/classified"
roi_file = "./roi_coordinates/1_rois.txt"
target_size = 640
threshold = 0.4 # 可调节的比例系数
batch_classify_images(model_path, input_folder, output_root, roi_file, target_size, threshold)