Files
ailai_image_point_diff/ailai_pc/detet_pc_f.py

132 lines
3.5 KiB
Python
Raw Normal View History

2025-11-03 16:10:50 +08:00
from ultralytics import YOLO
from ultralytics.utils.ops import non_max_suppression
import torch
import os
import time
import shutil
2025-11-03 16:10:50 +08:00
from pathlib import Path
# ======================
# 配置参数
# ======================
MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai_detect/weights/best.pt'
INPUT_FOLDER = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailaidete/train/bag'
OUTPUT_FOLDER = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailaidete/train/bag'
CONF_BUCKETS = [0.93, 0.95] # ← ⭐ 自己改这里
2025-11-03 16:10:50 +08:00
IOU_THRESH = 0.45
CLASS_NAMES = ['bag']
2025-11-03 16:10:50 +08:00
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMG_SIZE = 640
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
# ======================
# 获取图片路径
2025-11-03 16:10:50 +08:00
# ======================
def get_image_paths(folder):
folder = Path(folder)
return sorted([p for p in folder.iterdir() if p.suffix.lower() in IMG_EXTENSIONS])
2025-11-03 16:10:50 +08:00
# ======================
# 防止重名覆盖
# ======================
def safe_move(src, dst_dir):
os.makedirs(dst_dir, exist_ok=True)
dst = os.path.join(dst_dir, os.path.basename(src))
if not os.path.exists(dst):
shutil.move(src, dst)
return dst
stem, suffix = os.path.splitext(os.path.basename(src))
i = 1
while True:
new_dst = os.path.join(dst_dir, f"{stem}_{i}{suffix}")
if not os.path.exists(new_dst):
shutil.move(src, new_dst)
return new_dst
i += 1
# ======================
# 根据置信度选择目录
# ======================
def get_bucket_dir(max_conf, output_root, buckets):
for th in sorted(buckets, reverse=True):
if max_conf >= th:
return os.path.join(output_root, f"bag_{th}")
return os.path.join(output_root, "delet")
# ======================
# 主逻辑
2025-11-03 16:10:50 +08:00
# ======================
def main():
print(f"✅ 使用设备: {DEVICE}")
model = YOLO(MODEL_PATH).to(DEVICE)
2025-11-03 16:10:50 +08:00
img_paths = get_image_paths(Path(INPUT_FOLDER))
2025-11-03 16:10:50 +08:00
if not img_paths:
print("⚠️ 没有图片")
2025-11-03 16:10:50 +08:00
return
print(f"📸 共 {len(img_paths)} 张图片")
print(f"📊 置信度档位: {CONF_BUCKETS}\n")
2025-11-03 16:10:50 +08:00
for idx, img_path in enumerate(img_paths, 1):
print(f"{'='*50}")
print(f"🖼️ {idx}/{len(img_paths)}: {img_path.name}")
2025-11-03 16:10:50 +08:00
start_time = time.time()
results = model(
str(img_path),
imgsz=IMG_SIZE,
conf=min(CONF_BUCKETS),
device=DEVICE,
verbose=False
)
2025-11-03 16:10:50 +08:00
r = results[0]
pred = r.boxes.data
2025-11-03 16:10:50 +08:00
det = non_max_suppression(
pred.unsqueeze(0),
conf_thres=min(CONF_BUCKETS),
2025-11-03 16:10:50 +08:00
iou_thres=IOU_THRESH,
classes=None,
agnostic=False,
max_det=100
)[0]
if det is not None and len(det):
det = det.cpu().numpy()
else:
det = []
max_conf = 0.0
for *_, conf, cls_id in det:
if int(cls_id) == 0:
max_conf = max(max_conf, float(conf))
dst_dir = get_bucket_dir(max_conf, OUTPUT_FOLDER, CONF_BUCKETS)
new_path = safe_move(str(img_path), dst_dir)
if max_conf > 0:
print(f"✅ bag max_conf={max_conf:.3f}{os.path.basename(dst_dir)}")
else:
print("❌ 未检测到 bag")
print(f"🚚 已移动到: {new_path}")
print(f"⏱️ {(time.time() - start_time)*1000:.1f} ms")
print("\n🎉 全部处理完成")
2025-11-03 16:10:50 +08:00
if __name__ == '__main__':
main()