159 lines
5.3 KiB
Python
159 lines
5.3 KiB
Python
import os
|
||
import cv2
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from ultralytics import YOLO
|
||
|
||
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
|
||
|
||
|
||
def rotate_box_to_corners(cx, cy, w, h, angle_rad):
|
||
corners = np.array([
|
||
[-w / 2.0, -h / 2.0],
|
||
[ w / 2.0, -h / 2.0],
|
||
[ w / 2.0, h / 2.0],
|
||
[-w / 2.0, h / 2.0]
|
||
], dtype=np.float32)
|
||
|
||
cos_a = np.cos(angle_rad)
|
||
sin_a = np.sin(angle_rad)
|
||
R = np.array([[cos_a, -sin_a],
|
||
[sin_a, cos_a]], dtype=np.float32)
|
||
|
||
rotated = corners.dot(R.T)
|
||
rotated += np.array([cx, cy], dtype=np.float32)
|
||
return rotated
|
||
|
||
|
||
def reorder_points_fixed(pts):
|
||
pts = np.array(pts)
|
||
pts_sorted_y = pts[np.argsort(pts[:, 1])]
|
||
top_two = pts_sorted_y[:2]
|
||
bottom_two = pts_sorted_y[2:]
|
||
top_left = top_two[np.argmin(top_two[:, 0])]
|
||
top_right = top_two[np.argmax(top_two[:, 0])]
|
||
bottom_left = bottom_two[np.argmin(bottom_two[:, 0])]
|
||
bottom_right = bottom_two[np.argmax(bottom_two[:, 0])]
|
||
return [top_left, top_right, bottom_right, bottom_left]
|
||
|
||
|
||
def process_obb_and_save_yolo_txt(model_path, image_dir, output_dir="./inference_results", conf_thresh=0.15, imgsz=640, max_dets=2):
|
||
|
||
image_dir = Path(image_dir)
|
||
output_dir = Path(output_dir)
|
||
labels_dir = output_dir / "labels"
|
||
labels_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print("加载 YOLO 模型...")
|
||
model = YOLO(model_path)
|
||
print("✅ 模型加载完成")
|
||
|
||
image_files = [f for f in sorted(os.listdir(image_dir)) if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS]
|
||
if not image_files:
|
||
print(f"❌ 未找到图像文件:{image_dir}")
|
||
return
|
||
|
||
print(f"发现 {len(image_files)} 张图像,开始推理并保存 YOLO-OBB txt (最多保留 {max_dets} 个检测)...")
|
||
|
||
for img_filename in image_files:
|
||
img_path = image_dir / img_filename
|
||
stem = Path(img_filename).stem
|
||
txt_out_path = labels_dir / f"{stem}.txt"
|
||
|
||
img = cv2.imread(str(img_path))
|
||
if img is None:
|
||
print(f"❌ 无法读取图像,跳过: {img_path}")
|
||
open(txt_out_path, "w").close()
|
||
continue
|
||
|
||
H, W = img.shape[:2]
|
||
|
||
results = model(img, save=False, imgsz=imgsz, conf=conf_thresh, mode='obb')
|
||
result = results[0]
|
||
|
||
boxes = result.obb
|
||
lines_to_write = []
|
||
|
||
if boxes is None or len(boxes) == 0:
|
||
open(txt_out_path, "w").close()
|
||
print(f"ℹ 无检测: {img_filename} -> 生成空 txt")
|
||
continue
|
||
|
||
# ==================== 新增:按置信度排序,只取 top-k ====================
|
||
detections = []
|
||
for box in boxes:
|
||
try:
|
||
conf = float(box.conf.cpu().numpy())
|
||
detections.append((conf, box))
|
||
except Exception as e:
|
||
print(f"⚠ 跳过无效 box: {e}")
|
||
continue
|
||
|
||
# 按置信度降序排序
|
||
detections.sort(key=lambda x: x[0], reverse=True)
|
||
top_boxes = [box for _, box in detections[:max_dets]] # 最多取 max_dets 个
|
||
# =====================================================================
|
||
|
||
for i, box in enumerate(top_boxes):
|
||
try:
|
||
xywhr = box.xywhr.cpu().numpy()
|
||
if xywhr.ndim == 2:
|
||
xywhr = xywhr[0]
|
||
cx, cy, bw, bh, r_rad = map(float, xywhr)
|
||
except:
|
||
try:
|
||
xywh = box.xywh.cpu().numpy()
|
||
if xywh.ndim == 2:
|
||
xywh = xywh[0]
|
||
cx, cy, bw, bh = map(float, xywh)
|
||
r_rad = 0.0
|
||
except:
|
||
print(f"⚠ 无法解析 box,跳过 {i}")
|
||
continue
|
||
|
||
# 归一化坐标检测(反归一化为像素)
|
||
if cx <= 1 and cy <= 1 and bw <= 1 and bh <= 1:
|
||
cx *= W
|
||
cy *= H
|
||
bw *= W
|
||
bh *= H
|
||
|
||
try:
|
||
cls = int(box.cls.cpu().numpy()[0])
|
||
except:
|
||
cls = 0
|
||
|
||
pts = rotate_box_to_corners(cx, cy, bw, bh, r_rad)
|
||
pts = reorder_points_fixed(pts)
|
||
|
||
pts_norm = []
|
||
for (x, y) in pts:
|
||
xn = min(max(x / W, 0), 1)
|
||
yn = min(max(y / H, 0), 1)
|
||
pts_norm.extend([xn, yn])
|
||
|
||
line = str(cls) + " " + " ".join(f"{v:.6f}" for v in pts_norm)
|
||
lines_to_write.append(line)
|
||
|
||
with open(txt_out_path, "w") as f:
|
||
if lines_to_write:
|
||
f.write("\n".join(lines_to_write) + "\n")
|
||
|
||
print(f"✅ {img_filename} -> {txt_out_path} (检测 {len(lines_to_write)} 个)")
|
||
|
||
print("\n全部处理完成,结果保存在:", labels_dir)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
MODEL_PATH = r'/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb_new2/weights/best.pt'
|
||
IMAGE_SOURCE_DIR = r"/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/zjdata17"
|
||
OUTPUT_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/labels"
|
||
|
||
process_obb_and_save_yolo_txt(
|
||
model_path=MODEL_PATH,
|
||
image_dir=IMAGE_SOURCE_DIR,
|
||
output_dir=OUTPUT_DIR,
|
||
conf_thresh=0.15,
|
||
imgsz=640,
|
||
max_dets=2 # 👈 关键参数:最多保留 2 个检测
|
||
) |