更新加入料带目标检测,判断料带到位,以及控制滚筒逻辑

This commit is contained in:
琉璃月光
2025-12-30 17:29:49 +08:00
parent d6918e90f2
commit 2028a96819
27 changed files with 1499 additions and 1224 deletions

View File

@ -1,135 +1,131 @@
from ultralytics import YOLO
from ultralytics.utils.ops import non_max_suppression
import torch
import cv2
import os
import time
import shutil
from pathlib import Path
# ======================
# 配置参数
# ======================
MODEL_PATH = 'detect.pt' # 你的模型路径
INPUT_FOLDER = '/home/hx/开发/ailai_image_obb/ailai_pc/train' # 输入图片文件夹
OUTPUT_FOLDER = '/home/hx/开发/ailai_image_obb/ailai_pc/results' # 输出结果文件夹(自动创建)
CONF_THRESH = 0.5
MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai_detect/weights/best.pt'
INPUT_FOLDER = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailaidete/train/bag'
OUTPUT_FOLDER = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailaidete/train/bag'
CONF_BUCKETS = [0.93, 0.95] # ← ⭐ 自己改这里
IOU_THRESH = 0.45
CLASS_NAMES = ['bag']
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMG_SIZE = 640
SHOW_IMAGE = False # 是否逐张显示图像(适合调试)
# 支持的图像格式
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
# ======================
# 获取文件夹中所有图片路径
# 获取图片路径
# ======================
def get_image_paths(folder):
folder = Path(folder)
if not folder.exists():
raise FileNotFoundError(f"输入文件夹不存在: {folder}")
paths = [p for p in folder.iterdir() if p.suffix.lower() in IMG_EXTENSIONS]
if not paths:
print(f"⚠️ 在 {folder} 中未找到图片")
return sorted(paths) # 按名称排序
return sorted([p for p in folder.iterdir() if p.suffix.lower() in IMG_EXTENSIONS])
# ======================
# 主函数(批量推理)
# 防止重名覆盖
# ======================
def safe_move(src, dst_dir):
os.makedirs(dst_dir, exist_ok=True)
dst = os.path.join(dst_dir, os.path.basename(src))
if not os.path.exists(dst):
shutil.move(src, dst)
return dst
stem, suffix = os.path.splitext(os.path.basename(src))
i = 1
while True:
new_dst = os.path.join(dst_dir, f"{stem}_{i}{suffix}")
if not os.path.exists(new_dst):
shutil.move(src, new_dst)
return new_dst
i += 1
# ======================
# 根据置信度选择目录
# ======================
def get_bucket_dir(max_conf, output_root, buckets):
for th in sorted(buckets, reverse=True):
if max_conf >= th:
return os.path.join(output_root, f"bag_{th}")
return os.path.join(output_root, "delet")
# ======================
# 主逻辑
# ======================
def main():
print(f"✅ 使用设备: {DEVICE}")
# 创建输出文件夹
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
print(f"📁 输出结果将保存到: {OUTPUT_FOLDER}")
model = YOLO(MODEL_PATH).to(DEVICE)
# 加载模型
print("➡️ 加载 YOLO 模型...")
model = YOLO(MODEL_PATH)
model.to(DEVICE)
# 获取图片列表
img_paths = get_image_paths(INPUT_FOLDER)
img_paths = get_image_paths(Path(INPUT_FOLDER))
if not img_paths:
print("⚠️ 没有图片")
return
print(f"📸 共找到 {len(img_paths)} 张图片,开始批量推理...\n")
total_start_time = time.time()
print(f"📸 共 {len(img_paths)} 张图片")
print(f"📊 置信度档位: {CONF_BUCKETS}\n")
for idx, img_path in enumerate(img_paths, 1):
print(f"{'=' * 50}")
print(f"🖼️ 处理第 {idx}/{len(img_paths)}: {img_path.name}")
print(f"{'='*50}")
print(f"🖼️ {idx}/{len(img_paths)}: {img_path.name}")
# 手动计时
start_time = time.time()
# 推理verbose=True 输出内部耗时)
results = model(str(img_path), imgsz=IMG_SIZE, conf=CONF_THRESH, device=DEVICE, verbose=True)
inference_time = time.time() - start_time
results = model(
str(img_path),
imgsz=IMG_SIZE,
conf=min(CONF_BUCKETS),
device=DEVICE,
verbose=False
)
# 获取结果
r = results[0]
pred = r.boxes.data # GPU 上的原始输出
pred = r.boxes.data
# 在 GPU 上做 NMS
det = non_max_suppression(
pred.unsqueeze(0),
conf_thres=CONF_THRESH,
conf_thres=min(CONF_BUCKETS),
iou_thres=IOU_THRESH,
classes=None,
agnostic=False,
max_det=100
)[0]
# 拷贝到 CPU仅一次
if det is not None and len(det):
det = det.cpu().numpy()
else:
det = []
# 读取图像并绘制
img = cv2.imread(str(img_path))
if img is None:
print(f"❌ 无法读取图像: {img_path}")
continue
max_conf = 0.0
for *_, conf, cls_id in det:
if int(cls_id) == 0:
max_conf = max(max_conf, float(conf))
print(f"\n📋 检测结果:")
for *xyxy, conf, cls_id in det:
x1, y1, x2, y2 = map(int, xyxy)
cls_name = CLASS_NAMES[int(cls_id)]
print(f" 类别: {cls_name}, 置信度: {conf:.3f}, 框: [{x1}, {y1}, {x2}, {y2}]")
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"{cls_name} {conf:.2f}"
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
dst_dir = get_bucket_dir(max_conf, OUTPUT_FOLDER, CONF_BUCKETS)
new_path = safe_move(str(img_path), dst_dir)
# 保存结果
output_path = os.path.join(OUTPUT_FOLDER, f"result_{img_path.name}")
cv2.imwrite(output_path, img)
print(f"\n✅ 结果已保存: {output_path}")
if max_conf > 0:
print(f"✅ bag max_conf={max_conf:.3f}{os.path.basename(dst_dir)}")
else:
print("❌ 未检测到 bag")
# 显示(可选)
if SHOW_IMAGE:
cv2.imshow("Detection", img)
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 退出
break
print(f"🚚 已移动到: {new_path}")
print(f"⏱️ {(time.time() - start_time)*1000:.1f} ms")
# 输出总耗时
total_infer_time = time.time() - start_time
print(f"⏱️ 总处理时间: {total_infer_time * 1000:.1f}ms (推理+后处理)")
# 结束
total_elapsed = time.time() - total_start_time
print(f"\n🎉 批量推理完成!共处理 {len(img_paths)} 张图片,总耗时: {total_elapsed:.2f}")
print(
f"🚀 平均每张: {total_elapsed / len(img_paths) * 1000:.1f} ms ({1 / (total_elapsed / len(img_paths)):.1f} FPS)")
if SHOW_IMAGE:
cv2.destroyAllWindows()
print("\n🎉 全部处理完成")
if __name__ == '__main__':
main()
main()