Files
zjsh_yolov11/angle_base_point/angle_main.py
琉璃月光 254afbbc43 修改
2025-08-14 18:24:45 +08:00

125 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from ultralytics import YOLO
import time
def calculate_angle_between_lines(line1_start, line1_end, line2_start, line2_end):
"""
计算两条线之间的夹角(单位:度),范围 0~180°
"""
v1 = np.array(line1_end) - np.array(line1_start)
v2 = np.array(line2_end) - np.array(line2_start)
if np.linalg.norm(v1) == 0 or np.linalg.norm(v2) == 0:
return 0.0
v1_u = v1 / np.linalg.norm(v1)
v2_u = v2 / np.linalg.norm(v2)
dot = np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)
angle_rad = np.arccos(dot)
angle_deg = np.degrees(angle_rad)
return angle_deg
def draw_keypoints_and_angle(image, keypoints, angle):
"""
绘制关键点、连线和角度文本
"""
pts = [(int(kp[0]), int(kp[1])) for kp in keypoints]
# 绘制关键点
for i, pt in enumerate(pts):
color = (0, 0, 255) if i in [0, 2] else (255, 0, 0) # 1,3 红2,4 蓝
cv2.circle(image, pt, radius=8, color=color, thickness=-1)
cv2.putText(image, f'{i+1}', (pt[0]+10, pt[1]-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
# 连线
cv2.line(image, pts[0], pts[2], color=(0, 255, 0), thickness=3) # 1-3 绿
cv2.line(image, pts[1], pts[3], color=(255, 255, 0), thickness=3) # 2-4 黄
# 显示角度
text = f"Angle: {angle:.1f}°"
cv2.putText(image, text, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 255), 3)
return image
# ====================== 用户配置 ======================
MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_pose/weights/best.pt' # 替换为你的模型路径
IMAGE_PATH = '/home/hx/yolo/test_image/1.png' # 替换为你的测试图像路径
CONF_THRESHOLD = 0.25
SAVE_RESULT = True
SAVE_PATH = '/home/hx/yolo/test_image/result_with_angle.jpg'
SHOW_IMAGE = True
# ====================== 主程序(带时间统计)======================
if __name__ == "__main__":
# 1. 加载模型(加载时间也计入首次推理)
print("正在加载模型...")
model = YOLO(MODEL_PATH)
print(f"✅ 模型加载完成: {MODEL_PATH}")
# 2. 读取图像
img = cv2.imread(IMAGE_PATH)
if img is None:
raise FileNotFoundError(f"无法加载图像: {IMAGE_PATH}")
print(f"正在处理图像: {IMAGE_PATH}")
# --- 开始计时:总耗时 ---
total_start = time.time()
# --- 模型推理计时 ---
infer_start = time.time()
results = model(img, conf=CONF_THRESHOLD, device='0') # 可指定 GPU
infer_end = time.time()
inference_time = (infer_end - infer_start) * 1000 # 转为毫秒
# --- 后处理计时 ---
post_start = time.time()
detection_count = 0
for result in results:
if result.keypoints is not None:
kpts = result.keypoints.cpu().numpy() # (N, 4, 3)
for i in range(len(kpts)):
instance_kpts = kpts[i]
pt1 = instance_kpts[0][:2]
pt2 = instance_kpts[1][:2]
pt3 = instance_kpts[2][:2]
pt4 = instance_kpts[3][:2]
angle = calculate_angle_between_lines(pt1, pt3, pt2, pt4)
img = draw_keypoints_and_angle(img, instance_kpts, angle)
detection_count += 1
print(f"✅ 检测到实例 {i+1},角度 = {angle:.1f}°")
else:
print("未检测到关键点")
post_end = time.time()
post_time = (post_end - post_start) * 1000
total_end = time.time()
total_time = (total_end - total_start) * 1000
# --- 输出时间统计 ---
print("\n" + "="*50)
print(" 推理时间统计")
print("="*50)
print(f"模型推理耗时: {inference_time:6.2f} ms")
print(f"后处理耗时: {post_time:6.2f} ms")
print(f"总耗时: {total_time:6.2f} ms")
print(f"FPS ≈ {1000 / total_time:6.2f}")
print(f"检测目标数: {detection_count}")
print("="*50)
# 5. 保存或显示
if SAVE_RESULT:
cv2.imwrite(SAVE_PATH, img)
print(f"结果图像已保存至: {SAVE_PATH}")
if SHOW_IMAGE:
cv2.imshow("Keypoint Detection with Angle", img)
print("按任意键关闭窗口...")
cv2.waitKey(0)
cv2.destroyAllWindows()