171 lines
7.1 KiB
Python
171 lines
7.1 KiB
Python
import json
|
||
import os
|
||
import glob
|
||
|
||
def labelme_to_yolo_keypoints_batch(json_dir, output_dir, target_box_label="J1", class_id=0, img_shape=None, keypoints_per_instance=4):
|
||
"""
|
||
批量转换 LabelMe JSON → YOLO Pose 格式 (.txt)
|
||
- 每 keypoints_per_instance 个关键点对应一个 target_box_label 实例
|
||
- 关键点必须与框一一对应
|
||
- 转换失败或数据不匹配时,删除 JSON 和对应图片
|
||
"""
|
||
if img_shape is None:
|
||
raise ValueError("必须提供 img_shape 参数,例如 (1440, 2506)")
|
||
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
json_files = glob.glob(os.path.join(json_dir, "*.json"))
|
||
json_files = [f for f in json_files if os.path.isfile(f) and not f.endswith("_mask.json")]
|
||
|
||
if not json_files:
|
||
print(f"❌ 在 {json_dir} 中未找到任何 JSON 文件")
|
||
return
|
||
|
||
img_h, img_w = img_shape
|
||
converted_count = 0
|
||
deleted_count = 0
|
||
|
||
print(f"🔍 开始转换:目标框='{target_box_label}', 每实例 {keypoints_per_instance} 个关键点")
|
||
|
||
for json_file in json_files:
|
||
success = False
|
||
base_name = os.path.splitext(os.path.basename(json_file))[0]
|
||
output_path = os.path.join(output_dir, f"{base_name}.txt")
|
||
image_file_to_delete = None
|
||
|
||
try:
|
||
with open(json_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
# 获取图片路径
|
||
image_path = data.get("imagePath")
|
||
if image_path:
|
||
image_file_to_delete = os.path.join(json_dir, os.path.basename(image_path))
|
||
else:
|
||
print(f"⚠️ {base_name}: JSON 中无 imagePath,无法定位图片")
|
||
image_file_to_delete = None
|
||
|
||
# 提取关键点
|
||
keypoint_labels = {str(i) for i in range(1, keypoints_per_instance + 1)} # 支持 1,2,3,4...
|
||
keypoints = []
|
||
for shape in data.get("shapes", []):
|
||
label = shape["label"]
|
||
if shape["shape_type"] == "point" and label in keypoint_labels:
|
||
x, y = shape["points"][0]
|
||
nx = x / img_w
|
||
ny = y / img_h
|
||
# 归一化并裁剪到 [0,1]
|
||
nx = max(0.0, min(1.0, nx))
|
||
ny = max(0.0, min(1.0, ny))
|
||
keypoints.append((label, nx, ny))
|
||
|
||
# 提取 J1 矩形框(每个框是一个实例)
|
||
j1_boxes = [
|
||
s for s in data.get("shapes", [])
|
||
if s["label"] == target_box_label and s["shape_type"] == "rectangle"
|
||
]
|
||
num_instances = len(j1_boxes)
|
||
total_keypoints = len(keypoints)
|
||
expected_total = num_instances * keypoints_per_instance
|
||
|
||
# 检查数量匹配
|
||
if total_keypoints != expected_total:
|
||
print(f"❌ {base_name}: 关键点数量不匹配!期望 {expected_total},实际 {total_keypoints}")
|
||
raise ValueError("关键点数量不匹配")
|
||
if num_instances == 0:
|
||
print(f"❌ {base_name}: 未找到任何 '{target_box_label}' 框")
|
||
raise ValueError("无目标框")
|
||
|
||
# 提取并归一化每个框的坐标
|
||
bboxes = []
|
||
for box in j1_boxes:
|
||
x1, y1 = box["points"][0]
|
||
x2, y2 = box["points"][1]
|
||
# 计算中心点和宽高
|
||
x_center = (x1 + x2) / 2 / img_w
|
||
y_center = (y1 + y2) / 2 / img_h
|
||
w = abs(x2 - x1) / img_w
|
||
h = abs(y2 - y1) / img_h
|
||
# 裁剪到 [0,1]
|
||
x_center = max(0.0, min(1.0, x_center))
|
||
y_center = max(0.0, min(1.0, y_center))
|
||
w = max(0.0, min(1.0, w))
|
||
h = max(0.0, min(1.0, h))
|
||
bboxes.append((x_center, y_center, w, h))
|
||
|
||
# 开始写入 YOLO 文件
|
||
with open(output_path, 'w', encoding='utf-8') as f_out:
|
||
for i in range(num_instances):
|
||
# 获取该实例对应的 4 个关键点
|
||
start_idx = i * keypoints_per_instance
|
||
end_idx = start_idx + keypoints_per_instance
|
||
group = keypoints[start_idx:end_idx]
|
||
|
||
if len(group) != keypoints_per_instance:
|
||
print(f"❌ {base_name}: 实例 {i+1} 缺少关键点")
|
||
raise ValueError("实例关键点不足")
|
||
|
||
# 按标签排序关键点 (1,2,3,4)
|
||
sorted_group = sorted(group, key=lambda x: x[0])
|
||
|
||
# 构造 YOLO 行:class + bbox + keypoints
|
||
yolo_line = [
|
||
str(class_id),
|
||
f"{bboxes[i][0]:.6f}", # x_center
|
||
f"{bboxes[i][1]:.6f}", # y_center
|
||
f"{bboxes[i][2]:.6f}", # w
|
||
f"{bboxes[i][3]:.6f}" # h
|
||
]
|
||
for _, kx, ky in sorted_group:
|
||
yolo_line.extend([f"{kx:.6f}", f"{ky:.6f}", "2"]) # v=2 表示可见
|
||
|
||
f_out.write(" ".join(yolo_line) + "\n")
|
||
|
||
print(f"✅ 已转换: {os.path.basename(json_file)} -> {num_instances} 个实例")
|
||
success = True
|
||
converted_count += 1
|
||
|
||
except Exception as e:
|
||
print(f"❌ 转换失败 {base_name}: {e}")
|
||
|
||
# 删除无效文件
|
||
try:
|
||
if os.path.exists(output_path):
|
||
os.remove(output_path)
|
||
if image_file_to_delete and os.path.exists(image_file_to_delete):
|
||
os.remove(image_file_to_delete)
|
||
print(f"🗑️ 已删除图片: {os.path.basename(image_file_to_delete)}")
|
||
if os.path.exists(json_file):
|
||
os.remove(json_file)
|
||
print(f"🗑️ 已删除 JSON: {os.path.basename(json_file)}")
|
||
deleted_count += 1
|
||
except Exception as del_e:
|
||
print(f"💥 删除文件时出错: {del_e}")
|
||
|
||
print("\n" + "="*60)
|
||
print(f"🎉 批量转换完成!")
|
||
print(f"✅ 成功保留: {converted_count} 个文件")
|
||
print(f"❌ 异常删除: {deleted_count} 个文件(JSON + 图片)")
|
||
print(f"📁 输出目录: {output_dir}")
|
||
print(f"📦 每实例关键点数: {keypoints_per_instance}")
|
||
print(f"🏷️ 目标框标签: {target_box_label}")
|
||
print("="*60)
|
||
|
||
|
||
# ================== 用户配置区 ==================
|
||
JSON_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/yolodataset/point1/f11"
|
||
OUTPUT_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/yolodataset/point1/labels_keypoints"
|
||
TARGET_BOX_LABEL = "J1"
|
||
CLASS_ID = 1
|
||
IMG_SHAPE = (1440, 2506) # (height, width)
|
||
KEYPOINTS_PER_INSTANCE = 3
|
||
|
||
# ================== 执行转换 ==================
|
||
if __name__ == "__main__":
|
||
labelme_to_yolo_keypoints_batch(
|
||
json_dir=JSON_DIR,
|
||
output_dir=OUTPUT_DIR,
|
||
target_box_label=TARGET_BOX_LABEL,
|
||
class_id=CLASS_ID,
|
||
img_shape=IMG_SHAPE,
|
||
keypoints_per_instance=KEYPOINTS_PER_INSTANCE
|
||
) |