Files
zjsh_yolov11/yemian/yemian_line/danmu_d.py
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

189 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
# --------------------
# 配置
# --------------------
TARGET_SIZE = 640
IMAGE_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/612/train/class0"
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/61seg/exp2/weights/best.pt"
OUTPUT_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/612/train/class2"
'''
ROIS = [
(445, 540, 931, 319),
]
'''
ROIS = [
(0, 0, 640, 640),
]
# --------------------
# 从 mask 中提取左右边界点
# --------------------
def extract_left_right_edge_points(mask_bin):
h, w = mask_bin.shape
left_pts = []
right_pts = []
for y in range(h):
xs = np.where(mask_bin[y] > 0)[0]
if len(xs) >= 2:
left_pts.append([xs.min(), y])
right_pts.append([xs.max(), y])
return np.array(left_pts), np.array(right_pts)
# --------------------
# 按 seg 的 y 百分比筛选
# --------------------
def filter_by_seg_y_ratio(pts, y_start=0.35, y_end=0.85):
if len(pts) < 2:
return pts
y_min = pts[:, 1].min()
y_max = pts[:, 1].max()
h = y_max - y_min
if h < 10:
return pts
y0 = y_min + int(h * y_start)
y1 = y_min + int(h * y_end)
return pts[(pts[:, 1] >= y0) & (pts[:, 1] <= y1)]
# --------------------
# 拟合直线
# --------------------
def fit_line(pts):
if len(pts) < 2:
return None
x = pts[:, 0]
y = pts[:, 1]
m, b = np.polyfit(y, x, 1)
return m, b
# --------------------
# y 参考值seg 底部)
# --------------------
def get_y_ref(mask_bin):
h, w = mask_bin.shape
ys = []
for x in range(int(w * 0.2), int(w * 0.8)):
y = np.where(mask_bin[:, x] > 0)[0]
if len(y):
ys.append(y.max())
return int(np.mean(ys)) if ys else h // 2
# --------------------
# 单图处理
# --------------------
def process_one(img_path, model):
img = cv2.imread(str(img_path))
vis = img.copy()
result_data = None # (XL, Y, XR, Y, diff)
for rx, ry, rw, rh in ROIS:
roi = img[ry:ry+rh, rx:rx+rw]
resized = cv2.resize(roi, (TARGET_SIZE, TARGET_SIZE))
result = model(resized, imgsz=TARGET_SIZE, verbose=False)[0]
if result.masks is None:
continue
mask = result.masks.data[0].cpu().numpy()
mask_bin = (mask > 0.5).astype(np.uint8)
mask_bin = cv2.resize(mask_bin, (rw, rh), cv2.INTER_NEAREST)
# overlay mask
green = np.zeros_like(roi)
green[mask_bin == 1] = (0, 255, 0)
vis[ry:ry+rh, rx:rx+rw] = cv2.addWeighted(roi, 0.7, green, 0.3, 0)
# 边界点
left_pts, right_pts = extract_left_right_edge_points(mask_bin)
left_pts = filter_by_seg_y_ratio(left_pts)
right_pts = filter_by_seg_y_ratio(right_pts)
left_line = fit_line(left_pts)
right_line = fit_line(right_pts)
if left_line is None or right_line is None:
continue
m1, b1 = left_line
m2, b2 = right_line
y_ref = get_y_ref(mask_bin)
# ROI 坐标
x_left = int(m1 * y_ref + b1)
x_right = int(m2 * y_ref + b2)
# 🔴 全局坐标
X_L = rx + x_left
X_R = rx + x_right
Y = ry + y_ref
diff = X_R - X_L
result_data = (X_L, Y, X_R, Y, diff)
# ---------- 可视化 ----------
roi_vis = vis[ry:ry+rh, rx:rx+rw]
for (m, b), c in [((m1, b1), (0,0,255)), ((m2, b2), (255,0,0))]:
cv2.line(
roi_vis,
(int(m * 0 + b), 0),
(int(m * rh + b), rh),
c, 3
)
cv2.line(roi_vis, (0, y_ref), (rw, y_ref), (0,255,255), 2)
cv2.circle(roi_vis, (x_left, y_ref), 6, (0,0,255), -1)
cv2.circle(roi_vis, (x_right, y_ref), 6, (255,0,0), -1)
cv2.putText(
roi_vis,
f"diff={diff}px",
(10, 40),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0,255,255),
2
)
return vis, result_data
# --------------------
# 批处理
# --------------------
def run():
model = YOLO(MODEL_PATH)
Path(OUTPUT_DIR).mkdir(exist_ok=True)
for img in sorted(os.listdir(IMAGE_DIR)):
if not img.lower().endswith((".jpg", ".png", ".jpeg")):
continue
vis, data = process_one(Path(IMAGE_DIR) / img, model)
out = Path(OUTPUT_DIR) / f"vis_{img}"
cv2.imwrite(str(out), vis)
if data:
XL, YL, XR, YR, diff = data
print(f"[{img}]")
print(f" 左交点: ({XL}, {YL})")
print(f" 右交点: ({XR}, {YR})")
print(f" diff : {diff} px")
else:
print(f"[{img}] 无有效结果")
# --------------------
if __name__ == "__main__":
run()