66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
import cv2
|
||
import os
|
||
import shutil
|
||
from ultralytics import YOLO
|
||
|
||
# ====================== 配置 ======================
|
||
MODEL_PATH = 'point.pt'
|
||
IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/ailaipoint'
|
||
|
||
OUTPUT_ROOT = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/ailaipoint/train_split'
|
||
OUTPUT_DIR_0 = os.path.join(OUTPUT_ROOT, '0') # 无目标 / conf=0
|
||
OUTPUT_DIR_1 = os.path.join(OUTPUT_ROOT, '1') # 0 < conf < 0.5
|
||
OUTPUT_DIR_2 = os.path.join(OUTPUT_ROOT, '2') # conf >= 0.5
|
||
|
||
for d in [OUTPUT_DIR_0, OUTPUT_DIR_1, OUTPUT_DIR_2]:
|
||
os.makedirs(d, exist_ok=True)
|
||
|
||
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
|
||
|
||
# ====================== 主程序 ======================
|
||
if __name__ == "__main__":
|
||
print("🚀 bbox 置信度分桶(移动原图,含无目标图像)")
|
||
|
||
model = YOLO(MODEL_PATH)
|
||
|
||
image_files = [
|
||
f for f in os.listdir(IMAGE_SOURCE_DIR)
|
||
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
|
||
]
|
||
|
||
print(f"📸 找到图片 {len(image_files)} 张")
|
||
|
||
for img_name in image_files:
|
||
src_path = os.path.join(IMAGE_SOURCE_DIR, img_name)
|
||
|
||
img = cv2.imread(src_path)
|
||
if img is None:
|
||
continue
|
||
|
||
results = model(img, verbose=False)
|
||
|
||
# ====================== 关键修复点 ======================
|
||
if not results or results[0].boxes is None or len(results[0].boxes.conf) == 0:
|
||
# 没有任何检测框 → 当作 conf = 0
|
||
bbox_conf = 0.0
|
||
else:
|
||
# 有检测框 → 取第一个(或最大 conf)
|
||
bbox_conf = float(results[0].boxes.conf[0].cpu().item())
|
||
|
||
# ====================== 分桶 ======================
|
||
if bbox_conf == 0:
|
||
dst_dir = OUTPUT_DIR_0
|
||
elif bbox_conf < 0.5:
|
||
dst_dir = OUTPUT_DIR_1
|
||
else:
|
||
dst_dir = OUTPUT_DIR_2
|
||
|
||
dst_path = os.path.join(dst_dir, img_name)
|
||
|
||
# ====================== 移动文件 ======================
|
||
shutil.move(src_path, dst_path)
|
||
|
||
print(f"{img_name} -> conf={bbox_conf:.3f} -> {os.path.basename(dst_dir)}")
|
||
|
||
print("✅ 完成(含无目标图片)")
|