Files
zjsh_yolov11/zhuangtai_class_cls/tuili_f_yuantusave.py
琉璃月光 df7c0730f5 bushu
2025-10-21 14:11:52 +08:00

198 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from pathlib import Path
import cv2
import numpy as np
from ultralytics import YOLO
# ---------------------------
# 类别映射
# ---------------------------
CLASS_NAMES = {
0: "未堆料",
1: "小堆料",
2: "大堆料",
3: "未浇筑满",
4: "浇筑满"
}
# ---------------------------
# 加载 ROI 列表
# ---------------------------
def load_global_rois(txt_path):
rois = []
if not os.path.exists(txt_path):
print(f"❌ ROI 文件不存在: {txt_path}")
return rois
with open(txt_path, 'r') as f:
for line in f:
s = line.strip()
if s:
try:
x, y, w, h = map(int, s.split(','))
rois.append((x, y, w, h))
except Exception as e:
print(f"⚠️ 无法解析 ROI 行 '{s}': {e}")
return rois
# ---------------------------
# 裁剪并 resize ROI
# ---------------------------
def crop_and_resize(img, rois, target_size=640):
crops = []
h_img, w_img = img.shape[:2]
for i, (x, y, w, h) in enumerate(rois):
# 边界检查
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
print(f"⚠️ ROI 超出图像边界,跳过: ({x}, {y}, {w}, {h})")
continue
roi = img[y:y+h, x:x+w]
roi_resized = cv2.resize(roi, (target_size, target_size), interpolation=cv2.INTER_AREA)
crops.append((roi_resized, i))
return crops
# ---------------------------
# class1/class2 加权判断
# ---------------------------
def weighted_small_large(pred_probs, threshold=0.4, w1=0.3, w2=0.7):
p1 = float(pred_probs[1])
p2 = float(pred_probs[2])
total = p1 + p2
if total > 0:
score = (w1 * p1 + w2 * p2) / total
else:
score = 0.0
final_class = "大堆料" if score >= threshold else "小堆料"
return final_class, score, p1, p2
# ---------------------------
# 单张图片推理函数
# ---------------------------
def classify_image_weighted(image, model, threshold=0.5):
results = model(image)
pred_probs = results[0].probs.data.cpu().numpy().flatten()
class_id = int(pred_probs.argmax())
confidence = float(pred_probs[class_id])
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
# 对于 小堆料/大堆料 使用加权逻辑
if class_id in [1, 2]:
final_class, score, p1, p2 = weighted_small_large(pred_probs, threshold=threshold)
else:
final_class = class_name
score = confidence
p1 = float(pred_probs[1])
p2 = float(pred_probs[2])
return final_class, score, p1, p2
# ---------------------------
# 批量推理主函数(保存原图,不改名)
# ---------------------------
def batch_classify_images(model_path, input_folder, output_root, roi_file, target_size=640, threshold=0.5):
# 加载模型
print("🚀 加载模型...")
model = YOLO(model_path)
# 确保输出根目录存在
output_root = Path(output_root)
output_root.mkdir(parents=True, exist_ok=True)
print(f"📁 输出目录: {output_root}")
# 为所有类别创建子目录
class_dirs = {}
for name in CLASS_NAMES.values():
d = output_root / name
d.mkdir(exist_ok=True)
class_dirs[name] = d
# 加载 ROI 区域
rois = load_global_rois(roi_file)
if not rois:
print("❌ 没有有效 ROI退出程序")
return
print(f"🎯 已加载 {len(rois)} 个 ROI 区域")
# 定义缺陷严重性等级(数字越小越严重)
severity_rank = {
"未堆料": 0,
"大堆料": 1,
"小堆料": 2,
"未浇筑满": 3,
"浇筑满": 4
}
# 遍历输入文件夹中的所有图片
input_folder = Path(input_folder)
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
processed_count = 0
for img_path in input_folder.glob("*.*"):
if img_path.suffix.lower() not in image_extensions:
continue
try:
print(f"\n📄 处理: {img_path.name}")
img = cv2.imread(str(img_path))
if img is None:
print(f"❌ 无法读取图像: {img_path.name}")
continue
# 裁剪并缩放 ROI
crops = crop_and_resize(img, rois, target_size)
if not crops:
print(f"⚠️ 无有效 ROI 裁剪区域: {img_path.name}")
continue
detected_classes = []
# 遍历每个 ROI 进行分类
for roi_img, roi_idx in crops:
final_class, score, p1, p2 = classify_image_weighted(roi_img, model, threshold=threshold)
detected_classes.append(final_class)
print(f" 🔍 ROI{roi_idx}: {final_class} (score={score:.2f})")
# 选择最严重的类别
most_severe_class = min(detected_classes, key=lambda x: severity_rank.get(x, 99))
# 构造目标路径:保持原文件名
dst_path = class_dirs[most_severe_class] / img_path.name
# 保存原图(不修改内容,不重命名)
cv2.imwrite(str(dst_path), img)
print(f"✅ 保存 -> [{most_severe_class}] {dst_path}")
processed_count += 1
except Exception as e:
print(f"❌ 处理失败 {img_path.name}: {e}")
print(f"\n🎉 批量处理完成!共处理 {processed_count} 张图像。")
# ---------------------------
# 使用示例(请根据实际情况修改路径)
# ---------------------------
if __name__ == "__main__":
# 🔧 用户配置区
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls_resize/exp_cls2/weights/best.pt"
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1000"
output_root = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1000/classified"
roi_file = "./roi_coordinates/1_rois.txt"
target_size = 640
threshold = 0.4 # 小堆料 vs 大堆料的加权阈值
# 🚀 开始执行
batch_classify_images(
model_path=model_path,
input_folder=input_folder,
output_root=output_root,
roi_file=roi_file,
target_size=target_size,
threshold=threshold
)