Files
ailai_image_point_diff/ailai_pc/point_test.py

129 lines
4.5 KiB
Python

import cv2
import numpy as np
from ultralytics import YOLO
import os
# ====================== 用户配置 ======================
MODEL_PATH = 'best.pt'
IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251226' # 👈 修改为你的图像文件夹路径
OUTPUT_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251226/output_images' # 保存结果的文件夹
# 支持的图像扩展名
IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ====================== 可视化函数 ======================
def draw_keypoints_on_image(image, kpts_xy, kpts_conf, orig_shape):
"""
在图像上绘制关键点
:param image: OpenCV 图像
:param kpts_xy: (N, K, 2) 坐标
:param kpts_conf: (N, K) 置信度
:param orig_shape: 原图尺寸 (H, W)
"""
colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0)] # 1红, 2蓝, 3绿, 4青
for i in range(len(kpts_xy)):
xy = kpts_xy[i] # (4, 2)
conf = kpts_conf[i] if kpts_conf.ndim == 2 else kpts_conf[i:i+1] # (4,) 或标量
for j in range(len(xy)):
x, y = xy[j]
c = conf[j] if hasattr(conf, '__len__') else conf
x, y = int(x), int(y)
# 检查坐标是否在图像范围内
if x < 0 or y < 0 or x >= orig_shape[1] or y >= orig_shape[0]:
continue
# 只绘制置信度 > 0.5 的点
if c < 0.5:
continue
# 绘制实心圆
cv2.circle(image, (x, y), radius=15, color=colors[j], thickness=-1)
# 标注编号(偏移避免遮挡)
cv2.putText(image, f'{j+1}', (x + 20, y - 20),
cv2.FONT_HERSHEY_SIMPLEX, 1.5, colors[j], 5)
return image
# ====================== 主程序 ======================
if __name__ == "__main__":
print("🚀 开始批量关键点检测任务")
# 加载模型
print("🔄 加载 YOLO 模型...")
model = YOLO(MODEL_PATH)
print(f"✅ 模型加载完成: {MODEL_PATH}")
# 获取所有图像文件
image_files = [
f for f in os.listdir(IMAGE_SOURCE_DIR)
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
]
if not image_files:
print(f"❌ 错误:在 {IMAGE_SOURCE_DIR} 中未找到支持的图像文件")
exit(1)
print(f"📁 发现 {len(image_files)} 张图像待处理")
# 遍历每张图像
for img_filename in image_files:
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
print(f"\n🖼️ 正在处理: {img_filename}")
# 读取图像
img = cv2.imread(img_path)
if img is None:
print(f"❌ 无法读取图像,跳过: {img_path}")
continue
print(f" ✅ 图像加载成功 (shape: {img.shape})")
# 推理
print(" 🔍 正在推理...")
results = model(img)
processed = False # 标记是否处理了关键点
for i, result in enumerate(results):
if result.keypoints is not None:
kpts = result.keypoints
orig_shape = kpts.orig_shape # (H, W)
# 获取坐标和置信度
kpts_xy = kpts.xy.cpu().numpy() # (N, K, 2)
kpts_conf = kpts.conf.cpu().numpy() if kpts.conf is not None else np.ones(kpts_xy.shape[:2])
print(f" ✅ 检测到 {len(kpts_xy)} 个实例")
# 绘制关键点
img_with_kpts = draw_keypoints_on_image(img.copy(), kpts_xy, kpts_conf, orig_shape)
# 保存图像
save_filename = f"keypoints_{img_filename}"
save_path = os.path.join(OUTPUT_DIR, save_filename)
cv2.imwrite(save_path, img_with_kpts)
print(f" 💾 结果已保存: {save_path}")
# 可选:显示图像(每次一张,按任意键继续)
# display_img = cv2.resize(img_with_kpts, (1280, 720))
# cv2.imshow("Keypoints Detection", display_img)
# print(" ⌨️ 按任意键继续...")
# cv2.waitKey(0)
# cv2.destroyAllWindows()
processed = True
if not processed:
print(f" ❌ 未检测到关键点,跳过保存")
print("\n" + "=" * 60)
print("🎉 批量推理完成!")
print(f"📊 总共处理 {len(image_files)} 张图像")
print(f"📁 结果保存在: {OUTPUT_DIR}")
print("=" * 60)