from ultralytics import YOLO from ultralytics.utils.ops import non_max_suppression import torch import os import time import shutil from pathlib import Path # ====================== # 配置参数 # ====================== MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai_detect3/weights/best.pt" IMAGE_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailaidete/train/delet" INPUT_FOLDER = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailaidete/train/delet" OUTPUT_FOLDER = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailaidete/train/delet" CONF_BUCKETS = [0.93, 0.95] # ← ⭐ 自己改这里 IOU_THRESH = 0.45 CLASS_NAMES = ['bag','bag35'] DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' IMG_SIZE = 640 IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} # ====================== # 获取图片路径 # ====================== def get_image_paths(folder): folder = Path(folder) return sorted([p for p in folder.iterdir() if p.suffix.lower() in IMG_EXTENSIONS]) # ====================== # 防止重名覆盖 # ====================== def safe_move(src, dst_dir): os.makedirs(dst_dir, exist_ok=True) dst = os.path.join(dst_dir, os.path.basename(src)) if not os.path.exists(dst): shutil.move(src, dst) return dst stem, suffix = os.path.splitext(os.path.basename(src)) i = 1 while True: new_dst = os.path.join(dst_dir, f"{stem}_{i}{suffix}") if not os.path.exists(new_dst): shutil.move(src, new_dst) return new_dst i += 1 # ====================== # 根据置信度选择目录 # ====================== def get_bucket_dir(max_conf, output_root, buckets): for th in sorted(buckets, reverse=True): if max_conf >= th: return os.path.join(output_root, f"bag_{th}") return os.path.join(output_root, "delet") # ====================== # 主逻辑 # ====================== def main(): print(f"✅ 使用设备: {DEVICE}") model = YOLO(MODEL_PATH).to(DEVICE) img_paths = get_image_paths(Path(INPUT_FOLDER)) if not img_paths: print("⚠️ 没有图片") return print(f"📸 共 {len(img_paths)} 张图片") print(f"📊 置信度档位: {CONF_BUCKETS}\n") for idx, img_path in enumerate(img_paths, 1): print(f"{'='*50}") print(f"🖼️ {idx}/{len(img_paths)}: {img_path.name}") start_time = time.time() results = model( str(img_path), imgsz=IMG_SIZE, conf=min(CONF_BUCKETS), device=DEVICE, verbose=False ) r = results[0] pred = r.boxes.data det = non_max_suppression( pred.unsqueeze(0), conf_thres=min(CONF_BUCKETS), iou_thres=IOU_THRESH, classes=None, agnostic=False, max_det=100 )[0] if det is not None and len(det): det = det.cpu().numpy() else: det = [] max_conf = 0.0 for *_, conf, cls_id in det: if int(cls_id) == 0: max_conf = max(max_conf, float(conf)) dst_dir = get_bucket_dir(max_conf, OUTPUT_FOLDER, CONF_BUCKETS) new_path = safe_move(str(img_path), dst_dir) if max_conf > 0: print(f"✅ bag max_conf={max_conf:.3f} → {os.path.basename(dst_dir)}") else: print("❌ 未检测到 bag") print(f"🚚 已移动到: {new_path}") print(f"⏱️ {(time.time() - start_time)*1000:.1f} ms") print("\n🎉 全部处理完成") if __name__ == '__main__': main()