Files
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

180 lines
6.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import cv2
import numpy as np
import os
# ------------------ 配置 ------------------
model_path = 'ultralytics_yolo11-main/runs/train/exp4/weights/best.pt'
img_folder = '/home/hx/yolo/test' # 输入文件夹路径
output_mask_dir = 'output_masks'
os.makedirs(output_mask_dir, exist_ok=True)
SUPPORTED_FORMATS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
# ------------------ 加载模型 ------------------
model = YOLO(model_path)
model.to('cuda') # 使用 GPU
def vector_to_normal(v):
"""从方向向量得到逆时针旋转90°的法向量"""
return np.array([-v[1], v[0]])
def angle_between_normals(n1, n2):
"""计算两个法向量之间的最小夹角 [0°, 90°]"""
n1_u = n1 / (np.linalg.norm(n1) + 1e-8)
n2_u = n2 / (np.linalg.norm(n2) + 1e-8)
cos_theta = np.clip(np.dot(n1_u, n2_u), -1.0, 1.0)
angle_deg = np.degrees(np.arccos(cos_theta))
return min(angle_deg, 180 - angle_deg)
def process_image(img_path, output_dir):
img = cv2.imread(img_path)
if img is None:
print(f"❌ 无法读取图像: {img_path}")
return
h, w = img.shape[:2]
filename = os.path.basename(img_path)
name_only = os.path.splitext(filename)[0]
print(f"\n🔄 正在处理: {name_only}")
# ------------------ 生成合成掩码 ------------------
composite_mask = np.zeros((h, w), dtype=np.uint8)
results = model(img_path, imgsz=1280, conf=0.5)
for r in results:
if r.masks is not None:
masks = r.masks.data.cpu().numpy()
for mask in masks:
mask_resized = cv2.resize(mask, (w, h))
mask_img = (mask_resized * 255).astype(np.uint8)
composite_mask = np.maximum(composite_mask, mask_img)
mask_save_path = os.path.join(output_dir, f'mask_{name_only}.png')
cv2.imwrite(mask_save_path, composite_mask)
print(f"✅ 掩码已保存: {mask_save_path}")
# ------------------ 提取轮廓 ------------------
contours, _ = cv2.findContours(composite_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) == 0:
print(f"⚠️ 未检测到轮廓: {filename}")
return
all_contours = np.vstack(contours)
epsilon = 0.005 * cv2.arcLength(all_contours, True)
approx = cv2.approxPolyDP(all_contours, epsilon, True)
pts = [p[0] for p in approx]
# ------------------ 提取前4条最长边 ------------------
line_segments = []
n = len(pts)
for i in range(n):
p1 = np.array(pts[i])
p2 = np.array(pts[(i + 1) % n])
length = np.linalg.norm(p2 - p1)
if length > 20:
line_segments.append({
'index': i,
'start': p1,
'end': p2,
'length': length,
'vector': p2 - p1
})
# 按长度排序取前4条
line_segments.sort(key=lambda x: x['length'], reverse=True)
top4 = line_segments[:4]
if len(top4) < 4:
print(f"⚠️ 不足4条有效边{len(top4)}条): {filename}")
return
print(f"✅ 已提取前4条最长边")
# ------------------ 计算每条边的法向量 ------------------
normals = []
for edge in top4:
v = edge['vector']
n = vector_to_normal(v)
normals.append(n)
# ------------------ 计算所有法向量之间的夹角,找最小夹角对 ------------------
min_angle = float('inf')
remove_pair = (0, 1) # 默认删除前两条
for i in range(4):
for j in range(i + 1, 4):
angle = angle_between_normals(normals[i], normals[j])
if angle < min_angle:
min_angle = angle
remove_pair = (i, j)
print(f"✅ 法向量最小夹角: {min_angle:.2f}° (边{i+1} vs 边{j+1})")
# ------------------ 删除夹角最小的一对边中的两条,保留另外两条 ------------------
keep_indices = set(range(4)) - set(remove_pair)
kept_edges = [top4[i] for i in keep_indices]
print(f"✅ 保留边索引: {sorted(keep_indices)}")
print(f"✅ 删除边索引: {remove_pair}")
if len(kept_edges) != 2:
print("❌ 保留边数量异常!")
return
edge1, edge2 = kept_edges
# ------------------ 计算两条保留边之间的直线夹角 ------------------
v1 = edge1['vector']
v2 = edge2['vector']
v1_u = v1 / (np.linalg.norm(v1) + 1e-8)
v2_u = v2 / (np.linalg.norm(v2) + 1e-8)
cos_theta = np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)
raw_angle = np.degrees(np.arccos(cos_theta))
line_angle = min(raw_angle, 180 - raw_angle)
# ------------------ 打印信息 ------------------
print(f"\n✅ 保留的两条边:")
for i, e in enumerate([edge1, edge2], 1):
print(f"{i}: 起点{tuple(e['start'])} → 终点{tuple(e['end'])}, 长度: {e['length']:.1f}")
print(f"\n✅ 两条保留边之间的夹角: {line_angle:.2f}°")
# ------------------ 可视化 ------------------
vis_img = img.copy()
color_keep = [(0, 255, 0), (255, 0, 255)] # 绿色和品红表示保留边
for idx, e in enumerate(kept_edges):
p1, p2 = e['start'], e['end']
cv2.line(vis_img, tuple(p1), tuple(p2), color_keep[idx], 3, cv2.LINE_AA)
mid = ((p1[0] + p2[0]) // 2, (p1[1] + p2[1]) // 2)
cv2.putText(vis_img, f"{e['length']:.0f}", (mid[0], mid[1] + (idx*20)),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, color_keep[idx], 2)
# 标注最终夹角
cv2.putText(vis_img, f"Final Angle: {line_angle:.2f}°", (20, 50),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
# 保存结果
vis_save_path = os.path.join(output_dir, f'kept_edges_{name_only}.jpg')
cv2.imwrite(vis_save_path, vis_img)
print(f"✅ 保留边可视化已保存: {vis_save_path}")
# ------------------ 主程序 ------------------
if __name__ == '__main__':
if not os.path.isdir(img_folder):
print(f"❌ 图像文件夹不存在: {img_folder}")
exit()
image_files = [f for f in os.listdir(img_folder) if f.lower().endswith(SUPPORTED_FORMATS)]
if len(image_files) == 0:
print(f"⚠️ 未找到支持的图像文件: {img_folder}")
exit()
print(f"✅ 发现 {len(image_files)} 张图像,开始处理...")
for file in image_files:
process_image(os.path.join(img_folder, file), output_mask_dir)
print(f"\n🎉 所有图像处理完成!结果保存在: {output_mask_dir}")