103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
import cv2
|
|
import os
|
|
import numpy as np
|
|
from ultralytics import YOLO
|
|
|
|
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
|
|
|
|
|
|
def process_obb_images(model_path, image_dir, output_dir="./inference_results", conf_thresh=0.15, imgsz=640):
|
|
"""
|
|
批量处理图像的 OBB 推理,计算每张图像检测目标的主方向和夹角。
|
|
|
|
输入:
|
|
model_path: YOLO 权重路径
|
|
image_dir: 图像文件夹路径
|
|
output_dir: 输出结果保存路径
|
|
conf_thresh: 置信度阈值
|
|
imgsz: 输入图像大小
|
|
输出:
|
|
results_dict: {image_filename: {'angles_deg': [...], 'pairwise_angles_deg': [...]}}
|
|
"""
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
results_dict = {}
|
|
|
|
print("加载 YOLO 模型...")
|
|
model = YOLO(model_path)
|
|
print("✅ 模型加载完成")
|
|
|
|
# 获取图像文件
|
|
image_files = [f for f in os.listdir(image_dir) if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS]
|
|
if not image_files:
|
|
print(f"❌ 未找到图像文件:{image_dir}")
|
|
return results_dict
|
|
|
|
print(f"发现 {len(image_files)} 张图像待处理")
|
|
|
|
for img_filename in image_files:
|
|
img_path = os.path.join(image_dir, img_filename)
|
|
print(f"\n正在处理:{img_filename}")
|
|
|
|
img = cv2.imread(img_path)
|
|
if img is None:
|
|
print(f"❌ 跳过:无法读取图像 {img_path}")
|
|
continue
|
|
|
|
# 推理 OBB
|
|
results = model(img, save=False, imgsz=imgsz, conf=conf_thresh, mode='obb')
|
|
result = results[0]
|
|
annotated_img = result.plot()
|
|
|
|
# 保存可视化
|
|
save_path = os.path.join(output_dir, "detected_" + img_filename)
|
|
cv2.imwrite(save_path, annotated_img)
|
|
print(f"✅ 推理结果已保存至: {save_path}")
|
|
|
|
# 提取旋转角
|
|
boxes = result.obb
|
|
angles_deg = []
|
|
if boxes is None or len(boxes) == 0:
|
|
print("❌ 该图像中未检测到任何目标")
|
|
else:
|
|
for i, box in enumerate(boxes):
|
|
cls = int(box.cls.cpu().numpy()[0])
|
|
conf = box.conf.cpu().numpy()[0]
|
|
cx, cy, w, h, r_rad = box.xywhr.cpu().numpy()[0]
|
|
direction = r_rad if w >= h else r_rad + np.pi / 2
|
|
direction = direction % np.pi
|
|
angle_deg = np.degrees(direction)
|
|
angles_deg.append(angle_deg)
|
|
print(f" Box {i + 1}: Class={cls}, Conf={conf:.3f}, 主方向={angle_deg:.2f}°")
|
|
|
|
# 两两夹角
|
|
pairwise_angles_deg = []
|
|
if len(angles_deg) >= 2:
|
|
for i in range(len(angles_deg)):
|
|
for j in range(i + 1, len(angles_deg)):
|
|
diff_rad = abs(np.radians(angles_deg[i]) - np.radians(angles_deg[j]))
|
|
min_diff_rad = min(diff_rad, np.pi - diff_rad)
|
|
pairwise_angles_deg.append(np.degrees(min_diff_rad))
|
|
print(f" Box {i + 1} 与 Box {j + 1} 夹角: {np.degrees(min_diff_rad):.2f}°")
|
|
|
|
# 保存每张图像结果
|
|
results_dict[img_filename] = {
|
|
"angles_deg": angles_deg,
|
|
"pairwise_angles_deg": pairwise_angles_deg
|
|
}
|
|
|
|
print("\n所有图像处理完成!")
|
|
return results_dict
|
|
|
|
|
|
# ------------------- 测试调用 -------------------
|
|
if __name__ == "__main__":
|
|
MODEL_PATH = r'best.pt'
|
|
IMAGE_SOURCE_DIR = r"./test_image"
|
|
OUTPUT_DIR = "./inference_results"
|
|
|
|
results = process_obb_images(MODEL_PATH, IMAGE_SOURCE_DIR, OUTPUT_DIR)
|
|
for img_name, info in results.items():
|
|
print(f"\n {img_name}:")
|
|
print(f"主方向角度列表: {info['angles_deg']}")
|
|
print(f"两两夹角列表: {info['pairwise_angles_deg']}")
|