189 lines
4.8 KiB
Python
189 lines
4.8 KiB
Python
import os
|
||
import cv2
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from ultralytics import YOLO
|
||
|
||
# --------------------
|
||
# 配置
|
||
# --------------------
|
||
TARGET_SIZE = 640
|
||
IMAGE_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/612/train/class0"
|
||
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/61seg/exp2/weights/best.pt"
|
||
OUTPUT_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/612/train/class2"
|
||
'''
|
||
ROIS = [
|
||
(445, 540, 931, 319),
|
||
]
|
||
|
||
'''
|
||
ROIS = [
|
||
(0, 0, 640, 640),
|
||
]
|
||
|
||
# --------------------
|
||
# 从 mask 中提取左右边界点
|
||
# --------------------
|
||
def extract_left_right_edge_points(mask_bin):
|
||
h, w = mask_bin.shape
|
||
left_pts = []
|
||
right_pts = []
|
||
|
||
for y in range(h):
|
||
xs = np.where(mask_bin[y] > 0)[0]
|
||
if len(xs) >= 2:
|
||
left_pts.append([xs.min(), y])
|
||
right_pts.append([xs.max(), y])
|
||
|
||
return np.array(left_pts), np.array(right_pts)
|
||
|
||
# --------------------
|
||
# 按 seg 的 y 百分比筛选
|
||
# --------------------
|
||
def filter_by_seg_y_ratio(pts, y_start=0.35, y_end=0.85):
|
||
if len(pts) < 2:
|
||
return pts
|
||
|
||
y_min = pts[:, 1].min()
|
||
y_max = pts[:, 1].max()
|
||
h = y_max - y_min
|
||
if h < 10:
|
||
return pts
|
||
|
||
y0 = y_min + int(h * y_start)
|
||
y1 = y_min + int(h * y_end)
|
||
|
||
return pts[(pts[:, 1] >= y0) & (pts[:, 1] <= y1)]
|
||
|
||
# --------------------
|
||
# 拟合直线
|
||
# --------------------
|
||
def fit_line(pts):
|
||
if len(pts) < 2:
|
||
return None
|
||
x = pts[:, 0]
|
||
y = pts[:, 1]
|
||
m, b = np.polyfit(y, x, 1)
|
||
return m, b
|
||
|
||
# --------------------
|
||
# y 参考值(seg 底部)
|
||
# --------------------
|
||
def get_y_ref(mask_bin):
|
||
h, w = mask_bin.shape
|
||
ys = []
|
||
for x in range(int(w * 0.2), int(w * 0.8)):
|
||
y = np.where(mask_bin[:, x] > 0)[0]
|
||
if len(y):
|
||
ys.append(y.max())
|
||
return int(np.mean(ys)) if ys else h // 2
|
||
|
||
# --------------------
|
||
# 单图处理
|
||
# --------------------
|
||
def process_one(img_path, model):
|
||
img = cv2.imread(str(img_path))
|
||
vis = img.copy()
|
||
|
||
result_data = None # (XL, Y, XR, Y, diff)
|
||
|
||
for rx, ry, rw, rh in ROIS:
|
||
roi = img[ry:ry+rh, rx:rx+rw]
|
||
resized = cv2.resize(roi, (TARGET_SIZE, TARGET_SIZE))
|
||
|
||
result = model(resized, imgsz=TARGET_SIZE, verbose=False)[0]
|
||
if result.masks is None:
|
||
continue
|
||
|
||
mask = result.masks.data[0].cpu().numpy()
|
||
mask_bin = (mask > 0.5).astype(np.uint8)
|
||
mask_bin = cv2.resize(mask_bin, (rw, rh), cv2.INTER_NEAREST)
|
||
|
||
# overlay mask
|
||
green = np.zeros_like(roi)
|
||
green[mask_bin == 1] = (0, 255, 0)
|
||
vis[ry:ry+rh, rx:rx+rw] = cv2.addWeighted(roi, 0.7, green, 0.3, 0)
|
||
|
||
# 边界点
|
||
left_pts, right_pts = extract_left_right_edge_points(mask_bin)
|
||
left_pts = filter_by_seg_y_ratio(left_pts)
|
||
right_pts = filter_by_seg_y_ratio(right_pts)
|
||
|
||
left_line = fit_line(left_pts)
|
||
right_line = fit_line(right_pts)
|
||
if left_line is None or right_line is None:
|
||
continue
|
||
|
||
m1, b1 = left_line
|
||
m2, b2 = right_line
|
||
|
||
y_ref = get_y_ref(mask_bin)
|
||
|
||
# ROI 坐标
|
||
x_left = int(m1 * y_ref + b1)
|
||
x_right = int(m2 * y_ref + b2)
|
||
|
||
# 🔴 全局坐标
|
||
X_L = rx + x_left
|
||
X_R = rx + x_right
|
||
Y = ry + y_ref
|
||
|
||
diff = X_R - X_L
|
||
|
||
result_data = (X_L, Y, X_R, Y, diff)
|
||
|
||
# ---------- 可视化 ----------
|
||
roi_vis = vis[ry:ry+rh, rx:rx+rw]
|
||
|
||
for (m, b), c in [((m1, b1), (0,0,255)), ((m2, b2), (255,0,0))]:
|
||
cv2.line(
|
||
roi_vis,
|
||
(int(m * 0 + b), 0),
|
||
(int(m * rh + b), rh),
|
||
c, 3
|
||
)
|
||
|
||
cv2.line(roi_vis, (0, y_ref), (rw, y_ref), (0,255,255), 2)
|
||
cv2.circle(roi_vis, (x_left, y_ref), 6, (0,0,255), -1)
|
||
cv2.circle(roi_vis, (x_right, y_ref), 6, (255,0,0), -1)
|
||
|
||
cv2.putText(
|
||
roi_vis,
|
||
f"diff={diff}px",
|
||
(10, 40),
|
||
cv2.FONT_HERSHEY_SIMPLEX,
|
||
1,
|
||
(0,255,255),
|
||
2
|
||
)
|
||
|
||
return vis, result_data
|
||
|
||
# --------------------
|
||
# 批处理
|
||
# --------------------
|
||
def run():
|
||
model = YOLO(MODEL_PATH)
|
||
Path(OUTPUT_DIR).mkdir(exist_ok=True)
|
||
|
||
for img in sorted(os.listdir(IMAGE_DIR)):
|
||
if not img.lower().endswith((".jpg", ".png", ".jpeg")):
|
||
continue
|
||
|
||
vis, data = process_one(Path(IMAGE_DIR) / img, model)
|
||
out = Path(OUTPUT_DIR) / f"vis_{img}"
|
||
cv2.imwrite(str(out), vis)
|
||
|
||
if data:
|
||
XL, YL, XR, YR, diff = data
|
||
print(f"[{img}]")
|
||
print(f" 左交点: ({XL}, {YL})")
|
||
print(f" 右交点: ({XR}, {YR})")
|
||
print(f" diff : {diff} px")
|
||
else:
|
||
print(f"[{img}] 无有效结果")
|
||
|
||
# --------------------
|
||
if __name__ == "__main__":
|
||
run()
|