from ultralytics import YOLO from ultralytics.utils.ops import non_max_suppression import torch import cv2 import os import time from pathlib import Path # ====================== # 配置参数 # ====================== MODEL_PATH = 'detect.pt' # 你的模型路径 INPUT_FOLDER = '/home/hx/开发/ailai_image_obb/ailai_pc/train' # 输入图片文件夹 OUTPUT_FOLDER = '/home/hx/开发/ailai_image_obb/ailai_pc/results' # 输出结果文件夹(自动创建) CONF_THRESH = 0.5 IOU_THRESH = 0.45 CLASS_NAMES = ['bag'] DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' IMG_SIZE = 640 SHOW_IMAGE = False # 是否逐张显示图像(适合调试) # 支持的图像格式 IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} # ====================== # 获取文件夹中所有图片路径 # ====================== def get_image_paths(folder): folder = Path(folder) if not folder.exists(): raise FileNotFoundError(f"输入文件夹不存在: {folder}") paths = [p for p in folder.iterdir() if p.suffix.lower() in IMG_EXTENSIONS] if not paths: print(f"⚠️ 在 {folder} 中未找到图片") return sorted(paths) # 按名称排序 # ====================== # 主函数(批量推理) # ====================== def main(): print(f"✅ 使用设备: {DEVICE}") # 创建输出文件夹 os.makedirs(OUTPUT_FOLDER, exist_ok=True) print(f"📁 输出结果将保存到: {OUTPUT_FOLDER}") # 加载模型 print("➡️ 加载 YOLO 模型...") model = YOLO(MODEL_PATH) model.to(DEVICE) # 获取图片列表 img_paths = get_image_paths(INPUT_FOLDER) if not img_paths: return print(f"📸 共找到 {len(img_paths)} 张图片,开始批量推理...\n") total_start_time = time.time() for idx, img_path in enumerate(img_paths, 1): print(f"{'=' * 50}") print(f"🖼️ 处理第 {idx}/{len(img_paths)} 张: {img_path.name}") # 手动计时 start_time = time.time() # 推理(verbose=True 输出内部耗时) results = model(str(img_path), imgsz=IMG_SIZE, conf=CONF_THRESH, device=DEVICE, verbose=True) inference_time = time.time() - start_time # 获取结果 r = results[0] pred = r.boxes.data # GPU 上的原始输出 # 在 GPU 上做 NMS det = non_max_suppression( pred.unsqueeze(0), conf_thres=CONF_THRESH, iou_thres=IOU_THRESH, classes=None, agnostic=False, max_det=100 )[0] # 拷贝到 CPU(仅一次) if det is not None and len(det): det = det.cpu().numpy() else: det = [] # 读取图像并绘制 img = cv2.imread(str(img_path)) if img is None: print(f"❌ 无法读取图像: {img_path}") continue print(f"\n📋 检测结果:") for *xyxy, conf, cls_id in det: x1, y1, x2, y2 = map(int, xyxy) cls_name = CLASS_NAMES[int(cls_id)] print(f" 类别: {cls_name}, 置信度: {conf:.3f}, 框: [{x1}, {y1}, {x2}, {y2}]") cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) label = f"{cls_name} {conf:.2f}" cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) # 保存结果 output_path = os.path.join(OUTPUT_FOLDER, f"result_{img_path.name}") cv2.imwrite(output_path, img) print(f"\n✅ 结果已保存: {output_path}") # 显示(可选) if SHOW_IMAGE: cv2.imshow("Detection", img) if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 退出 break # 输出总耗时 total_infer_time = time.time() - start_time print(f"⏱️ 总处理时间: {total_infer_time * 1000:.1f}ms (推理+后处理)") # 结束 total_elapsed = time.time() - total_start_time print(f"\n🎉 批量推理完成!共处理 {len(img_paths)} 张图片,总耗时: {total_elapsed:.2f} 秒") print( f"🚀 平均每张: {total_elapsed / len(img_paths) * 1000:.1f} ms ({1 / (total_elapsed / len(img_paths)):.1f} FPS)") if SHOW_IMAGE: cv2.destroyAllWindows() if __name__ == '__main__': main()