Files
zjsh_yolov11/test_file.py

163 lines
6.1 KiB
Python
Raw Normal View History

2025-08-13 12:53:33 +08:00
from ultralytics import YOLO
import cv2
import numpy as np
import os
# ------------------ 配置 ------------------
model_path = 'ultralytics_yolo11-main/runs/train/exp4/weights/best.pt'
img_folder = '/home/hx/yolo/ultralytics_yolo11-main/dataset1/test' # 你的图片文件夹路径
output_mask_dir = 'output_masks1'
os.makedirs(output_mask_dir, exist_ok=True)
# 支持的图像格式
SUPPORTED_FORMATS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
# ------------------ 加载模型 ------------------
model = YOLO(model_path)
model.to('cuda') # 使用 GPU
def process_image(img_path, output_dir):
"""处理单张图像提取最长4条边找出最接近垂直的2条"""
img = cv2.imread(img_path)
if img is None:
print(f"❌ 无法读取图像: {img_path}")
return
h, w = img.shape[:2]
filename = os.path.basename(img_path)
name_only = os.path.splitext(filename)[0]
print(f"\n🔄 正在处理: {filename}")
# ------------------ 生成合成掩码 ------------------
composite_mask = np.zeros((h, w), dtype=np.uint8)
results = model(img_path, imgsz=1280, conf=0.5)
for r in results:
if r.masks is not None:
masks = r.masks.data.cpu().numpy()
for mask in masks:
mask_resized = cv2.resize(mask, (w, h))
mask_img = (mask_resized * 255).astype(np.uint8)
composite_mask = np.maximum(composite_mask, mask_img)
# 保存掩码
mask_save_path = os.path.join(output_dir, f'mask_{name_only}.png')
cv2.imwrite(mask_save_path, composite_mask)
print(f"✅ 掩码已保存: {mask_save_path}")
# ------------------ 提取轮廓 ------------------
contours, _ = cv2.findContours(composite_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) == 0:
print(f"⚠️ 未检测到轮廓: {filename}")
return
all_contours = np.vstack(contours)
epsilon = 0.005 * cv2.arcLength(all_contours, True)
approx = cv2.approxPolyDP(all_contours, epsilon, True)
pts = [p[0] for p in approx] # 展平为 (x, y)
# ------------------ 提取边段 ------------------
line_segments = []
n = len(pts)
for i in range(n):
p1 = np.array(pts[i])
p2 = np.array(pts[(i + 1) % n])
dx = p2[0] - p1[0]
dy = p2[1] - p1[1]
length = np.linalg.norm([dx, dy])
if length < 20:
continue # 忽略太短的边
# 计算“垂直相似度”:|dx| / |dy|,越小越接近垂直
vertical_score = float('inf') if abs(dy) < 1e-6 else abs(dx) / abs(dy)
# 计算斜率
slope = float('inf') if abs(dx) < 1e-6 else dy / dx
line_segments.append({
'start': p1,
'end': p2,
'length': length,
'dx': dx,
'dy': dy,
'slope': slope,
'vertical_score': vertical_score # 越小越接近垂直
})
# 按长度排序取最长的4条边
line_segments.sort(key=lambda x: x['length'], reverse=True)
top_4_segments = line_segments[:4]
if len(top_4_segments) < 2:
print(f"⚠️ 不足2条有效边{len(top_4_segments)}条): {filename}")
return
# ------------------ 找出最接近垂直的2条边 ------------------
# 按 vertical_score 升序排序:值越小,越接近垂直
top_4_segments.sort(key=lambda x: x['vertical_score'])
vertical_candidates = top_4_segments[:2] # 最接近垂直的两条边
# ------------------ 打印信息 ------------------
print(f"✅ 提取最长的4条边:")
for i, seg in enumerate(top_4_segments):
slope_val = seg['slope'] if isinstance(seg['slope'], float) else (
'inf' if seg['slope'] == float('inf') else '-inf')
print(f"{i + 1}长边: 长度={seg['length']:.1f}, 斜率={slope_val:.3f}, 垂直评分={seg['vertical_score']:.3f}")
print(f"✅ 最接近垂直的两条边(垂直评分最小):")
for i, seg in enumerate(vertical_candidates):
slope_val = seg['slope'] if isinstance(seg['slope'], float) else (
'inf' if seg['slope'] == float('inf') else '-inf')
print(f" 垂直边{i + 1}: 长度={seg['length']:.1f}, 斜率={slope_val:.3f}, 垂直评分={seg['vertical_score']:.3f}")
# ------------------ 可视化 ------------------
vis_img = img.copy()
colors = [(0, 0, 255), (255, 0, 0)] # 红色, 蓝色
for i, seg in enumerate(vertical_candidates):
p1, p2 = seg['start'], seg['end']
color = colors[i]
cv2.line(vis_img, tuple(p1), tuple(p2), color, 3, cv2.LINE_AA)
mid_x = (p1[0] + p2[0]) // 2
mid_y = (p1[1] + p2[1]) // 2
slope_val = 'inf' if not isinstance(seg['slope'], float) else f"{seg['slope']:.1f}"
cv2.putText(vis_img, f"{seg['length']:.0f}, k={slope_val}",
(mid_x, mid_y - 10 + i * 25),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
# 添加图例说明
cv2.putText(vis_img, "Red: 1st closest to vertical", (20, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1)
cv2.putText(vis_img, "Blue: 2nd closest to vertical", (20, 60),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 1)
# 保存可视化图像
vis_save_path = os.path.join(output_dir, f'vertical_{name_only}.jpg')
cv2.imwrite(vis_save_path, vis_img)
print(f"✅ 可视化图像已保存: {vis_save_path}")
# ------------------ 主程序:遍历文件夹 ------------------
if __name__ == '__main__':
if not os.path.isdir(img_folder):
print(f"❌ 图像文件夹不存在: {img_folder}")
exit()
image_files = [
f for f in os.listdir(img_folder)
if f.lower().endswith(SUPPORTED_FORMATS)
]
if len(image_files) == 0:
print(f"⚠️ 在 {img_folder} 中未找到支持的图像文件")
exit()
print(f"✅ 发现 {len(image_files)} 张图像,开始批量处理...")
for image_file in image_files:
image_path = os.path.join(img_folder, image_file)
process_image(image_path, output_mask_dir)
print(f"\n🎉 所有图像处理完成!结果保存在: {output_mask_dir}")