Files
琉璃月光 c134abf749 first commit
2025-10-21 11:07:29 +08:00

174 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from ultralytics import YOLO
# ====================== 用户配置 ======================
MODEL_PATH = 'best.pt'
IMAGE_PATH = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/point2/train/1.jpg' # 👈 修改为你的具体图像路径
OUTPUT_DIR = './output_images'
# 固定点(例如标定得到的理论位置)
FIXED_REF_POINT = (535.0, 605)# (x, y),单位:像素
def calculate_scale(width_mm, width_px):
"""
计算缩放因子单位mm/px
:param width_mm: 实际宽度(单位:毫米)
:param width_px: 宽度的像素数量
:return: 缩放因子单位mm/px
"""
if width_px == 0:
print("像素宽度不能为0")
return None
return width_mm / float(width_px)
# 示例使用:
# 假设我们知道一个参考物体的实际宽度是50毫米在图像中占据100个像素
width_mm = 70.0 # 实际宽度(单位:毫米)
width_px = 42 # 在图像中的宽度(单位:像素)
SCALE_X= calculate_scale(width_mm, width_px)
print(f"水平方向的缩放因子为: {SCALE_X:.3f} mm/px")
def calculate_scale_y(height_mm, height_px):
"""
计算垂直方向的缩放因子单位mm/px
:param height_mm: 实际高度(单位:毫米)
:param height_px: 高度的像素数量
:return: 缩放因子单位mm/px
"""
if height_px == 0:
print("像素高度不能为0")
return None
return height_mm / float(height_px)
# 同样地,对于高度来说
height_mm = 890.0 # 实际高度(单位:毫米)
height_px = 507 # 在图像中的高度(单位:像素)
SCALE_Y = calculate_scale_y(height_mm, height_px)
print(f"垂直方向的缩放因子为: {SCALE_Y:.3f} mm/px")
# 创建输出目录
import os
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ====================== 可视化函数(增强版)======================
def draw_keypoints_and_offset(image, kpts_xy, kpts_conf, orig_shape, fixed_point, scale_x, scale_y):
"""
在图像上绘制关键点、中心点、参考点、偏移箭头和文本
:param image: OpenCV 图像
:param kpts_xy: (N, K, 2) 坐标
:param kpts_conf: (N, K) 置信度
:param orig_shape: 原图尺寸 (H, W)
:param fixed_point: 固定参考点 (fx, fy)
:param scale_x: x方向缩放 mm/px
:param scale_y: y方向缩放 mm/px
:return: 处理后的图像,偏移信息列表
"""
colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0)] # 1红, 2蓝, 3绿, 4青
results_info = []
for i in range(len(kpts_xy)):
xy = kpts_xy[i] # (K, 2)
conf = kpts_conf[i] if kpts_conf.ndim == 2 else kpts_conf[i:i+1]
# 检查是否有至少两个关键点
if len(xy) < 2:
print(f"⚠️ 实例 {i} 的关键点数量不足2个")
continue
p1 = xy[0] # 第一个关键点
p2 = xy[1] # 第二个关键点
c1 = conf[0] if hasattr(conf, '__len__') else conf
c2 = conf[1] if hasattr(conf, '__len__') else conf
if c1 < 0.5 or c2 < 0.5:
print(f"⚠️ 实例 {i} 的前两个关键点置信度过低: c1={c1:.3f}, c2={c2:.3f}")
continue
# 转为整数坐标(仅用于绘制)
p1_int = tuple(map(int, p1))
p2_int = tuple(map(int, p2))
h, w = orig_shape
valid = all(0 <= x < w and 0 <= y < h for x, y in [p1, p2])
if not valid:
print(f"⚠️ 实例 {i} 的关键点超出图像边界")
continue
# 绘制前两个关键点
cv2.circle(image, p1_int, radius=15, color=colors[0], thickness=-1) # 红色
cv2.circle(image, p2_int, radius=15, color=colors[1], thickness=-1) # 蓝色
# 标注编号
cv2.putText(image, "1", (p1_int[0] + 20, p1_int[1] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 1.5, colors[0], 5)
cv2.putText(image, "2", (p2_int[0] + 20, p2_int[1] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 1.5, colors[1], 5)
center_x = (p1[0] + p2[0]) / 2.0
center_y = (p1[1] + p2[1]) / 2.0
dynamic_center = (int(center_x), int(center_y))
cv2.circle(image, dynamic_center, radius=18, color=(0, 255, 0), thickness=3)
cv2.putText(image, "Center", (dynamic_center[0] + 30, dynamic_center[1]),
cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 3)
fx, fy = map(int, fixed_point)
cv2.circle(image, (fx, fy), radius=20, color=(255, 255, 0), thickness=3)
cv2.putText(image, "Ref", (fx + 30, fy), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (255, 255, 0), 3)
dx_px = center_x - fixed_point[0]
dy_px = center_y - fixed_point[1]
dx_mm = dx_px * scale_x
dy_mm = dy_px * scale_y
cv2.arrowedLine(image, (fx, fy), dynamic_center, (0, 255, 255), 3, tipLength=0.05)
cv2.putText(image, f"ΔX={dx_mm:+.1f}mm", (fx + 40, fy - 40),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 3)
cv2.putText(image, f"ΔY={dy_mm:+.1f}mm", (fx + 40, fy + 40),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 3)
results_info.append({
'instance': i,
'center': (center_x, center_y),
'dx_px': dx_px, 'dy_px': dy_px,
'dx_mm': dx_mm, 'dy_mm': dy_mm
})
return image, results_info
if __name__ == "__main__":
img = cv2.imread(IMAGE_PATH)
if img is None:
print(f"❌ 无法读取图像,检查路径: {IMAGE_PATH}")
exit(1)
model = YOLO(MODEL_PATH)
results = model(img)
for i, result in enumerate(results):
if result.keypoints is not None:
kpts = result.keypoints
orig_shape = kpts.orig_shape
kpts_xy = kpts.xy.cpu().numpy()
kpts_conf = kpts.conf.cpu().numpy() if kpts.conf is not None else np.ones(kpts_xy.shape[:2])
img_with_kpts = img.copy()
img_with_kpts, offset_results = draw_keypoints_and_offset(
img_with_kpts, kpts_xy, kpts_conf, orig_shape,
fixed_point=FIXED_REF_POINT,
scale_x=SCALE_X, scale_y=SCALE_Y
)
for info in offset_results:
print(f" 📌 实例 {info['instance']}: "
f"ΔX={info['dx_mm']:+.2f}mm, ΔY={info['dy_mm']:+.2f}mm")
save_filename = f"offset_{os.path.basename(IMAGE_PATH)}"
save_path = os.path.join(OUTPUT_DIR, save_filename)
cv2.imwrite(save_path, img_with_kpts)
print(f" 💾 结果已保存: {save_path}")