Files
zjsh_yolov11/angle_base_point/point_test.py

94 lines
3.3 KiB
Python
Raw Normal View History

2025-08-14 18:24:45 +08:00
import cv2
import numpy as np
from ultralytics import YOLO
import os
# ====================== 用户配置 ======================
MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_pose/weights/best.pt'
2025-09-01 14:14:18 +08:00
IMAGE_PATH = '/output_masks/3.png'
2025-08-14 18:24:45 +08:00
OUTPUT_DIR = '/home/hx/yolo/output_images' # 保存结果的文件夹
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ====================== 可视化函数 ======================
def draw_keypoints_on_image(image, kpts_xy, kpts_conf, orig_shape):
"""
在图像上绘制关键点
:param image: OpenCV 图像
:param kpts_xy: (N, K, 2) 坐标
:param kpts_conf: (N, K) 置信度
:param orig_shape: 原图尺寸 (H, W)
"""
colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0)] # 1红, 2蓝, 3绿, 4青
for i in range(len(kpts_xy)):
xy = kpts_xy[i] # (4, 2)
conf = kpts_conf[i] if kpts_conf.ndim == 2 else kpts_conf[i:i+1] # (4,) 或标量
for j in range(len(xy)):
x, y = xy[j]
c = conf[j] if hasattr(conf, '__len__') else conf
x, y = int(x), int(y)
# 检查坐标是否在图像范围内
if x < 0 or y < 0 or x >= orig_shape[1] or y >= orig_shape[0]:
continue
# 只绘制置信度 > 0.5 的点
if c < 0.5:
continue
# 绘制圆圈
cv2.circle(image, (x, y), radius=15, color=colors[j], thickness=-1)
# 标注编号
cv2.putText(image, f'{j+1}', (x + 20, y - 20),
cv2.FONT_HERSHEY_SIMPLEX, 1.5, colors[j], 5)
return image
# ====================== 主程序 ======================
if __name__ == "__main__":
print("正在加载模型...")
model = YOLO(MODEL_PATH)
print(f"✅ 模型加载完成: {MODEL_PATH}")
img = cv2.imread(IMAGE_PATH)
if img is None:
raise FileNotFoundError(f"无法加载图像: {IMAGE_PATH}")
print(f"✅ 图像加载成功: {IMAGE_PATH} (shape: {img.shape})")
print("\n开始推理...")
results = model(img)
for i, result in enumerate(results):
if result.keypoints is not None:
kpts = result.keypoints
orig_shape = kpts.orig_shape # (H, W)
# ✅ 正确获取坐标和置信度
kpts_xy = kpts.xy.cpu().numpy() # (N, K, 2)
kpts_conf = kpts.conf.cpu().numpy() if kpts.conf is not None else np.ones(kpts_xy.shape[:2])
print(f"\n 结果 {i + 1}:")
print(f" 检测到 {len(kpts_xy)} 个实例")
print(f" 坐标形状: {kpts_xy.shape}")
print(f" 置信度形状: {kpts_conf.shape}")
# 绘图
img_with_kpts = draw_keypoints_on_image(img.copy(), kpts_xy, kpts_conf, orig_shape)
# 保存
base_name = os.path.basename(IMAGE_PATH)
save_path = os.path.join(OUTPUT_DIR, f"keypoints_{base_name}")
cv2.imwrite(save_path, img_with_kpts)
print(f"✅ 结果已保存: {save_path}")
# 显示(可选)
display_img = cv2.resize(img_with_kpts, (1280, 720)) # 缩小显示
cv2.imshow("Keypoints", display_img)
print("按任意键关闭...")
cv2.waitKey(0)
cv2.destroyAllWindows()
else:
print(" ❌ 未检测到关键点")