Files
zjsh_yolov11/zhuangtai_class_cls/lunkuoxian_f.py

89 lines
2.7 KiB
Python
Raw Permalink Normal View History

2026-03-10 13:58:21 +08:00
import cv2
import numpy as np
from sklearn.linear_model import RANSACRegressor
from scipy.interpolate import UnivariateSpline
from pathlib import Path
import os
# -----------------------------
# 配置
# -----------------------------
INPUT_FOLDER = "/home/hx/yolo/yemian/test_image" # 输入图片文件夹
OUTPUT_FOLDER = "/home/hx/yolo/yemian/test_image/output" # 保存结果
ROI = (519, 757, 785, 50) # x, y, w, h
RANSAC_RES_THRESHOLD = 2 # RANSAC residual_threshold
SPLINE_S = 5 # spline 平滑系数
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
# -----------------------------
# 批量处理
# -----------------------------
image_files = list(Path(INPUT_FOLDER).glob("*.png")) + list(Path(INPUT_FOLDER).glob("*.jpg"))
for img_path in image_files:
print(f"处理图片: {img_path.name}")
img = cv2.imread(str(img_path))
if img is None:
print("无法读取图片,跳过")
continue
x, y, w, h = ROI
roi_img = img[y:y+h, x:x+w].copy()
# 灰度 + CLAHE + 高斯模糊
gray = cv2.cvtColor(roi_img, cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
gray_clahe = clahe.apply(gray)
blur = cv2.GaussianBlur(gray_clahe, (3,3), 0)
# Canny 边缘
edges = cv2.Canny(blur, 50, 150)
# 提取边缘点
ys, xs = np.where(edges > 0)
points = np.vstack([xs, ys]).T # shape=(N,2)
if len(points) < 10:
print("边缘点太少,跳过")
continue
# -----------------------------
# RANSAC 拟合 y = f(x)
# -----------------------------
x_pts = points[:,0].reshape(-1,1)
y_pts = points[:,1]
try:
ransac = RANSACRegressor(min_samples=5, residual_threshold=RANSAC_RES_THRESHOLD, max_trials=1000)
ransac.fit(x_pts, y_pts)
y_ransac = ransac.predict(x_pts)
# spline 平滑
sort_idx = np.argsort(x_pts[:,0])
x_sorted = x_pts[:,0][sort_idx]
y_sorted = y_ransac[sort_idx]
spline = UnivariateSpline(x_sorted, y_sorted, s=SPLINE_S)
x_new = np.linspace(x_sorted[0], x_sorted[-1], 500)
y_new = spline(x_new)
y_new = np.nan_to_num(y_new, nan=np.mean(y_sorted))
except Exception as e:
print(f"红色曲线拟合失败,跳过: {e}")
continue
# -----------------------------
# 可视化红色曲线
# -----------------------------
output = roi_img.copy()
for i in range(len(x_new)-1):
cv2.line(output,
(int(x_new[i]), int(y_new[i])),
(int(x_new[i+1]), int(y_new[i+1])),
(0,0,255), 2)
# 保存结果
out_path = Path(OUTPUT_FOLDER) / f"{img_path.stem}_curve.png"
cv2.imwrite(str(out_path), output)
print(f"保存可视化结果: {out_path.name}")