Files
zjsh_yolov11/test_line_angle_f.py
2025-08-13 14:49:06 +08:00

163 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import cv2
import numpy as np
import os
# ------------------ 配置 ------------------
model_path = 'ultralytics_yolo11-main/runs/train/exp4/weights/best.pt'
img_folder = '/home/hx/yolo/ultralytics_yolo11-main/dataset1/test' # 你的图片文件夹路径
output_mask_dir = 'output_masks1'
os.makedirs(output_mask_dir, exist_ok=True)
# 支持的图像格式
SUPPORTED_FORMATS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
# ------------------ 加载模型 ------------------
model = YOLO(model_path)
model.to('cuda') # 使用 GPU
def process_image(img_path, output_dir):
"""处理单张图像提取最长4条边找出最接近垂直的2条"""
img = cv2.imread(img_path)
if img is None:
print(f"❌ 无法读取图像: {img_path}")
return
h, w = img.shape[:2]
filename = os.path.basename(img_path)
name_only = os.path.splitext(filename)[0]
print(f"\n🔄 正在处理: {filename}")
# ------------------ 生成合成掩码 ------------------
composite_mask = np.zeros((h, w), dtype=np.uint8)
results = model(img_path, imgsz=1280, conf=0.5)
for r in results:
if r.masks is not None:
masks = r.masks.data.cpu().numpy()
for mask in masks:
mask_resized = cv2.resize(mask, (w, h))
mask_img = (mask_resized * 255).astype(np.uint8)
composite_mask = np.maximum(composite_mask, mask_img)
# 保存掩码
mask_save_path = os.path.join(output_dir, f'mask_{name_only}.png')
cv2.imwrite(mask_save_path, composite_mask)
print(f"✅ 掩码已保存: {mask_save_path}")
# ------------------ 提取轮廓 ------------------
contours, _ = cv2.findContours(composite_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) == 0:
print(f"⚠️ 未检测到轮廓: {filename}")
return
all_contours = np.vstack(contours)
epsilon = 0.005 * cv2.arcLength(all_contours, True)
approx = cv2.approxPolyDP(all_contours, epsilon, True)
pts = [p[0] for p in approx] # 展平为 (x, y)
# ------------------ 提取边段 ------------------
line_segments = []
n = len(pts)
for i in range(n):
p1 = np.array(pts[i])
p2 = np.array(pts[(i + 1) % n])
dx = p2[0] - p1[0]
dy = p2[1] - p1[1]
length = np.linalg.norm([dx, dy])
if length < 20:
continue # 忽略太短的边
# 计算“垂直相似度”:|dx| / |dy|,越小越接近垂直
vertical_score = float('inf') if abs(dy) < 1e-6 else abs(dx) / abs(dy)
# 计算斜率
slope = float('inf') if abs(dx) < 1e-6 else dy / dx
line_segments.append({
'start': p1,
'end': p2,
'length': length,
'dx': dx,
'dy': dy,
'slope': slope,
'vertical_score': vertical_score # 越小越接近垂直
})
# 按长度排序取最长的4条边
line_segments.sort(key=lambda x: x['length'], reverse=True)
top_4_segments = line_segments[:4]
if len(top_4_segments) < 2:
print(f"⚠️ 不足2条有效边{len(top_4_segments)}条): {filename}")
return
# ------------------ 找出最接近垂直的2条边 ------------------
# 按 vertical_score 升序排序:值越小,越接近垂直
top_4_segments.sort(key=lambda x: x['vertical_score'])
vertical_candidates = top_4_segments[:2] # 最接近垂直的两条边
# ------------------ 打印信息 ------------------
print(f"✅ 提取最长的4条边:")
for i, seg in enumerate(top_4_segments):
slope_val = seg['slope'] if isinstance(seg['slope'], float) else (
'inf' if seg['slope'] == float('inf') else '-inf')
print(f"{i + 1}长边: 长度={seg['length']:.1f}, 斜率={slope_val:.3f}, 垂直评分={seg['vertical_score']:.3f}")
print(f"✅ 最接近垂直的两条边(垂直评分最小):")
for i, seg in enumerate(vertical_candidates):
slope_val = seg['slope'] if isinstance(seg['slope'], float) else (
'inf' if seg['slope'] == float('inf') else '-inf')
print(f" 垂直边{i + 1}: 长度={seg['length']:.1f}, 斜率={slope_val:.3f}, 垂直评分={seg['vertical_score']:.3f}")
# ------------------ 可视化 ------------------
vis_img = img.copy()
colors = [(0, 0, 255), (255, 0, 0)] # 红色, 蓝色
for i, seg in enumerate(vertical_candidates):
p1, p2 = seg['start'], seg['end']
color = colors[i]
cv2.line(vis_img, tuple(p1), tuple(p2), color, 3, cv2.LINE_AA)
mid_x = (p1[0] + p2[0]) // 2
mid_y = (p1[1] + p2[1]) // 2
slope_val = 'inf' if not isinstance(seg['slope'], float) else f"{seg['slope']:.1f}"
cv2.putText(vis_img, f"{seg['length']:.0f}, k={slope_val}",
(mid_x, mid_y - 10 + i * 25),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
# 添加图例说明
cv2.putText(vis_img, "Red: 1st closest to vertical", (20, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1)
cv2.putText(vis_img, "Blue: 2nd closest to vertical", (20, 60),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 1)
# 保存可视化图像
vis_save_path = os.path.join(output_dir, f'vertical_{name_only}.jpg')
cv2.imwrite(vis_save_path, vis_img)
print(f"✅ 可视化图像已保存: {vis_save_path}")
# ------------------ 主程序:遍历文件夹 ------------------
if __name__ == '__main__':
if not os.path.isdir(img_folder):
print(f"❌ 图像文件夹不存在: {img_folder}")
exit()
image_files = [
f for f in os.listdir(img_folder)
if f.lower().endswith(SUPPORTED_FORMATS)
]
if len(image_files) == 0:
print(f"⚠️ 在 {img_folder} 中未找到支持的图像文件")
exit()
print(f"✅ 发现 {len(image_files)} 张图像,开始批量处理...")
for image_file in image_files:
image_path = os.path.join(img_folder, image_file)
process_image(image_path, output_mask_dir)
print(f"\n🎉 所有图像处理完成!结果保存在: {output_mask_dir}")