Files
zjsh_yolov11/yemian/yemian_line/danmu_d.py

189 lines
4.8 KiB
Python
Raw Normal View History

2026-03-10 13:58:21 +08:00
import os
2025-12-11 08:37:09 +08:00
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
# --------------------
2026-03-10 13:58:21 +08:00
# 配置
2025-12-11 08:37:09 +08:00
# --------------------
TARGET_SIZE = 640
2026-03-10 13:58:21 +08:00
IMAGE_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/612/train/class0"
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/61seg/exp2/weights/best.pt"
OUTPUT_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/612/train/class2"
'''
ROIS = [
(445, 540, 931, 319),
]
2025-12-11 08:37:09 +08:00
2026-03-10 13:58:21 +08:00
'''
2025-12-11 08:37:09 +08:00
ROIS = [
2026-03-10 13:58:21 +08:00
(0, 0, 640, 640),
2025-12-11 08:37:09 +08:00
]
# --------------------
2026-03-10 13:58:21 +08:00
# 从 mask 中提取左右边界点
# --------------------
def extract_left_right_edge_points(mask_bin):
h, w = mask_bin.shape
left_pts = []
right_pts = []
for y in range(h):
xs = np.where(mask_bin[y] > 0)[0]
if len(xs) >= 2:
left_pts.append([xs.min(), y])
right_pts.append([xs.max(), y])
return np.array(left_pts), np.array(right_pts)
# --------------------
# 按 seg 的 y 百分比筛选
2025-12-11 08:37:09 +08:00
# --------------------
2026-03-10 13:58:21 +08:00
def filter_by_seg_y_ratio(pts, y_start=0.35, y_end=0.85):
if len(pts) < 2:
return pts
y_min = pts[:, 1].min()
y_max = pts[:, 1].max()
h = y_max - y_min
if h < 10:
return pts
y0 = y_min + int(h * y_start)
y1 = y_min + int(h * y_end)
2025-12-11 08:37:09 +08:00
2026-03-10 13:58:21 +08:00
return pts[(pts[:, 1] >= y0) & (pts[:, 1] <= y1)]
2025-12-11 08:37:09 +08:00
# --------------------
2026-03-10 13:58:21 +08:00
# 拟合直线
2025-12-11 08:37:09 +08:00
# --------------------
2026-03-10 13:58:21 +08:00
def fit_line(pts):
if len(pts) < 2:
2025-12-11 08:37:09 +08:00
return None
2026-03-10 13:58:21 +08:00
x = pts[:, 0]
y = pts[:, 1]
m, b = np.polyfit(y, x, 1)
return m, b
2025-12-11 08:37:09 +08:00
2026-03-10 13:58:21 +08:00
# --------------------
# y 参考值seg 底部)
# --------------------
def get_y_ref(mask_bin):
h, w = mask_bin.shape
ys = []
for x in range(int(w * 0.2), int(w * 0.8)):
y = np.where(mask_bin[:, x] > 0)[0]
if len(y):
ys.append(y.max())
return int(np.mean(ys)) if ys else h // 2
2025-12-11 08:37:09 +08:00
2026-03-10 13:58:21 +08:00
# --------------------
# 单图处理
# --------------------
def process_one(img_path, model):
img = cv2.imread(str(img_path))
vis = img.copy()
2025-12-11 08:37:09 +08:00
2026-03-10 13:58:21 +08:00
result_data = None # (XL, Y, XR, Y, diff)
2025-12-11 08:37:09 +08:00
2026-03-10 13:58:21 +08:00
for rx, ry, rw, rh in ROIS:
roi = img[ry:ry+rh, rx:rx+rw]
resized = cv2.resize(roi, (TARGET_SIZE, TARGET_SIZE))
2025-12-11 08:37:09 +08:00
2026-03-10 13:58:21 +08:00
result = model(resized, imgsz=TARGET_SIZE, verbose=False)[0]
if result.masks is None:
2025-12-11 08:37:09 +08:00
continue
mask = result.masks.data[0].cpu().numpy()
mask_bin = (mask > 0.5).astype(np.uint8)
2026-03-10 13:58:21 +08:00
mask_bin = cv2.resize(mask_bin, (rw, rh), cv2.INTER_NEAREST)
# overlay mask
green = np.zeros_like(roi)
green[mask_bin == 1] = (0, 255, 0)
vis[ry:ry+rh, rx:rx+rw] = cv2.addWeighted(roi, 0.7, green, 0.3, 0)
# 边界点
left_pts, right_pts = extract_left_right_edge_points(mask_bin)
left_pts = filter_by_seg_y_ratio(left_pts)
right_pts = filter_by_seg_y_ratio(right_pts)
left_line = fit_line(left_pts)
right_line = fit_line(right_pts)
if left_line is None or right_line is None:
continue
m1, b1 = left_line
m2, b2 = right_line
y_ref = get_y_ref(mask_bin)
# ROI 坐标
x_left = int(m1 * y_ref + b1)
x_right = int(m2 * y_ref + b2)
# 🔴 全局坐标
X_L = rx + x_left
X_R = rx + x_right
Y = ry + y_ref
diff = X_R - X_L
result_data = (X_L, Y, X_R, Y, diff)
# ---------- 可视化 ----------
roi_vis = vis[ry:ry+rh, rx:rx+rw]
for (m, b), c in [((m1, b1), (0,0,255)), ((m2, b2), (255,0,0))]:
cv2.line(
roi_vis,
(int(m * 0 + b), 0),
(int(m * rh + b), rh),
c, 3
)
cv2.line(roi_vis, (0, y_ref), (rw, y_ref), (0,255,255), 2)
cv2.circle(roi_vis, (x_left, y_ref), 6, (0,0,255), -1)
cv2.circle(roi_vis, (x_right, y_ref), 6, (255,0,0), -1)
cv2.putText(
roi_vis,
f"diff={diff}px",
(10, 40),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0,255,255),
2
)
return vis, result_data
# --------------------
# 批处理
# --------------------
def run():
model = YOLO(MODEL_PATH)
Path(OUTPUT_DIR).mkdir(exist_ok=True)
for img in sorted(os.listdir(IMAGE_DIR)):
if not img.lower().endswith((".jpg", ".png", ".jpeg")):
continue
vis, data = process_one(Path(IMAGE_DIR) / img, model)
out = Path(OUTPUT_DIR) / f"vis_{img}"
cv2.imwrite(str(out), vis)
if data:
XL, YL, XR, YR, diff = data
print(f"[{img}]")
print(f" 左交点: ({XL}, {YL})")
print(f" 右交点: ({XR}, {YR})")
print(f" diff : {diff} px")
else:
print(f"[{img}] 无有效结果")
# --------------------
2025-12-11 08:37:09 +08:00
if __name__ == "__main__":
2026-03-10 13:58:21 +08:00
run()