120 lines
3.6 KiB
Python
120 lines
3.6 KiB
Python
import cv2
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import os
|
||
|
||
# =========================
|
||
# 强制使用非 GUI 后端(关键!)
|
||
# =========================
|
||
import matplotlib
|
||
|
||
matplotlib.use('Agg') # 必须在 import pyplot 之前设置
|
||
|
||
|
||
def visualize_obb(image_path, label_path, output_dir="output_visualizations"):
|
||
"""
|
||
可视化图片及其 OBB 标签,并保存结果图像到指定目录。
|
||
|
||
:param image_path: 图片路径
|
||
:param label_path: 标签路径
|
||
:param output_dir: 输出目录(自动创建)
|
||
"""
|
||
# 读取图像
|
||
image = cv2.imread(image_path)
|
||
if image is None:
|
||
print(f"❌ 无法读取图像: {image_path}")
|
||
return
|
||
|
||
h, w = image.shape[:2]
|
||
print(f"✅ 正在处理图像: {os.path.basename(image_path)} | 尺寸: {w} x {h}")
|
||
|
||
# 创建用于绘图的副本(BGR → 绘图用)
|
||
img_draw = image.copy()
|
||
|
||
# 读取标签
|
||
try:
|
||
with open(label_path, 'r') as f:
|
||
lines = f.readlines()
|
||
except Exception as e:
|
||
print(f"❌ 无法读取标签文件 {label_path}: {e}")
|
||
return
|
||
|
||
for line in lines:
|
||
parts = line.strip().split()
|
||
if len(parts) < 9:
|
||
print(f"⚠️ 跳过无效标签行: {line}")
|
||
continue
|
||
|
||
# 解析:class_id x1 y1 x2 y2 x3 y3 x4 y4
|
||
try:
|
||
points = np.array([float(x) for x in parts[1:9]]).reshape(4, 2)
|
||
except:
|
||
print(f"⚠️ 坐标解析失败: {line}")
|
||
continue
|
||
|
||
# 归一化坐标 → 像素坐标
|
||
points[:, 0] *= w # x
|
||
points[:, 1] *= h # y
|
||
points = np.int32(points)
|
||
|
||
# 绘制四边形(绿色)
|
||
cv2.polylines(img_draw, [points], isClosed=True, color=(0, 255, 0), thickness=3)
|
||
|
||
# 绘制顶点(红色圆圈)
|
||
for (x, y) in points:
|
||
cv2.circle(img_draw, (x, y), 6, (0, 0, 255), -1) # 红色实心圆
|
||
|
||
# 转为 RGB 用于 matplotlib 保存
|
||
img_rgb = cv2.cvtColor(img_draw, cv2.COLOR_BGR2RGB)
|
||
|
||
# 创建输出目录
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 生成输出路径
|
||
filename = os.path.splitext(os.path.basename(image_path))[0] + "_vis.png"
|
||
output_path = os.path.join(output_dir, filename)
|
||
|
||
# 使用 matplotlib 保存图像(不显示)
|
||
plt.figure(figsize=(16, 9), dpi=100)
|
||
plt.imshow(img_rgb)
|
||
plt.title(f"OBB Visualization - {os.path.basename(image_path)}", fontsize=14)
|
||
plt.axis('off')
|
||
plt.tight_layout()
|
||
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
|
||
plt.close() # 释放内存
|
||
|
||
print(f"✅ 可视化结果已保存: {output_path}")
|
||
|
||
|
||
def process_directory(directory):
|
||
"""
|
||
遍历目录,处理所有图片和对应的 .txt 标签文件
|
||
"""
|
||
print(f"🔍 开始处理目录: {directory}")
|
||
count = 0
|
||
|
||
for filename in os.listdir(directory):
|
||
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
|
||
image_path = os.path.join(directory, filename)
|
||
label_path = os.path.splitext(image_path)[0] + ".txt"
|
||
|
||
if os.path.exists(label_path):
|
||
visualize_obb(image_path, label_path)
|
||
count += 1
|
||
else:
|
||
print(f"🟡 跳过 (无标签): {filename}")
|
||
|
||
print(f"🎉 处理完成!共处理 {count} 张图像。")
|
||
|
||
|
||
# =========================
|
||
# 主程序入口
|
||
# =========================
|
||
if __name__ == "__main__":
|
||
# 设置你的数据目录
|
||
directory = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb4/labels'
|
||
|
||
if not os.path.exists(directory):
|
||
raise FileNotFoundError(f"目录不存在: {directory}")
|
||
|
||
process_directory(directory) |