Files
zjsh_yolov11/test_line_angle_f1error.py
2025-08-13 14:49:06 +08:00

180 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import cv2
import numpy as np
import os
# ------------------ 配置 ------------------
model_path = 'ultralytics_yolo11-main/runs/train/exp4/weights/best.pt'
img_folder = '/home/hx/yolo/ultralytics_yolo11-main/dataset1/test' # 图像文件夹路径
output_mask_dir = 'output_masks1'
os.makedirs(output_mask_dir, exist_ok=True)
# 支持的图像格式
SUPPORTED_FORMATS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
# ------------------ 加载模型 ------------------
model = YOLO(model_path)
model.to('cuda') # 使用 GPU
def vector_to_normal(v):
"""从方向向量得到逆时针旋转90°的法向量"""
return np.array([-v[1], v[0]])
def angle_between_normals(n1, n2):
"""计算两个法向量之间的最小夹角 [0°, 90°]"""
n1_u = n1 / (np.linalg.norm(n1) + 1e-8)
n2_u = n2 / (np.linalg.norm(n2) + 1e-8)
cos_theta = np.clip(np.dot(n1_u, n2_u), -1.0, 1.0)
angle_deg = np.degrees(np.arccos(cos_theta))
return min(angle_deg, 180 - angle_deg)
def process_image(img_path, output_dir):
"""处理单张图像"""
img = cv2.imread(img_path)
if img is None:
print(f"❌ 无法读取图像: {img_path}")
return
h, w = img.shape[:2]
filename = os.path.basename(img_path)
name_only = os.path.splitext(filename)[0]
print(f"\n🔄 正在处理: {filename}")
# ------------------ 生成合成掩码 ------------------
composite_mask = np.zeros((h, w), dtype=np.uint8)
results = model(img_path, imgsz=1280, conf=0.5)
for r in results:
if r.masks is not None:
masks = r.masks.data.cpu().numpy()
for mask in masks:
mask_resized = cv2.resize(mask, (w, h))
mask_img = (mask_resized * 255).astype(np.uint8)
composite_mask = np.maximum(composite_mask, mask_img)
# 保存掩码
mask_save_path = os.path.join(output_dir, f'mask_{name_only}.png')
cv2.imwrite(mask_save_path, composite_mask)
print(f"✅ 掩码已保存: {mask_save_path}")
# ------------------ 提取轮廓 ------------------
contours, _ = cv2.findContours(composite_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) == 0:
print(f"⚠️ 未检测到轮廓: {filename}")
return
all_contours = np.vstack(contours)
epsilon = 0.005 * cv2.arcLength(all_contours, True)
approx = cv2.approxPolyDP(all_contours, epsilon, True)
pts = [p[0] for p in approx]
# ------------------ 提取边段(长度 > 20------------------
line_segments = []
n = len(pts)
for i in range(n):
p1 = np.array(pts[i])
p2 = np.array(pts[(i + 1) % n])
length = np.linalg.norm(p2 - p1)
if length > 20:
line_segments.append({
'index': i,
'start': p1,
'end': p2,
'length': length,
'vector': p2 - p1
})
# 按长度排序取前4条
line_segments.sort(key=lambda x: x['length'], reverse=True)
top4 = line_segments[:4]
if len(top4) < 4:
print(f"⚠️ 不足4条有效边{len(top4)}条): {filename}")
return
print(f"✅ 已提取前4条最长边")
# ------------------ 计算法向量 ------------------
normals = [vector_to_normal(edge['vector']) for edge in top4]
# ------------------ 找法向量夹角最小的一对边 ------------------
min_angle = float('inf')
remove_pair = (0, 1)
for i in range(4):
for j in range(i + 1, 4):
angle = angle_between_normals(normals[i], normals[j])
if angle < min_angle:
min_angle = angle
remove_pair = (i, j)
print(f"✅ 法向量最小夹角: {min_angle:.2f}° → 删除边{i+1}与边{j+1}")
# ------------------ 保留另外两条边 ------------------
keep_indices = set(range(4)) - set(remove_pair)
kept_edges = [top4[i] for i in sorted(keep_indices)] # 排序保持顺序一致
if len(kept_edges) != 2:
print("❌ 保留边数量异常!")
return
edge1, edge2 = kept_edges
# ------------------ 计算保留的两条边之间的夹角 ------------------
v1 = edge1['vector']
v2 = edge2['vector']
v1_u = v1 / (np.linalg.norm(v1) + 1e-8)
v2_u = v2 / (np.linalg.norm(v2) + 1e-8)
cos_theta = np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)
raw_angle = np.degrees(np.arccos(cos_theta))
line_angle = min(raw_angle, 180 - raw_angle) # 转为 [0°, 90°]
# ------------------ 打印信息 ------------------
print(f"✅ 保留边:")
for idx, e in enumerate(kept_edges, 1):
print(f"{idx}: 起点{tuple(e['start'])} → 终点{tuple(e['end'])}, 长度: {e['length']:.1f}")
print(f"✅ 保留边之间的夹角: {line_angle:.2f}°")
# ------------------ 可视化 ------------------
vis_img = img.copy()
colors = [(0, 255, 0), (255, 0, 255)] # 绿色、品红
for idx, e in enumerate(kept_edges):
p1, p2 = e['start'], e['end']
cv2.line(vis_img, tuple(p1), tuple(p2), colors[idx], 3, cv2.LINE_AA)
mid = ((p1[0] + p2[0]) // 2, (p1[1] + p2[1]) // 2)
cv2.putText(vis_img, f"{e['length']:.0f}", (mid[0], mid[1] + idx * 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[idx], 2)
# 标注最终夹角
cv2.putText(vis_img, f"Final Angle: {line_angle:.2f}°", (20, 50),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
# 保存可视化图像
vis_save_path = os.path.join(output_dir, f'kept_edges_{name_only}.jpg')
cv2.imwrite(vis_save_path, vis_img)
print(f"✅ 保留边可视化已保存: {vis_save_path}")
# ------------------ 主程序 ------------------
if __name__ == '__main__':
if not os.path.isdir(img_folder):
print(f"❌ 图像文件夹不存在: {img_folder}")
exit()
image_files = [
f for f in os.listdir(img_folder)
if f.lower().endswith(SUPPORTED_FORMATS)
]
if len(image_files) == 0:
print(f"⚠️ 在 {img_folder} 中未找到支持的图像文件")
exit()
print(f"✅ 发现 {len(image_files)} 张图像,开始批量处理...")
for image_file in image_files:
image_path = os.path.join(img_folder, image_file)
process_image(image_path, output_mask_dir)
print(f"\n🎉 所有图像处理完成!结果保存在: {output_mask_dir}")