Files
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

135 lines
4.4 KiB
Python

import os
import cv2
from pathlib import Path
from ultralytics import YOLO
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
class ObjectDetector:
"""封装 YOLO 目标检测模型"""
def __init__(self, model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
self.model = YOLO(model_path)
print(f"[INFO] 成功加载 YOLO 目标检测模型: {model_path}")
def detect(self, img_np, conf_threshold=0.0):
"""返回所有置信度 >= conf_threshold 的检测结果"""
results = self.model.predict(img_np, conf=conf_threshold, verbose=False)
detections = []
for result in results:
boxes = result.boxes.cpu().numpy()
for box in boxes:
detection_info = {
'bbox_xyxy': box.xyxy[0], # [x1, y1, x2, y2]
'confidence': float(box.conf.item()),
'class_id': int(box.cls.item())
}
detections.append(detection_info)
return detections
def save_yolo_detect_labels_from_folder(
model_path,
image_dir,
output_dir,
conf_threshold=0.5,
label_map={0: "hole", 1: "crack"} # 可选,仅用于日志
):
"""
对 image_dir 中所有图像进行 YOLO Detect 推理,
每个类别保留最高置信度框,保存为 YOLO 格式的 .txt 标签文件。
YOLO 格式: <class_id> <cx_norm> <cy_norm> <w_norm> <h_norm>
"""
image_dir = Path(image_dir)
output_dir = Path(output_dir)
labels_dir = output_dir / "labels"
labels_dir.mkdir(parents=True, exist_ok=True)
# 获取图像列表
image_files = [
f for f in sorted(os.listdir(image_dir))
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
]
if not image_files:
print(f"❌ 未在 {image_dir} 中找到支持的图像文件")
return
print(f"共找到 {len(image_files)} 张图像,开始推理...")
detector = ObjectDetector(model_path)
for img_filename in image_files:
img_path = image_dir / img_filename
stem = Path(img_filename).stem
txt_path = labels_dir / f"{stem}.txt"
# 读图
img = cv2.imread(str(img_path))
if img is None:
print(f"⚠️ 跳过无效图像: {img_path}")
txt_path.write_text("") # 写空文件
continue
H, W = img.shape[:2]
# 推理(获取所有 ≥ conf_threshold 的框)
all_detections = detector.detect(img, conf_threshold=conf_threshold)
# 按类别保留最高置信度框
best_per_class = {}
for det in all_detections:
cls_id = det['class_id']
if cls_id not in best_per_class or det['confidence'] > best_per_class[cls_id]['confidence']:
best_per_class[cls_id] = det
top_detections = list(best_per_class.values())
# 转为 YOLO 格式并写入
lines = []
for det in top_detections:
x1, y1, x2, y2 = det['bbox_xyxy']
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
bw = x2 - x1
bh = y2 - y1
# 归一化
cx_norm = cx / W
cy_norm = cy / H
w_norm = bw / W
h_norm = bh / H
# 限制在 [0, 1]
cx_norm = max(0.0, min(1.0, cx_norm))
cy_norm = max(0.0, min(1.0, cy_norm))
w_norm = max(0.0, min(1.0, w_norm))
h_norm = max(0.0, min(1.0, h_norm))
line = f"{det['class_id']} {cx_norm:.6f} {cy_norm:.6f} {w_norm:.6f} {h_norm:.6f}"
lines.append(line)
# 写入标签文件
with open(txt_path, "w") as f:
if lines:
f.write("\n".join(lines) + "\n")
print(f"{img_filename} -> {len(lines)} 个检测框已保存")
print(f"\n🎉 全部完成!标签文件保存在: {labels_dir}")
# ------------------- 主函数调用 -------------------
if __name__ == "__main__":
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai_detect3/weights/best.pt"
IMAGE_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailaidete/train/delet"
OUTPUT_DIR = "./inference_results"
save_yolo_detect_labels_from_folder(
model_path=MODEL_PATH,
image_dir=IMAGE_DIR,
output_dir=OUTPUT_DIR,
conf_threshold=0.5
)