Files
zjsh_yolov11/推理图片反向上传CVAT/obb/tuili_save_txt_f.py

159 lines
5.3 KiB
Python
Raw Normal View History

2025-12-11 08:37:09 +08:00
import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
def rotate_box_to_corners(cx, cy, w, h, angle_rad):
corners = np.array([
[-w / 2.0, -h / 2.0],
[ w / 2.0, -h / 2.0],
[ w / 2.0, h / 2.0],
[-w / 2.0, h / 2.0]
], dtype=np.float32)
cos_a = np.cos(angle_rad)
sin_a = np.sin(angle_rad)
R = np.array([[cos_a, -sin_a],
[sin_a, cos_a]], dtype=np.float32)
rotated = corners.dot(R.T)
rotated += np.array([cx, cy], dtype=np.float32)
return rotated
def reorder_points_fixed(pts):
pts = np.array(pts)
pts_sorted_y = pts[np.argsort(pts[:, 1])]
top_two = pts_sorted_y[:2]
bottom_two = pts_sorted_y[2:]
top_left = top_two[np.argmin(top_two[:, 0])]
top_right = top_two[np.argmax(top_two[:, 0])]
bottom_left = bottom_two[np.argmin(bottom_two[:, 0])]
bottom_right = bottom_two[np.argmax(bottom_two[:, 0])]
return [top_left, top_right, bottom_right, bottom_left]
def process_obb_and_save_yolo_txt(model_path, image_dir, output_dir="./inference_results", conf_thresh=0.15, imgsz=640, max_dets=2):
image_dir = Path(image_dir)
output_dir = Path(output_dir)
labels_dir = output_dir / "labels"
labels_dir.mkdir(parents=True, exist_ok=True)
print("加载 YOLO 模型...")
model = YOLO(model_path)
print("✅ 模型加载完成")
image_files = [f for f in sorted(os.listdir(image_dir)) if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS]
if not image_files:
print(f"❌ 未找到图像文件:{image_dir}")
return
print(f"发现 {len(image_files)} 张图像,开始推理并保存 YOLO-OBB txt (最多保留 {max_dets} 个检测)...")
for img_filename in image_files:
img_path = image_dir / img_filename
stem = Path(img_filename).stem
txt_out_path = labels_dir / f"{stem}.txt"
img = cv2.imread(str(img_path))
if img is None:
print(f"❌ 无法读取图像,跳过: {img_path}")
open(txt_out_path, "w").close()
continue
H, W = img.shape[:2]
results = model(img, save=False, imgsz=imgsz, conf=conf_thresh, mode='obb')
result = results[0]
boxes = result.obb
lines_to_write = []
if boxes is None or len(boxes) == 0:
open(txt_out_path, "w").close()
print(f" 无检测: {img_filename} -> 生成空 txt")
continue
# ==================== 新增:按置信度排序,只取 top-k ====================
detections = []
for box in boxes:
try:
conf = float(box.conf.cpu().numpy())
detections.append((conf, box))
except Exception as e:
print(f"⚠ 跳过无效 box: {e}")
continue
# 按置信度降序排序
detections.sort(key=lambda x: x[0], reverse=True)
top_boxes = [box for _, box in detections[:max_dets]] # 最多取 max_dets 个
# =====================================================================
for i, box in enumerate(top_boxes):
try:
xywhr = box.xywhr.cpu().numpy()
if xywhr.ndim == 2:
xywhr = xywhr[0]
cx, cy, bw, bh, r_rad = map(float, xywhr)
except:
try:
xywh = box.xywh.cpu().numpy()
if xywh.ndim == 2:
xywh = xywh[0]
cx, cy, bw, bh = map(float, xywh)
r_rad = 0.0
except:
print(f"⚠ 无法解析 box跳过 {i}")
continue
# 归一化坐标检测(反归一化为像素)
if cx <= 1 and cy <= 1 and bw <= 1 and bh <= 1:
cx *= W
cy *= H
bw *= W
bh *= H
try:
cls = int(box.cls.cpu().numpy()[0])
except:
cls = 0
pts = rotate_box_to_corners(cx, cy, bw, bh, r_rad)
pts = reorder_points_fixed(pts)
pts_norm = []
for (x, y) in pts:
xn = min(max(x / W, 0), 1)
yn = min(max(y / H, 0), 1)
pts_norm.extend([xn, yn])
line = str(cls) + " " + " ".join(f"{v:.6f}" for v in pts_norm)
lines_to_write.append(line)
with open(txt_out_path, "w") as f:
if lines_to_write:
f.write("\n".join(lines_to_write) + "\n")
print(f"{img_filename} -> {txt_out_path} (检测 {len(lines_to_write)} 个)")
print("\n全部处理完成,结果保存在:", labels_dir)
if __name__ == "__main__":
MODEL_PATH = r'/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb_new2/weights/best.pt'
IMAGE_SOURCE_DIR = r"/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/zjdata17"
OUTPUT_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/labels"
process_obb_and_save_yolo_txt(
model_path=MODEL_PATH,
image_dir=IMAGE_SOURCE_DIR,
output_dir=OUTPUT_DIR,
conf_thresh=0.15,
imgsz=640,
max_dets=2 # 👈 关键参数:最多保留 2 个检测
)