2025-11-03 16:10:50 +08:00
|
|
|
from ultralytics import YOLO
|
|
|
|
|
from ultralytics.utils.ops import non_max_suppression
|
|
|
|
|
import torch
|
|
|
|
|
import os
|
|
|
|
|
import time
|
2025-12-30 17:29:49 +08:00
|
|
|
import shutil
|
2025-11-03 16:10:50 +08:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
|
|
# 配置参数
|
|
|
|
|
# ======================
|
2025-12-30 17:29:49 +08:00
|
|
|
MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai_detect/weights/best.pt'
|
|
|
|
|
INPUT_FOLDER = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailaidete/train/bag'
|
|
|
|
|
OUTPUT_FOLDER = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailaidete/train/bag'
|
|
|
|
|
|
|
|
|
|
CONF_BUCKETS = [0.93, 0.95] # ← ⭐ 自己改这里
|
2025-11-03 16:10:50 +08:00
|
|
|
IOU_THRESH = 0.45
|
|
|
|
|
CLASS_NAMES = ['bag']
|
2025-12-30 17:29:49 +08:00
|
|
|
|
2025-11-03 16:10:50 +08:00
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
IMG_SIZE = 640
|
|
|
|
|
|
|
|
|
|
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ======================
|
2025-12-30 17:29:49 +08:00
|
|
|
# 获取图片路径
|
2025-11-03 16:10:50 +08:00
|
|
|
# ======================
|
|
|
|
|
def get_image_paths(folder):
|
|
|
|
|
folder = Path(folder)
|
2025-12-30 17:29:49 +08:00
|
|
|
return sorted([p for p in folder.iterdir() if p.suffix.lower() in IMG_EXTENSIONS])
|
2025-11-03 16:10:50 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# ======================
|
2025-12-30 17:29:49 +08:00
|
|
|
# 防止重名覆盖
|
|
|
|
|
# ======================
|
|
|
|
|
def safe_move(src, dst_dir):
|
|
|
|
|
os.makedirs(dst_dir, exist_ok=True)
|
|
|
|
|
dst = os.path.join(dst_dir, os.path.basename(src))
|
|
|
|
|
if not os.path.exists(dst):
|
|
|
|
|
shutil.move(src, dst)
|
|
|
|
|
return dst
|
|
|
|
|
|
|
|
|
|
stem, suffix = os.path.splitext(os.path.basename(src))
|
|
|
|
|
i = 1
|
|
|
|
|
while True:
|
|
|
|
|
new_dst = os.path.join(dst_dir, f"{stem}_{i}{suffix}")
|
|
|
|
|
if not os.path.exists(new_dst):
|
|
|
|
|
shutil.move(src, new_dst)
|
|
|
|
|
return new_dst
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
|
|
# 根据置信度选择目录
|
|
|
|
|
# ======================
|
|
|
|
|
def get_bucket_dir(max_conf, output_root, buckets):
|
|
|
|
|
for th in sorted(buckets, reverse=True):
|
|
|
|
|
if max_conf >= th:
|
|
|
|
|
return os.path.join(output_root, f"bag_{th}")
|
|
|
|
|
return os.path.join(output_root, "delet")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
|
|
# 主逻辑
|
2025-11-03 16:10:50 +08:00
|
|
|
# ======================
|
|
|
|
|
def main():
|
|
|
|
|
print(f"✅ 使用设备: {DEVICE}")
|
|
|
|
|
|
2025-12-30 17:29:49 +08:00
|
|
|
model = YOLO(MODEL_PATH).to(DEVICE)
|
2025-11-03 16:10:50 +08:00
|
|
|
|
2025-12-30 17:29:49 +08:00
|
|
|
img_paths = get_image_paths(Path(INPUT_FOLDER))
|
2025-11-03 16:10:50 +08:00
|
|
|
if not img_paths:
|
2025-12-30 17:29:49 +08:00
|
|
|
print("⚠️ 没有图片")
|
2025-11-03 16:10:50 +08:00
|
|
|
return
|
|
|
|
|
|
2025-12-30 17:29:49 +08:00
|
|
|
print(f"📸 共 {len(img_paths)} 张图片")
|
|
|
|
|
print(f"📊 置信度档位: {CONF_BUCKETS}\n")
|
2025-11-03 16:10:50 +08:00
|
|
|
|
|
|
|
|
for idx, img_path in enumerate(img_paths, 1):
|
2025-12-30 17:29:49 +08:00
|
|
|
print(f"{'='*50}")
|
|
|
|
|
print(f"🖼️ {idx}/{len(img_paths)}: {img_path.name}")
|
2025-11-03 16:10:50 +08:00
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
2025-12-30 17:29:49 +08:00
|
|
|
results = model(
|
|
|
|
|
str(img_path),
|
|
|
|
|
imgsz=IMG_SIZE,
|
|
|
|
|
conf=min(CONF_BUCKETS),
|
|
|
|
|
device=DEVICE,
|
|
|
|
|
verbose=False
|
|
|
|
|
)
|
2025-11-03 16:10:50 +08:00
|
|
|
|
|
|
|
|
r = results[0]
|
2025-12-30 17:29:49 +08:00
|
|
|
pred = r.boxes.data
|
2025-11-03 16:10:50 +08:00
|
|
|
|
|
|
|
|
det = non_max_suppression(
|
|
|
|
|
pred.unsqueeze(0),
|
2025-12-30 17:29:49 +08:00
|
|
|
conf_thres=min(CONF_BUCKETS),
|
2025-11-03 16:10:50 +08:00
|
|
|
iou_thres=IOU_THRESH,
|
|
|
|
|
classes=None,
|
|
|
|
|
agnostic=False,
|
|
|
|
|
max_det=100
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
if det is not None and len(det):
|
|
|
|
|
det = det.cpu().numpy()
|
|
|
|
|
else:
|
|
|
|
|
det = []
|
|
|
|
|
|
2025-12-30 17:29:49 +08:00
|
|
|
max_conf = 0.0
|
|
|
|
|
for *_, conf, cls_id in det:
|
|
|
|
|
if int(cls_id) == 0:
|
|
|
|
|
max_conf = max(max_conf, float(conf))
|
|
|
|
|
|
|
|
|
|
dst_dir = get_bucket_dir(max_conf, OUTPUT_FOLDER, CONF_BUCKETS)
|
|
|
|
|
new_path = safe_move(str(img_path), dst_dir)
|
|
|
|
|
|
|
|
|
|
if max_conf > 0:
|
|
|
|
|
print(f"✅ bag max_conf={max_conf:.3f} → {os.path.basename(dst_dir)}")
|
|
|
|
|
else:
|
|
|
|
|
print("❌ 未检测到 bag")
|
|
|
|
|
|
|
|
|
|
print(f"🚚 已移动到: {new_path}")
|
|
|
|
|
print(f"⏱️ {(time.time() - start_time)*1000:.1f} ms")
|
|
|
|
|
|
|
|
|
|
print("\n🎉 全部处理完成")
|
2025-11-03 16:10:50 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2025-12-30 17:29:49 +08:00
|
|
|
main()
|