Files
zjsh_yolov11/muju_cls/main_pc.py
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

153 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from pathlib import Path
import cv2
import numpy as np
import shutil
from ultralytics import YOLO
# ---------------------------
# 三分类类别定义(必须与模型训练时的顺序一致!)
# ---------------------------
CLASS_NAMES = {
0: "模具车1",
1: "模具车2",
2: "有遮挡"
}
# ---------------------------
# 加载 ROI 列表
# ---------------------------
def load_global_rois(txt_path):
rois = []
if not os.path.exists(txt_path):
print(f"❌ ROI 文件不存在: {txt_path}")
return rois
with open(txt_path, 'r') as f:
for line in f:
s = line.strip()
if s:
try:
x, y, w, h = map(int, s.split(','))
rois.append((x, y, w, h))
except Exception as e:
print(f"⚠️ 无法解析 ROI 行 '{s}': {e}")
return rois
# ---------------------------
# 裁剪并 resize ROI
# ---------------------------
def crop_and_resize(img, rois, target_size=640):
crops = []
h_img, w_img = img.shape[:2]
for i, (x, y, w, h) in enumerate(rois):
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
print(f"⚠️ ROI 超出图像边界,跳过: ({x}, {y}, {w}, {h})")
continue
roi = img[y:y+h, x:x+w]
roi_resized = cv2.resize(roi, (target_size, target_size), interpolation=cv2.INTER_AREA)
crops.append((roi_resized, i))
return crops
# ---------------------------
# 单张图片推理函数3分类
# ---------------------------
def classify_image(image, model):
results = model(image)
pred_probs = results[0].probs.data.cpu().numpy().flatten()
class_id = int(pred_probs.argmax())
confidence = float(pred_probs[class_id])
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
return class_name, confidence
# ---------------------------
# 批量推理主函数(移动原图)
# ---------------------------
def batch_classify_images(model_path, input_folder, output_root, roi_file, target_size=640):
# 加载模型
model = YOLO(model_path)
# 创建输出目录
output_root = Path(output_root)
output_root.mkdir(parents=True, exist_ok=True)
class_dirs = {}
for name in CLASS_NAMES.values():
d = output_root / name
d.mkdir(exist_ok=True)
class_dirs[name] = d
# 加载 ROI
rois = load_global_rois(roi_file)
if not rois:
print("❌ 没有有效 ROI退出程序")
return
# 定义严重性等级(数值越小越“正常”,用于取最严重结果)
# 根据你的业务调整:例如“有遮挡”最严重
severity_rank = {
"模具车1": 0,
"模具车2": 1,
"有遮挡": 2
}
input_folder = Path(input_folder)
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
processed_count = 0
for img_path in input_folder.glob("*.*"):
if img_path.suffix.lower() not in image_extensions:
continue
try:
print(f"\n📄 处理: {img_path.name}")
img = cv2.imread(str(img_path))
print(f"图像尺寸: {img.shape[1]} x {img.shape[0]}")
if img is None:
print(f"❌ 无法读取图像: {img_path.name}")
continue
crops = crop_and_resize(img, rois, target_size)
if not crops:
print(f"⚠️ 无有效 ROI 裁剪区域: {img_path.name}")
continue
detected_classes = []
for roi_img, roi_idx in crops:
final_class, conf = classify_image(roi_img, model)
detected_classes.append(final_class)
print(f" 🔍 ROI{roi_idx}: {final_class} (conf={conf:.2f})")
# 选择最严重的类别severity_rank 值最大者)
most_severe_class = max(detected_classes, key=lambda x: severity_rank.get(x, -1))
# 移动原图(不是裁剪图!)
dst_path = class_dirs[most_severe_class] / img_path.name
shutil.move(str(img_path), str(dst_path))
print(f"📦 已移动 -> [{most_severe_class}] {dst_path}")
processed_count += 1
except Exception as e:
print(f"❌ 处理失败 {img_path.name}: {e}")
print(f"\n🎉 批量处理完成!共处理 {processed_count} 张图像。")
# ---------------------------
# 主程序入口
# ---------------------------
if __name__ == "__main__":
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls_resize_muju/exp_cls2/weights/best.pt"
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/61/浇筑满"
output_root = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/61"
roi_file = "/home/hx/yolo/muju_cls/roi_coordinates/muju_roi.txt"
target_size = 640
batch_classify_images(
model_path=model_path,
input_folder=input_folder,
output_root=output_root,
roi_file=roi_file,
target_size=target_size
)