135 lines
4.3 KiB
Python
135 lines
4.3 KiB
Python
|
|
from ultralytics import YOLO
|
|||
|
|
from ultralytics.utils.ops import non_max_suppression
|
|||
|
|
import torch
|
|||
|
|
import cv2
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
# ======================
|
|||
|
|
# 配置参数
|
|||
|
|
# ======================
|
|||
|
|
MODEL_PATH = 'detect.pt' # 你的模型路径
|
|||
|
|
INPUT_FOLDER = '/home/hx/开发/ailai_image_obb/ailai_pc/train' # 输入图片文件夹
|
|||
|
|
OUTPUT_FOLDER = '/home/hx/开发/ailai_image_obb/ailai_pc/results' # 输出结果文件夹(自动创建)
|
|||
|
|
CONF_THRESH = 0.5
|
|||
|
|
IOU_THRESH = 0.45
|
|||
|
|
CLASS_NAMES = ['bag']
|
|||
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|||
|
|
IMG_SIZE = 640
|
|||
|
|
SHOW_IMAGE = False # 是否逐张显示图像(适合调试)
|
|||
|
|
|
|||
|
|
# 支持的图像格式
|
|||
|
|
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ======================
|
|||
|
|
# 获取文件夹中所有图片路径
|
|||
|
|
# ======================
|
|||
|
|
def get_image_paths(folder):
|
|||
|
|
folder = Path(folder)
|
|||
|
|
if not folder.exists():
|
|||
|
|
raise FileNotFoundError(f"输入文件夹不存在: {folder}")
|
|||
|
|
paths = [p for p in folder.iterdir() if p.suffix.lower() in IMG_EXTENSIONS]
|
|||
|
|
if not paths:
|
|||
|
|
print(f"⚠️ 在 {folder} 中未找到图片")
|
|||
|
|
return sorted(paths) # 按名称排序
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ======================
|
|||
|
|
# 主函数(批量推理)
|
|||
|
|
# ======================
|
|||
|
|
def main():
|
|||
|
|
print(f"✅ 使用设备: {DEVICE}")
|
|||
|
|
|
|||
|
|
# 创建输出文件夹
|
|||
|
|
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
|
|||
|
|
print(f"📁 输出结果将保存到: {OUTPUT_FOLDER}")
|
|||
|
|
|
|||
|
|
# 加载模型
|
|||
|
|
print("➡️ 加载 YOLO 模型...")
|
|||
|
|
model = YOLO(MODEL_PATH)
|
|||
|
|
model.to(DEVICE)
|
|||
|
|
|
|||
|
|
# 获取图片列表
|
|||
|
|
img_paths = get_image_paths(INPUT_FOLDER)
|
|||
|
|
if not img_paths:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
print(f"📸 共找到 {len(img_paths)} 张图片,开始批量推理...\n")
|
|||
|
|
|
|||
|
|
total_start_time = time.time()
|
|||
|
|
|
|||
|
|
for idx, img_path in enumerate(img_paths, 1):
|
|||
|
|
print(f"{'=' * 50}")
|
|||
|
|
print(f"🖼️ 处理第 {idx}/{len(img_paths)} 张: {img_path.name}")
|
|||
|
|
|
|||
|
|
# 手动计时
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 推理(verbose=True 输出内部耗时)
|
|||
|
|
results = model(str(img_path), imgsz=IMG_SIZE, conf=CONF_THRESH, device=DEVICE, verbose=True)
|
|||
|
|
inference_time = time.time() - start_time
|
|||
|
|
|
|||
|
|
# 获取结果
|
|||
|
|
r = results[0]
|
|||
|
|
pred = r.boxes.data # GPU 上的原始输出
|
|||
|
|
|
|||
|
|
# 在 GPU 上做 NMS
|
|||
|
|
det = non_max_suppression(
|
|||
|
|
pred.unsqueeze(0),
|
|||
|
|
conf_thres=CONF_THRESH,
|
|||
|
|
iou_thres=IOU_THRESH,
|
|||
|
|
classes=None,
|
|||
|
|
agnostic=False,
|
|||
|
|
max_det=100
|
|||
|
|
)[0]
|
|||
|
|
|
|||
|
|
# 拷贝到 CPU(仅一次)
|
|||
|
|
if det is not None and len(det):
|
|||
|
|
det = det.cpu().numpy()
|
|||
|
|
else:
|
|||
|
|
det = []
|
|||
|
|
|
|||
|
|
# 读取图像并绘制
|
|||
|
|
img = cv2.imread(str(img_path))
|
|||
|
|
if img is None:
|
|||
|
|
print(f"❌ 无法读取图像: {img_path}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print(f"\n📋 检测结果:")
|
|||
|
|
for *xyxy, conf, cls_id in det:
|
|||
|
|
x1, y1, x2, y2 = map(int, xyxy)
|
|||
|
|
cls_name = CLASS_NAMES[int(cls_id)]
|
|||
|
|
print(f" 类别: {cls_name}, 置信度: {conf:.3f}, 框: [{x1}, {y1}, {x2}, {y2}]")
|
|||
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|||
|
|
label = f"{cls_name} {conf:.2f}"
|
|||
|
|
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
|||
|
|
|
|||
|
|
# 保存结果
|
|||
|
|
output_path = os.path.join(OUTPUT_FOLDER, f"result_{img_path.name}")
|
|||
|
|
cv2.imwrite(output_path, img)
|
|||
|
|
print(f"\n✅ 结果已保存: {output_path}")
|
|||
|
|
|
|||
|
|
# 显示(可选)
|
|||
|
|
if SHOW_IMAGE:
|
|||
|
|
cv2.imshow("Detection", img)
|
|||
|
|
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 退出
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 输出总耗时
|
|||
|
|
total_infer_time = time.time() - start_time
|
|||
|
|
print(f"⏱️ 总处理时间: {total_infer_time * 1000:.1f}ms (推理+后处理)")
|
|||
|
|
|
|||
|
|
# 结束
|
|||
|
|
total_elapsed = time.time() - total_start_time
|
|||
|
|
print(f"\n🎉 批量推理完成!共处理 {len(img_paths)} 张图片,总耗时: {total_elapsed:.2f} 秒")
|
|||
|
|
print(
|
|||
|
|
f"🚀 平均每张: {total_elapsed / len(img_paths) * 1000:.1f} ms ({1 / (total_elapsed / len(img_paths)):.1f} FPS)")
|
|||
|
|
|
|||
|
|
if SHOW_IMAGE:
|
|||
|
|
cv2.destroyAllWindows()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
main()
|