Files
zjsh_yolov11/zhuangtai_class_cls/tuili_f_yuantusave61_nojiaquan.py
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

171 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import shutil
from pathlib import Path
import cv2
import numpy as np
from ultralytics import YOLO
# ---------------------------
# 类别映射
# ---------------------------
CLASS_NAMES = {
0: "未堆料",
1: "小堆料",
2: "大堆料",
3: "未浇筑满",
4: "浇筑满"
}
# ---------------------------
# 加载 ROI 列表
# ---------------------------
def load_global_rois(txt_path):
rois = []
if not os.path.exists(txt_path):
print(f"❌ ROI 文件不存在: {txt_path}")
return rois
with open(txt_path, 'r') as f:
for line in f:
s = line.strip()
if s:
try:
x, y, w, h = map(int, s.split(','))
rois.append((x, y, w, h))
except Exception as e:
print(f"⚠️ 无法解析 ROI 行 '{s}': {e}")
return rois
# ---------------------------
# 裁剪并 resize ROI
# ---------------------------
def crop_and_resize(img, rois, target_size=640):
crops = []
h_img, w_img = img.shape[:2]
for i, (x, y, w, h) in enumerate(rois):
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
print(f"⚠️ ROI 超出图像边界,跳过: ({x}, {y}, {w}, {h})")
continue
roi = img[y:y+h, x:x+w]
roi_resized = cv2.resize(
roi, (target_size, target_size),
interpolation=cv2.INTER_AREA
)
crops.append((roi_resized, i))
return crops
# ---------------------------
# 单张图片分类(无加权)
# ---------------------------
def classify_image(image, model):
results = model(image)
probs = results[0].probs.data.cpu().numpy().flatten()
class_id = int(probs.argmax())
confidence = float(probs[class_id])
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
return class_name, confidence
# ---------------------------
# 批量推理主函数(移动文件)
# ---------------------------
def batch_classify_images_move(
model_path,
input_folder,
output_root,
roi_file,
target_size=640
):
print("🚀 加载模型...")
model = YOLO(model_path)
output_root = Path(output_root)
output_root.mkdir(parents=True, exist_ok=True)
print(f"📁 输出目录: {output_root}")
# 创建类别目录
class_dirs = {}
for name in CLASS_NAMES.values():
d = output_root / name
d.mkdir(exist_ok=True)
class_dirs[name] = d
# 加载 ROI
rois = load_global_rois(roi_file)
if not rois:
print("❌ 没有有效 ROI退出")
return
print(f"🎯 已加载 {len(rois)} 个 ROI")
# 严重程度(数值越小越严重)
severity_rank = {
"未堆料": 0,
"大堆料": 1,
"小堆料": 2,
"未浇筑满": 3,
"浇筑满": 4
}
input_folder = Path(input_folder)
image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
processed = 0
for img_path in sorted(input_folder.glob("*.*")):
if img_path.suffix.lower() not in image_exts:
continue
try:
print(f"\n📄 处理: {img_path.name}")
img = cv2.imread(str(img_path))
if img is None:
print("❌ 读取失败")
continue
crops = crop_and_resize(img, rois, target_size)
if not crops:
print("⚠️ 无有效 ROI")
continue
detected_classes = []
for roi_img, roi_idx in crops:
cls, conf = classify_image(roi_img, model)
detected_classes.append(cls)
print(f" 🔍 ROI{roi_idx}: {cls} ({conf:.2f})")
# 取最严重结果
final_class = min(
detected_classes,
key=lambda x: severity_rank.get(x, 99)
)
dst = class_dirs[final_class] / img_path.name
shutil.move(str(img_path), str(dst))
print(f"✅ 移动 -> [{final_class}]")
processed += 1
except Exception as e:
print(f"❌ 处理失败 {img_path.name}: {e}")
print(f"\n🎉 完成,共处理 {processed}")
# ---------------------------
# 使用示例
# ---------------------------
if __name__ == "__main__":
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls_resize1/exp_cls2/weights/best.pt"
input_folder = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/camera02/ready"
output_root = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/camera02/ready"
roi_file = "./roi_coordinates/2.txt"
threshold = 0.4
batch_classify_images_move(
model_path=model_path,
input_folder=input_folder,
output_root=output_root,
roi_file=roi_file,
target_size=640
)