Files
zjsh_yolov11/angle_base_obb/anger_caculate_file.py
2025-09-11 20:44:35 +08:00

103 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import cv2
import numpy as np
import os
# ================== 配置参数 ==================
MODEL_PATH = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb4/weights/best.pt"
IMAGE_SOURCE_DIR = r"/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb2/test" # 图像文件夹路径
OUTPUT_DIR = "./inference_results" # 输出结果保存路径
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
# 创建输出目录
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 1. 加载模型
print("🔄 加载 YOLO 模型...")
model = YOLO(MODEL_PATH)
print("✅ 模型加载完成")
# 获取所有图像文件
image_files = [
f for f in os.listdir(IMAGE_SOURCE_DIR)
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
]
if not image_files:
print(f"❌ 错误:在路径中未找到图像文件:{IMAGE_SOURCE_DIR}")
exit(1)
print(f"📁 发现 {len(image_files)} 张图像待处理")
# ================== 批量处理每张图像 ==================
for img_filename in image_files:
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
print(f"\n🖼️ 正在处理:{img_filename}")
# 读取图像
img = cv2.imread(img_path)
if img is None:
print(f"❌ 跳过:无法读取图像 {img_path}")
continue
# 推理OBB 模式)
results = model(
img,
save=False,
imgsz=640,
conf=0.15,
mode='obb'
)
result = results[0]
annotated_img = result.plot() # 绘制旋转框
# 保存结果图像
save_path = os.path.join(OUTPUT_DIR, "detected_" + img_filename)
cv2.imwrite(save_path, annotated_img)
print(f"✅ 推理结果已保存至: {save_path}")
# 提取旋转框信息
boxes = result.obb
directions = [] # 存储每个框的主方向(弧度),归一化到 [0, π)
if boxes is None or len(boxes) == 0:
print("❌ 该图像中未检测到任何目标")
else:
print(f"✅ 检测到 {len(boxes)} 个目标:")
for i, box in enumerate(boxes):
cls = int(box.cls.cpu().numpy()[0])
conf = box.conf.cpu().numpy()[0]
xywhr = box.xywhr.cpu().numpy()[0] # [cx, cy, w, h, r]
cx, cy, w, h, r_rad = xywhr
# 确定主方向(长边方向)
if w >= h:
direction = r_rad # 长边方向
else:
direction = r_rad + np.pi / 2 # 长边是宽的方向
# 归一化到 [0, π)
direction = direction % np.pi
directions.append(direction)
angle_deg = np.degrees(direction)
print(f" Box {i+1}: Class: {cls}, Confidence: {conf:.3f}, 主方向: {angle_deg:.2f}°")
# 计算两两之间的夹角最小夹角0°~90°
if len(directions) >= 2:
print("\n🔍 计算各框之间的夹角(主方向最小夹角):")
for i in range(len(directions)):
for j in range(i + 1, len(directions)):
dir1 = directions[i]
dir2 = directions[j]
diff = abs(dir1 - dir2)
min_diff_rad = min(diff, np.pi - diff) # 最小夹角(考虑周期性)
min_diff_deg = np.degrees(min_diff_rad)
print(f" Box {i+1} 与 Box {j+1} 之间夹角: {min_diff_deg:.2f}°")
else:
print("⚠️ 检测到少于两个目标,无法计算夹角。")
print("\n🎉 所有图像处理完成!")