129 lines
4.5 KiB
Python
129 lines
4.5 KiB
Python
import cv2
|
|
import numpy as np
|
|
from ultralytics import YOLO
|
|
import os
|
|
|
|
# ====================== 用户配置 ======================
|
|
MODEL_PATH = 'best.pt'
|
|
IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251226' # 👈 修改为你的图像文件夹路径
|
|
OUTPUT_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251226/output_images' # 保存结果的文件夹
|
|
|
|
# 支持的图像扩展名
|
|
IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
# ====================== 可视化函数 ======================
|
|
def draw_keypoints_on_image(image, kpts_xy, kpts_conf, orig_shape):
|
|
"""
|
|
在图像上绘制关键点
|
|
:param image: OpenCV 图像
|
|
:param kpts_xy: (N, K, 2) 坐标
|
|
:param kpts_conf: (N, K) 置信度
|
|
:param orig_shape: 原图尺寸 (H, W)
|
|
"""
|
|
colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0)] # 1红, 2蓝, 3绿, 4青
|
|
|
|
for i in range(len(kpts_xy)):
|
|
xy = kpts_xy[i] # (4, 2)
|
|
conf = kpts_conf[i] if kpts_conf.ndim == 2 else kpts_conf[i:i+1] # (4,) 或标量
|
|
|
|
for j in range(len(xy)):
|
|
x, y = xy[j]
|
|
c = conf[j] if hasattr(conf, '__len__') else conf
|
|
|
|
x, y = int(x), int(y)
|
|
|
|
# 检查坐标是否在图像范围内
|
|
if x < 0 or y < 0 or x >= orig_shape[1] or y >= orig_shape[0]:
|
|
continue
|
|
|
|
# 只绘制置信度 > 0.5 的点
|
|
if c < 0.5:
|
|
continue
|
|
|
|
# 绘制实心圆
|
|
cv2.circle(image, (x, y), radius=15, color=colors[j], thickness=-1)
|
|
# 标注编号(偏移避免遮挡)
|
|
cv2.putText(image, f'{j+1}', (x + 20, y - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1.5, colors[j], 5)
|
|
|
|
return image
|
|
|
|
|
|
# ====================== 主程序 ======================
|
|
if __name__ == "__main__":
|
|
print("🚀 开始批量关键点检测任务")
|
|
|
|
# 加载模型
|
|
print("🔄 加载 YOLO 模型...")
|
|
model = YOLO(MODEL_PATH)
|
|
print(f"✅ 模型加载完成: {MODEL_PATH}")
|
|
|
|
# 获取所有图像文件
|
|
image_files = [
|
|
f for f in os.listdir(IMAGE_SOURCE_DIR)
|
|
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
|
|
]
|
|
|
|
if not image_files:
|
|
print(f"❌ 错误:在 {IMAGE_SOURCE_DIR} 中未找到支持的图像文件")
|
|
exit(1)
|
|
|
|
print(f"📁 发现 {len(image_files)} 张图像待处理")
|
|
|
|
# 遍历每张图像
|
|
for img_filename in image_files:
|
|
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
|
|
print(f"\n🖼️ 正在处理: {img_filename}")
|
|
|
|
# 读取图像
|
|
img = cv2.imread(img_path)
|
|
if img is None:
|
|
print(f"❌ 无法读取图像,跳过: {img_path}")
|
|
continue
|
|
print(f" ✅ 图像加载成功 (shape: {img.shape})")
|
|
|
|
# 推理
|
|
print(" 🔍 正在推理...")
|
|
results = model(img)
|
|
|
|
processed = False # 标记是否处理了关键点
|
|
|
|
for i, result in enumerate(results):
|
|
if result.keypoints is not None:
|
|
kpts = result.keypoints
|
|
orig_shape = kpts.orig_shape # (H, W)
|
|
|
|
# 获取坐标和置信度
|
|
kpts_xy = kpts.xy.cpu().numpy() # (N, K, 2)
|
|
kpts_conf = kpts.conf.cpu().numpy() if kpts.conf is not None else np.ones(kpts_xy.shape[:2])
|
|
|
|
print(f" ✅ 检测到 {len(kpts_xy)} 个实例")
|
|
|
|
# 绘制关键点
|
|
img_with_kpts = draw_keypoints_on_image(img.copy(), kpts_xy, kpts_conf, orig_shape)
|
|
|
|
# 保存图像
|
|
save_filename = f"keypoints_{img_filename}"
|
|
save_path = os.path.join(OUTPUT_DIR, save_filename)
|
|
cv2.imwrite(save_path, img_with_kpts)
|
|
print(f" 💾 结果已保存: {save_path}")
|
|
|
|
# 可选:显示图像(每次一张,按任意键继续)
|
|
# display_img = cv2.resize(img_with_kpts, (1280, 720))
|
|
# cv2.imshow("Keypoints Detection", display_img)
|
|
# print(" ⌨️ 按任意键继续...")
|
|
# cv2.waitKey(0)
|
|
# cv2.destroyAllWindows()
|
|
|
|
processed = True
|
|
|
|
if not processed:
|
|
print(f" ❌ 未检测到关键点,跳过保存")
|
|
|
|
print("\n" + "=" * 60)
|
|
print("🎉 批量推理完成!")
|
|
print(f"📊 总共处理 {len(image_files)} 张图像")
|
|
print(f"📁 结果保存在: {OUTPUT_DIR}")
|
|
print("=" * 60) |