258 lines
9.2 KiB
Python
258 lines
9.2 KiB
Python
import os
|
||
import cv2
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
import logging
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||
|
||
"""
|
||
✅ YOLOv8-OBB 单类 ONNX 推理脚本(严格遵循 65 = 64 + 1 + 0 结构)
|
||
输出通道:64 (xywh分布) + 1 (obj) → 无显式类别分支
|
||
角度来自独立分支 Output[3]
|
||
"""
|
||
|
||
|
||
def letterbox(img, new_shape=640, color=(114, 114, 114), auto=False, scale_fill=False, scale_up=True, stride=32):
|
||
"""图像预处理:保持宽高比的缩放"""
|
||
shape = img.shape[:2]
|
||
r = min(new_shape / shape[0], new_shape / shape[1])
|
||
if not scale_up:
|
||
r = min(r, 1.0)
|
||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||
dw, dh = new_shape - new_unpad[0], new_shape - new_unpad[1]
|
||
if auto:
|
||
dw, dh = np.mod(dw, stride), np.mod(dh, stride)
|
||
elif scale_fill:
|
||
dw, dh = 0.0, 0.0
|
||
new_unpad = (new_shape, new_shape)
|
||
dw /= 2
|
||
dh /= 2
|
||
if shape[::-1] != new_unpad:
|
||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
|
||
return img, (dw, dh), r
|
||
|
||
|
||
def load_model(onnx_path):
|
||
"""加载 ONNX 模型"""
|
||
session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
|
||
logging.info(f"✅ 模型加载成功: {onnx_path}")
|
||
for i, output in enumerate(session.get_outputs()):
|
||
print(f"Output {i}: {output.name}, shape: {output.shape}")
|
||
return session
|
||
|
||
|
||
def _get_covariance_matrix(obb):
|
||
"""计算协方差矩阵参数 a, b, c 用于 ProbIoU"""
|
||
cx, cy = obb[..., 0], obb[..., 1]
|
||
w2 = (obb[..., 2] / 2) ** 2
|
||
h2 = (obb[..., 3] / 2) ** 2
|
||
cos_a = np.cos(obb[..., 4])
|
||
sin_a = np.sin(obb[..., 4])
|
||
a = cos_a ** 2 * w2 + sin_a ** 2 * h2
|
||
b = sin_a ** 2 * w2 + cos_a ** 2 * h2
|
||
c = (w2 - h2) * cos_a * sin_a
|
||
return a, b, c
|
||
|
||
|
||
def batch_probiou(obb1, obb2, eps=1e-7):
|
||
"""计算 ProbIoU"""
|
||
x1, y1 = obb1[..., 0], obb1[..., 1]
|
||
x2, y2 = obb2[..., 0], obb2[..., 1]
|
||
a1, b1, c1 = _get_covariance_matrix(obb1)
|
||
a2, b2, c2 = _get_covariance_matrix(obb2)
|
||
|
||
t1 = ((a1[:, None] + a2) * (y1[:, None] - y2) ** 2 +
|
||
(b1[:, None] + b2) * (x1[:, None] - x2) ** 2) / \
|
||
((a1[:, None] + a2) * (b1[:, None] + b2) - (c1[:, None] + c2) ** 2 + eps) * 0.25
|
||
t2 = ((c1[:, None] + c2) * (x2 - x1[:, None]) * (y1[:, None] - y2)) / \
|
||
((a1[:, None] + a2) * (b1[:, None] + b2) - (c1[:, None] + c2) ** 2 + eps) * 0.5
|
||
t3 = np.log(((a1[:, None] + a2) * (b1[:, None] + b2) - (c1[:, None] + c2) ** 2) /
|
||
(4 * np.sqrt((a1 * b1 - c1 ** 2)[:, None] * (a2 * b2 - c2 ** 2)) + eps) + eps) * 0.5
|
||
|
||
bd = np.clip(t1 + t2 + t3, eps, 100.0)
|
||
hd = np.sqrt(1.0 - np.exp(-bd) + eps)
|
||
return 1 - hd
|
||
|
||
|
||
def rotated_nms(boxes, scores, iou_threshold=0.5):
|
||
"""使用 ProbIoU 的旋转框 NMS"""
|
||
if len(boxes) == 0:
|
||
return []
|
||
order = scores.argsort()[::-1]
|
||
keep = []
|
||
while len(order) > 0:
|
||
i = order[0]
|
||
keep.append(i)
|
||
if len(order) == 1:
|
||
break
|
||
ious = batch_probiou(boxes[i:i+1], boxes[order[1:]])[0]
|
||
order = order[1:][ious < iou_threshold]
|
||
return keep
|
||
|
||
|
||
def run_inference(session, image_path, img_size=640):
|
||
"""运行单张图像推理"""
|
||
im0 = cv2.imread(image_path)
|
||
if im0 is None:
|
||
raise ValueError(f"无法读取图像: {image_path}")
|
||
img, (dw, dh), r = letterbox(im0, new_shape=img_size)
|
||
img = img.transpose(2, 0, 1)[::-1] # HWC -> CHW, BGR -> RGB
|
||
img = np.ascontiguousarray(img).astype(np.float32) / 255.0
|
||
img = img[None] # (1, 3, 640, 640)
|
||
|
||
input_name = session.get_inputs()[0].name
|
||
outputs = session.run(None, {input_name: img})
|
||
return outputs, im0, (dw, dh), r
|
||
|
||
|
||
def parse_outputs(outputs, dw, dh, r, conf_threshold=0.5, iou_threshold=0.3):
|
||
"""
|
||
解析 ONNX 输出
|
||
结构: 65 = 64 (xywh分布) + 1 (obj) → 无类别分支
|
||
"""
|
||
detections = [] # 存储 (corners, confidence)
|
||
strides = [8, 16, 32]
|
||
sizes = [80, 40, 20]
|
||
|
||
# 解析角度分支 (1, 1, 8400)
|
||
angle_flat = outputs[3][0, 0] # (8400,)
|
||
angles = []
|
||
start = 0
|
||
for size in sizes:
|
||
end = start + size * size
|
||
angles.append(angle_flat[start:end].reshape(size, size))
|
||
start = end
|
||
|
||
# 处理每个尺度
|
||
for i in range(3):
|
||
out = outputs[i][0] # (65, h, w)
|
||
stride = strides[i]
|
||
h, w = sizes[i], sizes[i]
|
||
angle_map = angles[i]
|
||
|
||
# 解码 box: 64 channels -> 4 coords
|
||
box_data = out[:64].reshape(4, 16, h, w) # (4, 16, h, w)
|
||
box_data = np.exp(box_data) / np.sum(np.exp(box_data), axis=1, keepdims=True) # softmax over bins
|
||
bin_indices = np.arange(16).reshape(1, 16, 1, 1)
|
||
box_offsets = np.sum(box_data * bin_indices, axis=1) / 16.0 # (4, h, w)
|
||
|
||
# ✅ 只有 objectness,无类别分支 → objectness 即置信度
|
||
obj_score = 1 / (1 + np.exp(-out[64])) # sigmoid → (h, w)
|
||
|
||
# 总置信度(因为无类别分支,所以直接使用 obj_score)
|
||
scores = obj_score # (h, w)
|
||
|
||
# grid 坐标
|
||
yv, xv = np.meshgrid(np.arange(h), np.arange(w))
|
||
grid = np.stack((xv, yv), axis=0).astype(np.float32) # (2, h, w)
|
||
|
||
# 解码坐标
|
||
xy = (box_offsets[:2] + grid) * stride # (2, h, w)
|
||
wh = np.exp(box_offsets[2:]) * stride # (2, h, w)
|
||
|
||
# 遍历每个 anchor
|
||
for y in range(h):
|
||
for x in range(w):
|
||
conf = scores[y, x]
|
||
if conf < conf_threshold:
|
||
continue
|
||
|
||
# 还原到原始图像坐标
|
||
x_center = (xy[0, y, x] - dw) / r
|
||
y_center = (xy[1, y, x] - dh) / r
|
||
width = wh[0, y, x] / r
|
||
height = wh[1, y, x] / r
|
||
angle_rad = angle_map[y, x] # 弧度
|
||
|
||
# 计算四个角点
|
||
corners = calculate_obb_corners(x_center, y_center, width, height, angle_rad)
|
||
detections.append({
|
||
"position": corners,
|
||
"confidence": float(conf),
|
||
"angle": float(angle_rad)
|
||
})
|
||
|
||
if not detections:
|
||
return []
|
||
|
||
# 提取用于 NMS 的旋转框 (cx, cy, w, h, angle)
|
||
nms_boxes = np.array([
|
||
[np.mean([p[0] for p in det["position"]]),
|
||
np.mean([p[1] for p in det["position"]]),
|
||
cv2.minAreaRect(np.array(det["position"], dtype=np.float32))[1][0],
|
||
cv2.minAreaRect(np.array(det["position"], dtype=np.float32))[1][1],
|
||
det["angle"]]
|
||
for det in detections
|
||
])
|
||
confs = np.array([det["confidence"] for det in detections])
|
||
keep_indices = rotated_nms(nms_boxes, confs, iou_threshold)
|
||
|
||
return [detections[i] for i in keep_indices]
|
||
|
||
|
||
def calculate_obb_corners(cx, cy, w, h, angle):
|
||
"""根据中心、宽高、角度计算四个角点"""
|
||
cos_a = np.cos(angle)
|
||
sin_a = np.sin(angle)
|
||
dx = w / 2
|
||
dy = h / 2
|
||
corners = [
|
||
(cx + cos_a * dx - sin_a * dy, cy + sin_a * dx + cos_a * dy),
|
||
(cx - cos_a * dx - sin_a * dy, cy - sin_a * dx + cos_a * dy),
|
||
(cx - cos_a * dx + sin_a * dy, cy - sin_a * dx - cos_a * dy),
|
||
(cx + cos_a * dx + sin_a * dy, cy + sin_a * dx - cos_a * dy),
|
||
]
|
||
return [(int(p[0]), int(p[1])) for p in corners]
|
||
|
||
|
||
def save_result(im0, detections, output_path):
|
||
"""保存检测结果图像"""
|
||
for det in detections:
|
||
corners = det["position"]
|
||
conf = det["confidence"]
|
||
# 画旋转框
|
||
for j in range(4):
|
||
pt1 = corners[j]
|
||
pt2 = corners[(j + 1) % 4]
|
||
cv2.line(im0, pt1, pt2, (0, 255, 0), 2)
|
||
# 标注置信度
|
||
cv2.putText(im0, f'{conf:.2f}', (corners[0][0], corners[0][1] - 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
||
cv2.imwrite(output_path, im0)
|
||
|
||
|
||
def main():
|
||
model_path = "obb.onnx" # ✅ 替换为您的 .onnx 模型路径
|
||
image_folder = "/home/hx/yolo/output_masks/" # ✅ 图像文件夹
|
||
output_folder = "results/" # ✅ 输出文件夹
|
||
conf_threshold = 0.5 # 置信度阈值
|
||
iou_threshold = 0.3 # NMS IoU 阈值
|
||
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
|
||
# 加载模型
|
||
session = load_model(model_path)
|
||
|
||
# 遍历图像
|
||
for fname in os.listdir(image_folder):
|
||
if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
|
||
image_path = os.path.join(image_folder, fname)
|
||
try:
|
||
outputs, im0, (dw, dh), r = run_inference(session, image_path)
|
||
detections = parse_outputs(outputs, dw, dh, r, conf_threshold, iou_threshold)
|
||
output_path = os.path.join(output_folder, fname)
|
||
save_result(im0, detections, output_path)
|
||
logging.info(f"✅ 已保存: {output_path} (检测到 {len(detections)} 个目标)")
|
||
except Exception as e:
|
||
logging.error(f"❌ 处理 {fname} 失败: {e}")
|
||
|
||
logging.info("🎉 所有图像处理完成!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |