rknn替换,板子是3568的
This commit is contained in:
@ -1,72 +1,76 @@
|
||||
# detect_pt.py
|
||||
import cv2
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.utils.ops import non_max_suppression
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
# ======================
|
||||
# 配置参数
|
||||
# ======================
|
||||
MODEL_PATH = 'best.pt' # 你的训练模型路径(yolov8n.pt 或你自己训练的)
|
||||
#IMG_PATH = '/home/hx/开发/ailai_image_obb/ailai_pc/train/192.168.0.234_01_202510141514352.jpg' # 测试图像路径
|
||||
MODEL_PATH = '/home/hx/开发/ailai_image_obb/ailai_pc/best12.pt'
|
||||
IMG_PATH = '1.jpg'
|
||||
OUTPUT_PATH = '/home/hx/开发/ailai_image_obb/ailai_pc/output_pt.jpg' # 可视化结果保存路径
|
||||
CONF_THRESH = 0.5 # 置信度阈值
|
||||
CLASS_NAMES = ['bag'] # 你的类别名列表(按训练时顺序)
|
||||
|
||||
# 是否显示窗口(适合有 GUI 的 PC)
|
||||
SHOW_IMAGE = True
|
||||
OUTPUT_PATH = 'output_pt.jpg'
|
||||
CONF_THRESH = 0.5
|
||||
IOU_THRESH = 0.45
|
||||
CLASS_NAMES = ['bag']
|
||||
|
||||
# ======================
|
||||
# 主函数
|
||||
# 主函数(优化版)
|
||||
# ======================
|
||||
def main():
|
||||
# 检查 CUDA
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f"✅ 使用设备: {device}")
|
||||
|
||||
# 加载模型
|
||||
print("➡️ 加载 YOLO 模型...")
|
||||
model = YOLO(MODEL_PATH) # 自动加载架构和权重
|
||||
model = YOLO(MODEL_PATH)
|
||||
model.to(device)
|
||||
|
||||
# 推理
|
||||
print("➡️ 开始推理...")
|
||||
results = model(IMG_PATH, imgsz=640, conf=CONF_THRESH, device=device)
|
||||
# 推理:获取原始结果(不立即解析)
|
||||
print("➡️ 开始推理...")
|
||||
results = model(IMG_PATH, imgsz=640, conf=CONF_THRESH, device=device, verbose=True)
|
||||
|
||||
# 获取第一张图的结果
|
||||
r = results[0]
|
||||
|
||||
# 获取原始图像(BGR)
|
||||
# 🚀 关键:使用原始 tensor 在 GPU 上处理
|
||||
# pred: [x1, y1, x2, y2, conf, cls] 形状为 [num_boxes, 6]
|
||||
pred = r.boxes.data # 已经在 GPU 上,类型: torch.Tensor
|
||||
|
||||
# 🔍 在 GPU 上做 NMS(这才是正确姿势)
|
||||
# 注意:non_max_suppression 输入是 [batch, num_boxes, 6]
|
||||
det = non_max_suppression(
|
||||
pred.unsqueeze(0), # 增加 batch 维度
|
||||
conf_thres=CONF_THRESH,
|
||||
iou_thres=IOU_THRESH,
|
||||
classes=None,
|
||||
agnostic=False,
|
||||
max_det=100
|
||||
)[0] # 取第一个(也是唯一一个)batch
|
||||
|
||||
# ✅ 此时所有后处理已完成,现在才从 GPU 拷贝到 CPU
|
||||
if det is not None and len(det):
|
||||
det = det.cpu().numpy() # ← 只拷贝一次!
|
||||
else:
|
||||
det = []
|
||||
|
||||
# 读取图像
|
||||
img = cv2.imread(IMG_PATH)
|
||||
if img is None:
|
||||
raise FileNotFoundError(f"无法读取图像: {IMG_PATH}")
|
||||
|
||||
print("\n📋 检测结果:")
|
||||
for box in r.boxes:
|
||||
# 获取数据
|
||||
xyxy = box.xyxy[0].cpu().numpy() # [x1, y1, x2, y2]
|
||||
conf = box.conf.cpu().numpy()[0] # 置信度
|
||||
cls_id = int(box.cls.cpu().numpy()[0]) # 类别 ID
|
||||
cls_name = CLASS_NAMES[cls_id] # 类别名
|
||||
|
||||
for *xyxy, conf, cls_id in det:
|
||||
x1, y1, x2, y2 = map(int, xyxy)
|
||||
cls_name = CLASS_NAMES[int(cls_id)]
|
||||
print(f" 类别: {cls_name}, 置信度: {conf:.3f}, 框: [{x1}, {y1}, {x2}, {y2}]")
|
||||
|
||||
# 画框
|
||||
# 画框和标签
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
# 画标签
|
||||
label = f"{cls_name} {conf:.2f}"
|
||||
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
||||
|
||||
# 保存结果
|
||||
cv2.imwrite(OUTPUT_PATH, img)
|
||||
print(f"\n🖼️ 可视化结果已保存: {OUTPUT_PATH}")
|
||||
|
||||
# 显示(可选)
|
||||
if SHOW_IMAGE:
|
||||
cv2.imshow("YOLOv8 Detection", img)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
print(f"\n🖼️ 可视化结果已保存: {OUTPUT_PATH}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user