first commit
This commit is contained in:
129
ailai_pc/point_test.py
Normal file
129
ailai_pc/point_test.py
Normal file
@ -0,0 +1,129 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
import os
|
||||
|
||||
# ====================== 用户配置 ======================
|
||||
MODEL_PATH = 'best.pt'
|
||||
IMAGE_SOURCE_DIR = './train' # 👈 修改为你的图像文件夹路径
|
||||
OUTPUT_DIR = './output_images' # 保存结果的文件夹
|
||||
|
||||
# 支持的图像扩展名
|
||||
IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
|
||||
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# ====================== 可视化函数 ======================
|
||||
def draw_keypoints_on_image(image, kpts_xy, kpts_conf, orig_shape):
|
||||
"""
|
||||
在图像上绘制关键点
|
||||
:param image: OpenCV 图像
|
||||
:param kpts_xy: (N, K, 2) 坐标
|
||||
:param kpts_conf: (N, K) 置信度
|
||||
:param orig_shape: 原图尺寸 (H, W)
|
||||
"""
|
||||
colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0)] # 1红, 2蓝, 3绿, 4青
|
||||
|
||||
for i in range(len(kpts_xy)):
|
||||
xy = kpts_xy[i] # (4, 2)
|
||||
conf = kpts_conf[i] if kpts_conf.ndim == 2 else kpts_conf[i:i+1] # (4,) 或标量
|
||||
|
||||
for j in range(len(xy)):
|
||||
x, y = xy[j]
|
||||
c = conf[j] if hasattr(conf, '__len__') else conf
|
||||
|
||||
x, y = int(x), int(y)
|
||||
|
||||
# 检查坐标是否在图像范围内
|
||||
if x < 0 or y < 0 or x >= orig_shape[1] or y >= orig_shape[0]:
|
||||
continue
|
||||
|
||||
# 只绘制置信度 > 0.5 的点
|
||||
if c < 0.5:
|
||||
continue
|
||||
|
||||
# 绘制实心圆
|
||||
cv2.circle(image, (x, y), radius=15, color=colors[j], thickness=-1)
|
||||
# 标注编号(偏移避免遮挡)
|
||||
cv2.putText(image, f'{j+1}', (x + 20, y - 20),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1.5, colors[j], 5)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
# ====================== 主程序 ======================
|
||||
if __name__ == "__main__":
|
||||
print("🚀 开始批量关键点检测任务")
|
||||
|
||||
# 加载模型
|
||||
print("🔄 加载 YOLO 模型...")
|
||||
model = YOLO(MODEL_PATH)
|
||||
print(f"✅ 模型加载完成: {MODEL_PATH}")
|
||||
|
||||
# 获取所有图像文件
|
||||
image_files = [
|
||||
f for f in os.listdir(IMAGE_SOURCE_DIR)
|
||||
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
|
||||
]
|
||||
|
||||
if not image_files:
|
||||
print(f"❌ 错误:在 {IMAGE_SOURCE_DIR} 中未找到支持的图像文件")
|
||||
exit(1)
|
||||
|
||||
print(f"📁 发现 {len(image_files)} 张图像待处理")
|
||||
|
||||
# 遍历每张图像
|
||||
for img_filename in image_files:
|
||||
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
|
||||
print(f"\n🖼️ 正在处理: {img_filename}")
|
||||
|
||||
# 读取图像
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
print(f"❌ 无法读取图像,跳过: {img_path}")
|
||||
continue
|
||||
print(f" ✅ 图像加载成功 (shape: {img.shape})")
|
||||
|
||||
# 推理
|
||||
print(" 🔍 正在推理...")
|
||||
results = model(img)
|
||||
|
||||
processed = False # 标记是否处理了关键点
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if result.keypoints is not None:
|
||||
kpts = result.keypoints
|
||||
orig_shape = kpts.orig_shape # (H, W)
|
||||
|
||||
# 获取坐标和置信度
|
||||
kpts_xy = kpts.xy.cpu().numpy() # (N, K, 2)
|
||||
kpts_conf = kpts.conf.cpu().numpy() if kpts.conf is not None else np.ones(kpts_xy.shape[:2])
|
||||
|
||||
print(f" ✅ 检测到 {len(kpts_xy)} 个实例")
|
||||
|
||||
# 绘制关键点
|
||||
img_with_kpts = draw_keypoints_on_image(img.copy(), kpts_xy, kpts_conf, orig_shape)
|
||||
|
||||
# 保存图像
|
||||
save_filename = f"keypoints_{img_filename}"
|
||||
save_path = os.path.join(OUTPUT_DIR, save_filename)
|
||||
cv2.imwrite(save_path, img_with_kpts)
|
||||
print(f" 💾 结果已保存: {save_path}")
|
||||
|
||||
# 可选:显示图像(每次一张,按任意键继续)
|
||||
# display_img = cv2.resize(img_with_kpts, (1280, 720))
|
||||
# cv2.imshow("Keypoints Detection", display_img)
|
||||
# print(" ⌨️ 按任意键继续...")
|
||||
# cv2.waitKey(0)
|
||||
# cv2.destroyAllWindows()
|
||||
|
||||
processed = True
|
||||
|
||||
if not processed:
|
||||
print(f" ❌ 未检测到关键点,跳过保存")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 批量推理完成!")
|
||||
print(f"📊 总共处理 {len(image_files)} 张图像")
|
||||
print(f"📁 结果保存在: {OUTPUT_DIR}")
|
||||
print("=" * 60)
|
||||
Reference in New Issue
Block a user