25 lines
857 B
Python
25 lines
857 B
Python
from ultralytics import YOLO
|
||
import cv2
|
||
import torch
|
||
|
||
# 加载模型
|
||
model = YOLO('best.pt')
|
||
|
||
# 读取一张真实图像
|
||
img_path = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/point2/train/1.jpg' # 替换成您的图像路径
|
||
image = cv2.imread(img_path)
|
||
|
||
# 将图像转换成RGB格式,并调整大小
|
||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||
image_resized = cv2.resize(image_rgb, (640, 640))
|
||
|
||
# 进行推理
|
||
results = model(image_resized)
|
||
|
||
# 打印关键点数据形状和样本
|
||
if len(results) > 0 and hasattr(results[0], 'keypoints') and results[0].keypoints is not None:
|
||
print("Keypoints data shape:", results[0].keypoints.data.shape)
|
||
if results[0].keypoints.data.shape[0] > 0:
|
||
print("Keypoints data sample:", results[0].keypoints.data[0, :12])
|
||
else:
|
||
print("No keypoints detected or invalid keypoints data.") |