Files
琉璃月光 8b263167f8 更新
2025-12-11 08:37:09 +08:00

159 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
def rotate_box_to_corners(cx, cy, w, h, angle_rad):
corners = np.array([
[-w / 2.0, -h / 2.0],
[ w / 2.0, -h / 2.0],
[ w / 2.0, h / 2.0],
[-w / 2.0, h / 2.0]
], dtype=np.float32)
cos_a = np.cos(angle_rad)
sin_a = np.sin(angle_rad)
R = np.array([[cos_a, -sin_a],
[sin_a, cos_a]], dtype=np.float32)
rotated = corners.dot(R.T)
rotated += np.array([cx, cy], dtype=np.float32)
return rotated
def reorder_points_fixed(pts):
pts = np.array(pts)
pts_sorted_y = pts[np.argsort(pts[:, 1])]
top_two = pts_sorted_y[:2]
bottom_two = pts_sorted_y[2:]
top_left = top_two[np.argmin(top_two[:, 0])]
top_right = top_two[np.argmax(top_two[:, 0])]
bottom_left = bottom_two[np.argmin(bottom_two[:, 0])]
bottom_right = bottom_two[np.argmax(bottom_two[:, 0])]
return [top_left, top_right, bottom_right, bottom_left]
def process_obb_and_save_yolo_txt(model_path, image_dir, output_dir="./inference_results", conf_thresh=0.15, imgsz=640, max_dets=2):
image_dir = Path(image_dir)
output_dir = Path(output_dir)
labels_dir = output_dir / "labels"
labels_dir.mkdir(parents=True, exist_ok=True)
print("加载 YOLO 模型...")
model = YOLO(model_path)
print("✅ 模型加载完成")
image_files = [f for f in sorted(os.listdir(image_dir)) if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS]
if not image_files:
print(f"❌ 未找到图像文件:{image_dir}")
return
print(f"发现 {len(image_files)} 张图像,开始推理并保存 YOLO-OBB txt (最多保留 {max_dets} 个检测)...")
for img_filename in image_files:
img_path = image_dir / img_filename
stem = Path(img_filename).stem
txt_out_path = labels_dir / f"{stem}.txt"
img = cv2.imread(str(img_path))
if img is None:
print(f"❌ 无法读取图像,跳过: {img_path}")
open(txt_out_path, "w").close()
continue
H, W = img.shape[:2]
results = model(img, save=False, imgsz=imgsz, conf=conf_thresh, mode='obb')
result = results[0]
boxes = result.obb
lines_to_write = []
if boxes is None or len(boxes) == 0:
open(txt_out_path, "w").close()
print(f" 无检测: {img_filename} -> 生成空 txt")
continue
# ==================== 新增:按置信度排序,只取 top-k ====================
detections = []
for box in boxes:
try:
conf = float(box.conf.cpu().numpy())
detections.append((conf, box))
except Exception as e:
print(f"⚠ 跳过无效 box: {e}")
continue
# 按置信度降序排序
detections.sort(key=lambda x: x[0], reverse=True)
top_boxes = [box for _, box in detections[:max_dets]] # 最多取 max_dets 个
# =====================================================================
for i, box in enumerate(top_boxes):
try:
xywhr = box.xywhr.cpu().numpy()
if xywhr.ndim == 2:
xywhr = xywhr[0]
cx, cy, bw, bh, r_rad = map(float, xywhr)
except:
try:
xywh = box.xywh.cpu().numpy()
if xywh.ndim == 2:
xywh = xywh[0]
cx, cy, bw, bh = map(float, xywh)
r_rad = 0.0
except:
print(f"⚠ 无法解析 box跳过 {i}")
continue
# 归一化坐标检测(反归一化为像素)
if cx <= 1 and cy <= 1 and bw <= 1 and bh <= 1:
cx *= W
cy *= H
bw *= W
bh *= H
try:
cls = int(box.cls.cpu().numpy()[0])
except:
cls = 0
pts = rotate_box_to_corners(cx, cy, bw, bh, r_rad)
pts = reorder_points_fixed(pts)
pts_norm = []
for (x, y) in pts:
xn = min(max(x / W, 0), 1)
yn = min(max(y / H, 0), 1)
pts_norm.extend([xn, yn])
line = str(cls) + " " + " ".join(f"{v:.6f}" for v in pts_norm)
lines_to_write.append(line)
with open(txt_out_path, "w") as f:
if lines_to_write:
f.write("\n".join(lines_to_write) + "\n")
print(f"{img_filename} -> {txt_out_path} (检测 {len(lines_to_write)} 个)")
print("\n全部处理完成,结果保存在:", labels_dir)
if __name__ == "__main__":
MODEL_PATH = r'/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb_new2/weights/best.pt'
IMAGE_SOURCE_DIR = r"/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/zjdata17"
OUTPUT_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/labels"
process_obb_and_save_yolo_txt(
model_path=MODEL_PATH,
image_dir=IMAGE_SOURCE_DIR,
output_dir=OUTPUT_DIR,
conf_thresh=0.15,
imgsz=640,
max_dets=2 # 👈 关键参数:最多保留 2 个检测
)