Files
琉璃月光 8b263167f8 更新
2025-12-11 08:37:09 +08:00

190 lines
6.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from pathlib import Path
import cv2
import numpy as np
import shutil
from ultralytics import YOLO
# ---------------------------
# 类别映射
# ---------------------------
CLASS_NAMES = {
0: "未堆料",
1: "小堆料",
2: "大堆料",
3: "未浇筑满",
4: "浇筑满"
}
# ---------------------------
# 加载 ROI 列表
# ---------------------------
def load_global_rois(txt_path):
rois = []
if not os.path.exists(txt_path):
print(f"❌ ROI 文件不存在: {txt_path}")
return rois
with open(txt_path, 'r') as f:
for line in f:
s = line.strip()
if s:
try:
x, y, w, h = map(int, s.split(','))
rois.append((x, y, w, h))
except Exception as e:
print(f"⚠️ 无法解析 ROI 行 '{s}': {e}")
return rois
# ---------------------------
# 裁剪并 resize ROI
# ---------------------------
def crop_and_resize(img, rois, target_size=640):
crops = []
h_img, w_img = img.shape[:2]
for i, (x, y, w, h) in enumerate(rois):
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
print(f"⚠️ ROI 超出图像边界,跳过: ({x}, {y}, {w}, {h})")
continue
roi = img[y:y+h, x:x+w]
roi_resized = cv2.resize(roi, (target_size, target_size), interpolation=cv2.INTER_AREA)
crops.append((roi_resized, i))
return crops
# ---------------------------
# class1/class2 加权判断
# ---------------------------
def weighted_small_large(pred_probs, threshold=0.4, w1=0.3, w2=0.7):
p1 = float(pred_probs[1])
p2 = float(pred_probs[2])
total = p1 + p2
score = (w1 * p1 + w2 * p2) / total if total > 0 else 0.0
final_class = "大堆料" if score >= threshold else "小堆料"
return final_class, score, p1, p2
# ---------------------------
# 单张图片推理函数
# ---------------------------
def classify_image_weighted(image, model, threshold=0.5):
results = model(image)
pred_probs = results[0].probs.data.cpu().numpy().flatten()
class_id = int(pred_probs.argmax())
confidence = float(pred_probs[class_id])
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
if class_id in [1, 2]:
final_class, score, p1, p2 = weighted_small_large(pred_probs, threshold=threshold)
else:
final_class = class_name
score = confidence
p1 = float(pred_probs[1])
p2 = float(pred_probs[2])
return final_class, score, p1, p2
# ---------------------------
# ⭐ 批量推理主函数(增加保存送入模型的图片)
# ---------------------------
def batch_classify_images(model_path, input_folder, output_root, roi_file, target_size=640, threshold=0.5):
model = YOLO(model_path)
output_root = Path(output_root)
output_root.mkdir(parents=True, exist_ok=True)
# 类别文件夹
class_dirs = {name: (output_root / name) for name in CLASS_NAMES.values()}
for d in class_dirs.values():
d.mkdir(exist_ok=True)
# ⭐ 新增:保存推理输入的 debug 文件夹
debug_dir = output_root / "debug_inputs"
debug_dir.mkdir(exist_ok=True)
rois = load_global_rois(roi_file)
if not rois:
print("❌ 没有有效 ROI退出程序")
return
severity_rank = {
"未堆料": 0,
"大堆料": 1,
"小堆料": 2,
"未浇筑满": 3,
"浇筑满": 4
}
input_folder = Path(input_folder)
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
processed_count = 0
for img_path in input_folder.glob("*.*"):
if img_path.suffix.lower() not in image_extensions:
continue
try:
print(f"\n📄 处理: {img_path.name}")
img = cv2.imread(str(img_path))
if img is None:
print(f"❌ 无法读取图像: {img_path.name}")
continue
crops = crop_and_resize(img, rois, target_size)
if not crops:
print(f"⚠️ 无有效 ROI 裁剪区域: {img_path.name}")
continue
detected_classes = []
for roi_img, roi_idx in crops:
# -------------------------------
# ⭐ 保存送入推理模型的 ROI 输入
# -------------------------------
debug_path = debug_dir / f"{img_path.stem}_ROI{roi_idx}.jpg"
cv2.imwrite(str(debug_path), roi_img)
final_class, score, p1, p2 = classify_image_weighted(roi_img, model, threshold=threshold)
detected_classes.append(final_class)
print(f" 🔍 ROI{roi_idx}: {final_class} (score={score:.2f})")
most_severe_class = min(detected_classes, key=lambda x: severity_rank.get(x, 99))
# -----------------------------
# 移动原图
# -----------------------------
dst_path = class_dirs[most_severe_class] / img_path.name
shutil.move(str(img_path), str(dst_path))
print(f"📦 已移动 -> [{most_severe_class}] {dst_path}")
processed_count += 1
except Exception as e:
print(f"❌ 处理失败 {img_path.name}: {e}")
print(f"\n🎉 批量处理完成!共处理 {processed_count} 张图像。")
# 主程序入口
if __name__ == "__main__":
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls_resize/exp_cls11/weights/best.pt"
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/3"
output_root = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/3c"
roi_file = "./roi_coordinates/3.txt"
target_size = 640
threshold = 0.4
# ⭐ 启用保存送入模型的图像
save_debug = True
batch_classify_images(
model_path=model_path,
input_folder=input_folder,
output_root=output_root,
roi_file=roi_file,
target_size=target_size,
threshold=threshold,
)