135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
import os
|
|
import cv2
|
|
from pathlib import Path
|
|
from ultralytics import YOLO
|
|
|
|
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
|
|
|
|
|
|
class ObjectDetector:
|
|
"""封装 YOLO 目标检测模型"""
|
|
|
|
def __init__(self, model_path):
|
|
if not os.path.exists(model_path):
|
|
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
|
self.model = YOLO(model_path)
|
|
print(f"[INFO] 成功加载 YOLO 目标检测模型: {model_path}")
|
|
|
|
def detect(self, img_np, conf_threshold=0.0):
|
|
"""返回所有置信度 >= conf_threshold 的检测结果"""
|
|
results = self.model.predict(img_np, conf=conf_threshold, verbose=False)
|
|
detections = []
|
|
for result in results:
|
|
boxes = result.boxes.cpu().numpy()
|
|
for box in boxes:
|
|
detection_info = {
|
|
'bbox_xyxy': box.xyxy[0], # [x1, y1, x2, y2]
|
|
'confidence': float(box.conf.item()),
|
|
'class_id': int(box.cls.item())
|
|
}
|
|
detections.append(detection_info)
|
|
return detections
|
|
|
|
|
|
def save_yolo_detect_labels_from_folder(
|
|
model_path,
|
|
image_dir,
|
|
output_dir,
|
|
conf_threshold=0.5,
|
|
label_map={0: "hole", 1: "crack"} # 可选,仅用于日志
|
|
):
|
|
"""
|
|
对 image_dir 中所有图像进行 YOLO Detect 推理,
|
|
每个类别保留最高置信度框,保存为 YOLO 格式的 .txt 标签文件。
|
|
|
|
YOLO 格式: <class_id> <cx_norm> <cy_norm> <w_norm> <h_norm>
|
|
"""
|
|
image_dir = Path(image_dir)
|
|
output_dir = Path(output_dir)
|
|
labels_dir = output_dir / "labels"
|
|
labels_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# 获取图像列表
|
|
image_files = [
|
|
f for f in sorted(os.listdir(image_dir))
|
|
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
|
|
]
|
|
if not image_files:
|
|
print(f"❌ 未在 {image_dir} 中找到支持的图像文件")
|
|
return
|
|
|
|
print(f"共找到 {len(image_files)} 张图像,开始推理...")
|
|
detector = ObjectDetector(model_path)
|
|
|
|
for img_filename in image_files:
|
|
img_path = image_dir / img_filename
|
|
stem = Path(img_filename).stem
|
|
txt_path = labels_dir / f"{stem}.txt"
|
|
|
|
# 读图
|
|
img = cv2.imread(str(img_path))
|
|
if img is None:
|
|
print(f"⚠️ 跳过无效图像: {img_path}")
|
|
txt_path.write_text("") # 写空文件
|
|
continue
|
|
|
|
H, W = img.shape[:2]
|
|
|
|
# 推理(获取所有 ≥ conf_threshold 的框)
|
|
all_detections = detector.detect(img, conf_threshold=conf_threshold)
|
|
|
|
# 按类别保留最高置信度框
|
|
best_per_class = {}
|
|
for det in all_detections:
|
|
cls_id = det['class_id']
|
|
if cls_id not in best_per_class or det['confidence'] > best_per_class[cls_id]['confidence']:
|
|
best_per_class[cls_id] = det
|
|
|
|
top_detections = list(best_per_class.values())
|
|
|
|
# 转为 YOLO 格式并写入
|
|
lines = []
|
|
for det in top_detections:
|
|
x1, y1, x2, y2 = det['bbox_xyxy']
|
|
cx = (x1 + x2) / 2.0
|
|
cy = (y1 + y2) / 2.0
|
|
bw = x2 - x1
|
|
bh = y2 - y1
|
|
|
|
# 归一化
|
|
cx_norm = cx / W
|
|
cy_norm = cy / H
|
|
w_norm = bw / W
|
|
h_norm = bh / H
|
|
|
|
# 限制在 [0, 1]
|
|
cx_norm = max(0.0, min(1.0, cx_norm))
|
|
cy_norm = max(0.0, min(1.0, cy_norm))
|
|
w_norm = max(0.0, min(1.0, w_norm))
|
|
h_norm = max(0.0, min(1.0, h_norm))
|
|
|
|
line = f"{det['class_id']} {cx_norm:.6f} {cy_norm:.6f} {w_norm:.6f} {h_norm:.6f}"
|
|
lines.append(line)
|
|
|
|
# 写入标签文件
|
|
with open(txt_path, "w") as f:
|
|
if lines:
|
|
f.write("\n".join(lines) + "\n")
|
|
|
|
print(f"✅ {img_filename} -> {len(lines)} 个检测框已保存")
|
|
|
|
print(f"\n🎉 全部完成!标签文件保存在: {labels_dir}")
|
|
|
|
|
|
# ------------------- 主函数调用 -------------------
|
|
if __name__ == "__main__":
|
|
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai_detect3/weights/best.pt"
|
|
IMAGE_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/ailaidete/train/delet"
|
|
OUTPUT_DIR = "./inference_results"
|
|
|
|
save_yolo_detect_labels_from_folder(
|
|
model_path=MODEL_PATH,
|
|
image_dir=IMAGE_DIR,
|
|
output_dir=OUTPUT_DIR,
|
|
conf_threshold=0.5
|
|
) |