2025-08-14 18:24:45 +08:00
|
|
|
import cv2
|
|
|
|
|
import numpy as np
|
|
|
|
|
from ultralytics import YOLO
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
# ====================== 用户配置 ======================
|
|
|
|
|
MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_pose/weights/best.pt'
|
2025-09-01 14:14:18 +08:00
|
|
|
IMAGE_PATH = '/output_masks/3.png'
|
2025-08-14 18:24:45 +08:00
|
|
|
OUTPUT_DIR = '/home/hx/yolo/output_images' # 保存结果的文件夹
|
|
|
|
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# ====================== 可视化函数 ======================
|
|
|
|
|
def draw_keypoints_on_image(image, kpts_xy, kpts_conf, orig_shape):
|
|
|
|
|
"""
|
|
|
|
|
在图像上绘制关键点
|
|
|
|
|
:param image: OpenCV 图像
|
|
|
|
|
:param kpts_xy: (N, K, 2) 坐标
|
|
|
|
|
:param kpts_conf: (N, K) 置信度
|
|
|
|
|
:param orig_shape: 原图尺寸 (H, W)
|
|
|
|
|
"""
|
|
|
|
|
colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0)] # 1红, 2蓝, 3绿, 4青
|
|
|
|
|
|
|
|
|
|
for i in range(len(kpts_xy)):
|
|
|
|
|
xy = kpts_xy[i] # (4, 2)
|
|
|
|
|
conf = kpts_conf[i] if kpts_conf.ndim == 2 else kpts_conf[i:i+1] # (4,) 或标量
|
|
|
|
|
|
|
|
|
|
for j in range(len(xy)):
|
|
|
|
|
x, y = xy[j]
|
|
|
|
|
c = conf[j] if hasattr(conf, '__len__') else conf
|
|
|
|
|
|
|
|
|
|
x, y = int(x), int(y)
|
|
|
|
|
|
|
|
|
|
# 检查坐标是否在图像范围内
|
|
|
|
|
if x < 0 or y < 0 or x >= orig_shape[1] or y >= orig_shape[0]:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 只绘制置信度 > 0.5 的点
|
|
|
|
|
if c < 0.5:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 绘制圆圈
|
|
|
|
|
cv2.circle(image, (x, y), radius=15, color=colors[j], thickness=-1)
|
|
|
|
|
# 标注编号
|
|
|
|
|
cv2.putText(image, f'{j+1}', (x + 20, y - 20),
|
|
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1.5, colors[j], 5)
|
|
|
|
|
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
# ====================== 主程序 ======================
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
print("正在加载模型...")
|
|
|
|
|
model = YOLO(MODEL_PATH)
|
|
|
|
|
print(f"✅ 模型加载完成: {MODEL_PATH}")
|
|
|
|
|
|
|
|
|
|
img = cv2.imread(IMAGE_PATH)
|
|
|
|
|
if img is None:
|
|
|
|
|
raise FileNotFoundError(f"无法加载图像: {IMAGE_PATH}")
|
|
|
|
|
print(f"✅ 图像加载成功: {IMAGE_PATH} (shape: {img.shape})")
|
|
|
|
|
|
|
|
|
|
print("\n开始推理...")
|
|
|
|
|
results = model(img)
|
|
|
|
|
|
|
|
|
|
for i, result in enumerate(results):
|
|
|
|
|
if result.keypoints is not None:
|
|
|
|
|
kpts = result.keypoints
|
|
|
|
|
orig_shape = kpts.orig_shape # (H, W)
|
|
|
|
|
|
|
|
|
|
# ✅ 正确获取坐标和置信度
|
|
|
|
|
kpts_xy = kpts.xy.cpu().numpy() # (N, K, 2)
|
|
|
|
|
kpts_conf = kpts.conf.cpu().numpy() if kpts.conf is not None else np.ones(kpts_xy.shape[:2])
|
|
|
|
|
|
|
|
|
|
print(f"\n 结果 {i + 1}:")
|
|
|
|
|
print(f" 检测到 {len(kpts_xy)} 个实例")
|
|
|
|
|
print(f" 坐标形状: {kpts_xy.shape}")
|
|
|
|
|
print(f" 置信度形状: {kpts_conf.shape}")
|
|
|
|
|
|
|
|
|
|
# 绘图
|
|
|
|
|
img_with_kpts = draw_keypoints_on_image(img.copy(), kpts_xy, kpts_conf, orig_shape)
|
|
|
|
|
|
|
|
|
|
# 保存
|
|
|
|
|
base_name = os.path.basename(IMAGE_PATH)
|
|
|
|
|
save_path = os.path.join(OUTPUT_DIR, f"keypoints_{base_name}")
|
|
|
|
|
cv2.imwrite(save_path, img_with_kpts)
|
|
|
|
|
print(f"✅ 结果已保存: {save_path}")
|
|
|
|
|
|
|
|
|
|
# 显示(可选)
|
|
|
|
|
display_img = cv2.resize(img_with_kpts, (1280, 720)) # 缩小显示
|
|
|
|
|
cv2.imshow("Keypoints", display_img)
|
|
|
|
|
print("按任意键关闭...")
|
|
|
|
|
cv2.waitKey(0)
|
|
|
|
|
cv2.destroyAllWindows()
|
|
|
|
|
else:
|
|
|
|
|
print(" ❌ 未检测到关键点")
|