部署文件+resize_cls+resize_seg
This commit is contained in:
1
yemian/resize/roi_coordinates/1_rois.txt
Normal file
1
yemian/resize/roi_coordinates/1_rois.txt
Normal file
@ -0,0 +1 @@
|
||||
859,810,696,328
|
||||
299
yemian/resize/rtest.py
Normal file
299
yemian/resize/rtest.py
Normal file
@ -0,0 +1,299 @@
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
|
||||
# ====================== 配置参数 ======================
|
||||
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/seg_r/exp/weights/best.pt"
|
||||
#SOURCE_IMG_DIR = "/home/hx/yolo/output_masks" # 原始输入图像目录
|
||||
SOURCE_IMG_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/f6" # 原始输入图像目录
|
||||
OUTPUT_DIR = "/home/hx/yolo/output_masks2" # 推理输出根目录
|
||||
ROI_COORDS_FILE = "./roi_coordinates/1_rois.txt" # 必须与训练时相同
|
||||
CONF_THRESHOLD = 0.25
|
||||
IOU_THRESHOLD = 0.45
|
||||
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
SAVE_TXT = True
|
||||
SAVE_MASKS = True
|
||||
VIEW_IMG = False
|
||||
LINE_WIDTH = 2
|
||||
TARGET_SIZE = 640 # 必须与训练输入尺寸一致
|
||||
|
||||
def load_roi_coords(txt_path):
|
||||
"""
|
||||
加载 ROI 坐标 (x, y, w, h)
|
||||
支持多个 ROI(例如多视野检测)
|
||||
"""
|
||||
rois = []
|
||||
if not os.path.exists(txt_path):
|
||||
raise FileNotFoundError(f"❌ ROI 文件未找到: {txt_path}")
|
||||
with open(txt_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
try:
|
||||
x, y, w, h = map(int, line.split(','))
|
||||
rois.append((x, y, w, h))
|
||||
print(f"📌 加载 ROI: (x={x}, y={y}, w={w}, h={h})")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 无法解析 ROI 行: '{line}' | 错误: {e}")
|
||||
return rois
|
||||
|
||||
|
||||
def plot_result_with_opacity_and_roi(
|
||||
result,
|
||||
orig_img_shape,
|
||||
roi_box,
|
||||
line_width=2,
|
||||
mask_opacity=0.5
|
||||
):
|
||||
"""
|
||||
绘制分割结果,支持透明叠加,并适配 ROI → 原图坐标的还原
|
||||
:param result: YOLO 推理结果(在 cropped-resized 图像上)
|
||||
:param orig_img_shape: 原图 (H, W)
|
||||
:param roi_box: (x, y, w, h) 当前 ROI 区域
|
||||
:return: 在原图上绘制的结果图像
|
||||
"""
|
||||
h_orig, w_orig = orig_img_shape
|
||||
img = np.zeros((h_orig, w_orig, 3), dtype=np.uint8) # 创建黑底原图大小画布
|
||||
x, y, w, h = roi_box
|
||||
|
||||
# 获取模型输出的裁剪+resize后的图像上的结果
|
||||
if result.masks is not None:
|
||||
masks = result.masks.data.cpu().numpy() # (N, 640, 640)
|
||||
# 将 mask resize 回 ROI 尺寸
|
||||
resized_masks = []
|
||||
for mask in masks:
|
||||
m = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||
m = (m > 0.5).astype(np.uint8)
|
||||
resized_masks.append(m)
|
||||
resized_masks = np.array(resized_masks)
|
||||
|
||||
# 上色叠加
|
||||
overlay = img.copy()
|
||||
colors = np.random.randint(0, 255, size=(len(resized_masks), 3), dtype=np.uint8)
|
||||
for i, mask in enumerate(resized_masks):
|
||||
color = colors[i].tolist()
|
||||
roi_region = overlay[y:y+h, x:x+w]
|
||||
roi_region[mask == 1] = color
|
||||
cv2.addWeighted(overlay, mask_opacity, img, 1 - mask_opacity, 0, img)
|
||||
|
||||
# 绘制边界框和标签
|
||||
if result.boxes is not None:
|
||||
boxes = result.boxes.xyxy.cpu().numpy()
|
||||
classes = result.boxes.cls.cpu().numpy().astype(int)
|
||||
confidences = result.boxes.conf.cpu().numpy()
|
||||
|
||||
colors = np.random.randint(0, 255, size=(len(classes), 3), dtype=np.uint8)
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.6
|
||||
thickness = 1
|
||||
|
||||
for i in range(len(boxes)):
|
||||
# 映射回 ROI 像素坐标
|
||||
box = boxes[i]
|
||||
x1 = int(box[0] * w / TARGET_SIZE) + x
|
||||
y1 = int(box[1] * h / TARGET_SIZE) + y
|
||||
x2 = int(box[2] * w / TARGET_SIZE) + x
|
||||
y2 = int(box[3] * h / TARGET_SIZE) + y
|
||||
|
||||
cls_id = classes[i]
|
||||
conf = confidences[i]
|
||||
color = colors[i].tolist()
|
||||
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, line_width)
|
||||
|
||||
label = f"{cls_id} {conf:.2f}"
|
||||
(text_w, text_h), baseline = cv2.getTextSize(label, font, font_scale, thickness)
|
||||
cv2.rectangle(img, (x1, y1 - text_h - 6), (x1 + text_w, y1), color, -1)
|
||||
cv2.putText(img, label, (x1, y1 - 4), font, font_scale,
|
||||
(255, 255, 255), thickness, cv2.LINE_AA)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def run_segmentation_inference_with_roi(
|
||||
model_path,
|
||||
source,
|
||||
output_dir,
|
||||
roi_coords_file,
|
||||
conf_threshold=0.25,
|
||||
iou_threshold=0.45,
|
||||
device="cuda:0",
|
||||
save_txt=True,
|
||||
save_masks=True,
|
||||
view_img=False,
|
||||
line_width=2,
|
||||
):
|
||||
print(f"🚀 加载模型: {model_path}")
|
||||
print(f"💻 使用设备: {device}")
|
||||
|
||||
# 加载模型
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 加载 ROI
|
||||
rois = load_roi_coords(roi_coords_file)
|
||||
if len(rois) == 0:
|
||||
print("❌ 没有加载到任何有效的 ROI,程序退出。")
|
||||
return
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
txt_dir = output_dir / "labels"
|
||||
mask_dir = output_dir / "masks"
|
||||
vis_dir = output_dir / "visualize"
|
||||
|
||||
txt_dir.mkdir(parents=True, exist_ok=True)
|
||||
mask_dir.mkdir(parents=True, exist_ok=True)
|
||||
vis_dir.mkdir(parents=True, exist_ok=True) # ✅ 改这里:无条件创建!
|
||||
|
||||
# 获取图像文件
|
||||
source = Path(source)
|
||||
if source.is_file():
|
||||
img_files = [source]
|
||||
else:
|
||||
img_files = list(source.glob("*.jpg")) + \
|
||||
list(source.glob("*.jpeg")) + \
|
||||
list(source.glob("*.png")) + \
|
||||
list(source.glob("*.bmp"))
|
||||
|
||||
if not img_files:
|
||||
print(f"❌ 在 {source} 中未找到图像文件")
|
||||
return
|
||||
|
||||
print(f"🖼️ 共 {len(img_files)} 张图片待推理...")
|
||||
|
||||
# 推理循环
|
||||
for img_path in img_files:
|
||||
print(f"\n🔍 正在处理: {img_path.name}")
|
||||
orig_img = cv2.imread(str(img_path))
|
||||
if orig_img is None:
|
||||
print(f"❌ 无法读取图像: {img_path}")
|
||||
continue
|
||||
h_orig, w_orig = orig_img.shape[:2]
|
||||
|
||||
# 初始化全图级结果存储
|
||||
full_vis_img = orig_img.copy() # 可视化用
|
||||
all_txt_lines = []
|
||||
|
||||
# 对每个 ROI 进行推理
|
||||
for idx, (x, y, w, h) in enumerate(rois):
|
||||
# 越界检查
|
||||
if x < 0 or y < 0 or x + w > w_orig or y + h > h_orig:
|
||||
print(f"⚠️ ROI 越界,跳过: ({x},{y},{w},{h})")
|
||||
continue
|
||||
|
||||
# 提取 ROI 并 resize 到模型输入尺寸
|
||||
roi_img = orig_img[y:y+h, x:x+w]
|
||||
if roi_img.size == 0:
|
||||
print(f"⚠️ 空 ROI 区域: {idx}")
|
||||
continue
|
||||
|
||||
resized_img = cv2.resize(roi_img, (TARGET_SIZE, TARGET_SIZE), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# 推理
|
||||
results = model(
|
||||
source=resized_img,
|
||||
conf=conf_threshold,
|
||||
iou=iou_threshold,
|
||||
imgsz=TARGET_SIZE,
|
||||
device=device,
|
||||
verbose=False
|
||||
)
|
||||
result = results[0]
|
||||
|
||||
# 可视化:将结果映射回原图
|
||||
vis_part = plot_result_with_opacity_and_roi(
|
||||
result=result,
|
||||
orig_img_shape=(h_orig, w_orig),
|
||||
roi_box=(x, y, w, h),
|
||||
line_width=line_width,
|
||||
mask_opacity=0.5
|
||||
)
|
||||
full_vis_img = cv2.addWeighted(full_vis_img, 1.0, vis_part, 0.7, 0)
|
||||
|
||||
# 保存标签(YOLO 格式,归一化到原图)
|
||||
if save_txt and result.masks is not None:
|
||||
scale_x = w / TARGET_SIZE
|
||||
scale_y = h / TARGET_SIZE
|
||||
for i in range(len(result.boxes)):
|
||||
cls_id = int(result.boxes.cls[i])
|
||||
seg = result.masks.xy[i] # (N, 2) in 640 space
|
||||
# 映射回 ROI 像素坐标
|
||||
seg[:, 0] = seg[:, 0] * scale_x + x
|
||||
seg[:, 1] = seg[:, 1] * scale_y + y
|
||||
# 归一化到原图
|
||||
seg_norm = seg / [w_orig, h_orig]
|
||||
seg_flat = seg_norm.flatten()
|
||||
line = f"{cls_id} {' '.join(f'{val:.6f}' for val in seg_flat)}"
|
||||
all_txt_lines.append(line)
|
||||
|
||||
# 保存单个 ROI 的 mask(可选)
|
||||
if save_masks and result.masks is not None:
|
||||
mask_data = result.masks.data.cpu().numpy()
|
||||
combined_mask = (mask_data.sum(axis=0) > 0).astype(np.uint8) * 255
|
||||
# resize 回 ROI 尺寸
|
||||
roi_mask = cv2.resize(combined_mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||
full_mask = np.zeros((h_orig, w_orig), dtype=np.uint8)
|
||||
full_mask[y:y+h, x:x+w] = roi_mask
|
||||
mask_save_path = mask_dir / f"mask_{img_path.stem}_roi{idx}.png"
|
||||
cv2.imwrite(str(mask_save_path), full_mask)
|
||||
|
||||
# 保存最终可视化图像
|
||||
vis_save_path = vis_dir / f"vis_{img_path.name}"
|
||||
cv2.imwrite(str(vis_save_path), full_vis_img)
|
||||
print(f"✅ 保存可视化结果: {vis_save_path}")
|
||||
|
||||
# 保存合并的文本标签
|
||||
if save_txt and all_txt_lines:
|
||||
txt_save_path = txt_dir / f"{img_path.stem}.txt"
|
||||
with open(txt_save_path, 'w') as f:
|
||||
f.write("\n".join(all_txt_lines) + "\n")
|
||||
print(f"📝 保存标签: {txt_save_path}")
|
||||
|
||||
# 实时显示
|
||||
if view_img:
|
||||
display_img = cv2.resize(full_vis_img, (960, 540))
|
||||
cv2.imshow("Segmentation Result", display_img)
|
||||
if cv2.waitKey(0) == 27: # ESC
|
||||
cv2.destroyAllWindows()
|
||||
break
|
||||
|
||||
if view_img:
|
||||
cv2.destroyAllWindows()
|
||||
print(f"\n🎉 推理完成!结果保存在: {output_dir}")
|
||||
|
||||
|
||||
# ====================== 主程序 ======================
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default=MODEL_PATH, help="模型权重路径")
|
||||
parser.add_argument("--source", default=SOURCE_IMG_DIR, help="图片路径或文件夹")
|
||||
parser.add_argument("--output", default=OUTPUT_DIR, help="输出目录")
|
||||
parser.add_argument("--roi-file", default=ROI_COORDS_FILE, help="ROI 坐标文件路径")
|
||||
parser.add_argument("--conf", type=float, default=CONF_THRESHOLD, help="置信度阈值")
|
||||
parser.add_argument("--iou", type=float, default=IOU_THRESHOLD, help="IoU 阈值")
|
||||
parser.add_argument("--device", default=DEVICE, help="设备: cuda:0, cpu")
|
||||
parser.add_argument("--view-img", action="store_true", help="显示图像")
|
||||
parser.add_argument("--save-txt", action="store_true", help="保存标签")
|
||||
parser.add_argument("--save-masks", action="store_true", help="保存掩码")
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
run_segmentation_inference_with_roi(
|
||||
model_path=opt.model,
|
||||
source=opt.source,
|
||||
output_dir=opt.output,
|
||||
roi_coords_file=opt.roi_file,
|
||||
conf_threshold=opt.conf,
|
||||
iou_threshold=opt.iou,
|
||||
device=opt.device,
|
||||
save_txt=opt.save_txt or SAVE_TXT,
|
||||
save_masks=opt.save_masks or SAVE_MASKS,
|
||||
view_img=opt.view_img,
|
||||
line_width=LINE_WIDTH,
|
||||
)
|
||||
184
yemian/resize/trans_photo_and_labels_1.py
Normal file
184
yemian/resize/trans_photo_and_labels_1.py
Normal file
@ -0,0 +1,184 @@
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
# ==================== 配置路径 ====================
|
||||
data_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/seg/class4"
|
||||
roi_coords_file = "./roi_coordinates/1_rois.txt"
|
||||
output_images_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/seg/resize"
|
||||
output_labels_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/seg/resize"
|
||||
|
||||
target_size = 640
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(output_images_dir, exist_ok=True)
|
||||
os.makedirs(output_labels_dir, exist_ok=True)
|
||||
|
||||
print(f"📁 输出图像目录: {output_images_dir}")
|
||||
print(f"📄 输出标签目录: {output_labels_dir}")
|
||||
|
||||
# 检查输出目录是否可写
|
||||
test_write_path = os.path.join(output_images_dir, "write_test.tmp")
|
||||
try:
|
||||
with open(test_write_path, 'w') as f:
|
||||
f.write("test")
|
||||
os.remove(test_write_path)
|
||||
print("✅ 输出目录可写")
|
||||
except Exception as e:
|
||||
print(f"❌ 输出目录不可写!错误: {e}")
|
||||
exit(1)
|
||||
|
||||
def load_global_rois(txt_path):
|
||||
"""加载全局 ROI 坐标"""
|
||||
rois = []
|
||||
if not os.path.exists(txt_path):
|
||||
print(f"❌ ROI 文件不存在: {txt_path}")
|
||||
return rois
|
||||
try:
|
||||
with open(txt_path, 'r') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
x, y, w, h = map(int, line.split(','))
|
||||
rois.append((x, y, w, h))
|
||||
print(f"📌 加载 ROI #{line_num}: (x={x}, y={y}, w={w}, h={h})")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 无法解析第 {line_num} 行: '{line}', 错误: {e}")
|
||||
except Exception as e:
|
||||
print(f"❌ 读取 ROI 文件失败: {e}")
|
||||
return []
|
||||
return rois
|
||||
|
||||
# 加载 ROI
|
||||
rois = load_global_rois(roi_coords_file)
|
||||
if len(rois) == 0:
|
||||
print("❌ 没有加载到任何有效的 ROI 坐标,程序退出。")
|
||||
exit(1)
|
||||
|
||||
# ==================== 关键修改:扁平结构处理 ====================
|
||||
print(f"\n🔍 扫描数据目录: {data_dir}")
|
||||
if not os.path.exists(data_dir):
|
||||
print(f"❌ 数据目录不存在!")
|
||||
exit(1)
|
||||
|
||||
all_files = os.listdir(data_dir)
|
||||
jpg_files = [f for f in all_files if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
|
||||
txt_files = [f for f in all_files if f.lower().endswith('.txt')]
|
||||
|
||||
print(f"📦 总文件数: {len(all_files)}")
|
||||
print(f"🖼️ 图像文件: {len(jpg_files)} 个")
|
||||
print(f"📝 标签文件: {len(txt_files)} 个")
|
||||
|
||||
if len(jpg_files) == 0:
|
||||
print("❌ 未找到任何图像文件")
|
||||
exit(1)
|
||||
|
||||
processed_any = False
|
||||
|
||||
for img_file in jpg_files:
|
||||
base_name, ext = os.path.splitext(img_file)
|
||||
img_path = os.path.join(data_dir, img_file)
|
||||
|
||||
# 读取图像
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
print(f"❌ 无法读取图像: {img_path}")
|
||||
continue
|
||||
h_img, w_img = img.shape[:2]
|
||||
print(f"✅ 成功读取图像: {img_path} (尺寸: {w_img}x{h_img})")
|
||||
|
||||
# 读取标签
|
||||
label_path = os.path.join(data_dir, f"{base_name}.txt")
|
||||
labels = []
|
||||
if os.path.exists(label_path):
|
||||
try:
|
||||
with open(label_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = list(map(float, line.split()))
|
||||
class_id = int(parts[0])
|
||||
coords = np.array(parts[1:]).reshape(-1, 2)
|
||||
coords[:, 0] *= w_img
|
||||
coords[:, 1] *= h_img
|
||||
labels.append((class_id, coords))
|
||||
print(f"🏷️ 加载标签: {label_path} ({len(labels)} 个对象)")
|
||||
except Exception as e:
|
||||
print(f"❌ 读取标签失败: {label_path}, 错误: {e}")
|
||||
else:
|
||||
print(f"🟡 未找到标签文件: {label_path}")
|
||||
|
||||
# 处理每个 ROI
|
||||
for i, (x, y, w_roi, h_roi) in enumerate(rois):
|
||||
print(f"🔲 处理 ROI #{i}: (x={x}, y={y}, w={w_roi}, h={h_roi})")
|
||||
|
||||
# 检查越界
|
||||
if x < 0 or y < 0 or x + w_roi > w_img or y + h_roi > h_img:
|
||||
print(f"⚠️ ROI 越界,跳过")
|
||||
continue
|
||||
|
||||
roi_img = img[y:y+h_roi, x:x+w_roi]
|
||||
if roi_img.size == 0:
|
||||
print(f"❌ ROI 图像为空")
|
||||
continue
|
||||
|
||||
resized_img = cv2.resize(roi_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
|
||||
|
||||
# 标签处理
|
||||
new_labels = []
|
||||
scale_x = target_size / w_roi
|
||||
scale_y = target_size / h_roi
|
||||
|
||||
for class_id, poly in labels:
|
||||
shifted_poly = poly.copy() - [x, y]
|
||||
valid_mask = (shifted_poly[:, 0] >= 0) & (shifted_poly[:, 0] < w_roi) & \
|
||||
(shifted_poly[:, 1] >= 0) & (shifted_poly[:, 1] < h_roi)
|
||||
if not np.any(valid_mask):
|
||||
continue
|
||||
|
||||
scaled_poly = shifted_poly * [scale_x, scale_y]
|
||||
normalized_poly = scaled_poly / target_size
|
||||
new_labels.append((class_id, normalized_poly.flatten()))
|
||||
|
||||
# 保存图像
|
||||
suffix = f"_roi{i}" if len(rois) > 1 else ""
|
||||
save_img_name = f"{base_name}{suffix}{ext}"
|
||||
save_img_path = os.path.join(output_images_dir, save_img_name)
|
||||
|
||||
try:
|
||||
success = cv2.imwrite(save_img_path, resized_img)
|
||||
if success:
|
||||
file_size = os.path.getsize(save_img_path)
|
||||
print(f"✅ 保存图像成功: {save_img_path} ({file_size} 字节)")
|
||||
else:
|
||||
print(f"❌ cv2.imwrite 返回 False: {save_img_path}")
|
||||
except Exception as e:
|
||||
print(f"💥 保存图像异常: {save_img_path}, 错误: {e}")
|
||||
|
||||
# 保存标签
|
||||
save_label_name = f"{base_name}{suffix}.txt"
|
||||
save_label_path = os.path.join(output_labels_dir, save_label_name)
|
||||
try:
|
||||
with open(save_label_path, 'w') as f:
|
||||
for cls_id, norm_poly in new_labels:
|
||||
line = [str(cls_id)] + [f"{val:.6f}" for val in norm_poly]
|
||||
f.write(" ".join(line) + "\n")
|
||||
print(f"✅ 保存标签成功: {save_label_path} ({len(new_labels)} 行)")
|
||||
except Exception as e:
|
||||
print(f"💥 保存标签异常: {save_label_path}, 错误: {e}")
|
||||
|
||||
processed_any = True
|
||||
|
||||
# === 最终总结 ===
|
||||
print("\n" + "="*50)
|
||||
if processed_any:
|
||||
print("✅ 程序完成:已成功处理图像和标签")
|
||||
else:
|
||||
print("❌ 程序完成:但未处理任何图像")
|
||||
|
||||
print(f"📌 请检查输出目录:")
|
||||
print(f" {output_images_dir}")
|
||||
print(f" {output_labels_dir}")
|
||||
228
yemian/resize/trans_photo_and_labels_class.py
Normal file
228
yemian/resize/trans_photo_and_labels_class.py
Normal file
@ -0,0 +1,228 @@
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
# ==================== 配置路径 ====================
|
||||
original_images_parent_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/seg/yeimian_seg"
|
||||
original_labels_parent_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/seg/yeimian_seg" # 原始 YOLO 标签目录
|
||||
roi_coords_file = "./roi_coordinates/1_rois.txt"
|
||||
output_images_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/seg/resize"
|
||||
output_labels_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/seg/resize" # 新标签输出目录
|
||||
target_size = 640
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(output_images_dir, exist_ok=True)
|
||||
os.makedirs(output_labels_dir, exist_ok=True)
|
||||
|
||||
print(f"📁 输出图像目录: {output_images_dir}")
|
||||
print(f"📄 输出标签目录: {output_labels_dir}")
|
||||
|
||||
# 检查输出目录是否可写
|
||||
test_write_path = os.path.join(output_images_dir, "write_test.tmp")
|
||||
try:
|
||||
with open(test_write_path, 'w') as f:
|
||||
f.write("test")
|
||||
os.remove(test_write_path)
|
||||
print("✅ 输出目录可写")
|
||||
except Exception as e:
|
||||
print(f"❌ 输出目录不可写!错误: {e}")
|
||||
exit(1)
|
||||
|
||||
def load_global_rois(txt_path):
|
||||
"""加载全局 ROI 坐标"""
|
||||
rois = []
|
||||
if not os.path.exists(txt_path):
|
||||
print(f"❌ ROI 文件不存在: {txt_path}")
|
||||
return rois
|
||||
try:
|
||||
with open(txt_path, 'r') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
x, y, w, h = map(int, line.split(','))
|
||||
rois.append((x, y, w, h))
|
||||
print(f"📌 加载 ROI #{line_num}: (x={x}, y={y}, w={w}, h={h})")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 无法解析第 {line_num} 行: '{line}', 错误: {e}")
|
||||
except Exception as e:
|
||||
print(f"❌ 读取 ROI 文件失败: {e}")
|
||||
return []
|
||||
return rois
|
||||
|
||||
# 加载 ROI
|
||||
rois = load_global_rois(roi_coords_file)
|
||||
if len(rois) == 0:
|
||||
print("❌ 没有加载到任何有效的 ROI 坐标,程序退出。")
|
||||
exit(1)
|
||||
|
||||
# 扫描原始父目录
|
||||
print(f"\n🔍 扫描原始图像父目录: {original_images_parent_dir}")
|
||||
if not os.path.exists(original_images_parent_dir):
|
||||
print(f"❌ 原始图像父目录不存在!")
|
||||
exit(1)
|
||||
|
||||
all_items = os.listdir(original_images_parent_dir)
|
||||
print(f"📦 内容: {all_items}")
|
||||
|
||||
# 过滤出有效类别(跳过 _cropped 或非目录)
|
||||
valid_classes = []
|
||||
for item in all_items:
|
||||
item_path = os.path.join(original_images_parent_dir, item)
|
||||
if os.path.isdir(item_path) and not item.endswith('_cropped'):
|
||||
valid_classes.append(item)
|
||||
|
||||
if len(valid_classes) == 0:
|
||||
print("❌ 未找到任何有效的原始类别目录(排除 _cropped)")
|
||||
exit(1)
|
||||
|
||||
print(f"🎯 发现 {len(valid_classes)} 个有效类别: {valid_classes}\n")
|
||||
|
||||
processed_any = False # 记录是否处理了任何图像
|
||||
|
||||
for class_name in valid_classes:
|
||||
class_img_dir = os.path.join(original_images_parent_dir, class_name)
|
||||
out_img_dir = os.path.join(output_images_dir, class_name)
|
||||
out_label_dir = os.path.join(output_labels_dir, class_name)
|
||||
|
||||
os.makedirs(out_img_dir, exist_ok=True)
|
||||
os.makedirs(out_label_dir, exist_ok=True)
|
||||
|
||||
print(f"🔄 处理类别: {class_name}")
|
||||
print(f" 📂 图像源: {class_img_dir}")
|
||||
print(f" 💾 图像输出: {out_img_dir}")
|
||||
print(f" 🏷️ 标签输出: {out_label_dir}")
|
||||
|
||||
# 检查图像目录
|
||||
if not os.path.exists(class_img_dir):
|
||||
print(f" ❌ 图像目录不存在")
|
||||
continue
|
||||
|
||||
img_files = [f for f in os.listdir(class_img_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
|
||||
print(f" 🖼️ 找到 {len(img_files)} 个图像文件")
|
||||
|
||||
if len(img_files) == 0:
|
||||
print(f" ⚠️ 该类别下无图像文件")
|
||||
continue
|
||||
|
||||
for img_file in img_files:
|
||||
base_name, ext = os.path.splitext(img_file)
|
||||
img_path = os.path.join(class_img_dir, img_file)
|
||||
|
||||
# 读取图像
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
print(f" ❌ 无法读取图像: {img_path}")
|
||||
continue
|
||||
h_img, w_img = img.shape[:2]
|
||||
print(f" ✅ 成功读取图像: {img_path} (尺寸: {w_img}x{h_img})")
|
||||
|
||||
# 读取标签
|
||||
label_path = os.path.join(original_labels_parent_dir, class_name, f"{base_name}.txt")
|
||||
labels = []
|
||||
if os.path.exists(label_path):
|
||||
try:
|
||||
with open(label_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = list(map(float, line.split()))
|
||||
class_id = int(parts[0])
|
||||
coords = np.array(parts[1:]).reshape(-1, 2)
|
||||
coords[:, 0] *= w_img
|
||||
coords[:, 1] *= h_img
|
||||
labels.append((class_id, coords))
|
||||
print(f" 🏷️ 加载标签: {label_path} ({len(labels)} 个对象)")
|
||||
except Exception as e:
|
||||
print(f" ❌ 读取标签失败: {label_path}, 错误: {e}")
|
||||
else:
|
||||
print(f" 🟡 未找到标签文件: {label_path}")
|
||||
|
||||
# 处理每个 ROI
|
||||
for i, (x, y, w_roi, h_roi) in enumerate(rois):
|
||||
print(f" 🔲 处理 ROI #{i}: (x={x}, y={y}, w={w_roi}, h={h_roi})")
|
||||
|
||||
# 检查越界
|
||||
if x < 0 or y < 0 or x + w_roi > w_img or y + h_roi > h_img:
|
||||
print(f" ⚠️ ROI 越界,跳过")
|
||||
continue
|
||||
|
||||
roi_img = img[y:y+h_roi, x:x+w_roi]
|
||||
if roi_img.size == 0:
|
||||
print(f" ❌ ROI 图像为空")
|
||||
continue
|
||||
|
||||
resized_img = cv2.resize(roi_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
|
||||
|
||||
# 标签处理
|
||||
new_labels = []
|
||||
scale_x = target_size / w_roi
|
||||
scale_y = target_size / h_roi
|
||||
|
||||
for class_id, poly in labels:
|
||||
shifted_poly = poly.copy() - [x, y]
|
||||
valid_mask = (shifted_poly[:, 0] >= 0) & (shifted_poly[:, 0] < w_roi) & \
|
||||
(shifted_poly[:, 1] >= 0) & (shifted_poly[:, 1] < h_roi)
|
||||
if not np.any(valid_mask):
|
||||
continue
|
||||
|
||||
scaled_poly = shifted_poly * [scale_x, scale_y]
|
||||
normalized_poly = scaled_poly / target_size
|
||||
new_labels.append((class_id, normalized_poly.flatten()))
|
||||
|
||||
# 保存图像
|
||||
suffix = f"_roi{i}" if len(rois) > 1 else ""
|
||||
save_img_name = f"{base_name}{suffix}{ext}"
|
||||
save_img_path = os.path.join(out_img_dir, save_img_name)
|
||||
|
||||
try:
|
||||
success = cv2.imwrite(save_img_path, resized_img)
|
||||
if success:
|
||||
file_size = os.path.getsize(save_img_path)
|
||||
print(f" ✅ 保存图像成功: {save_img_path} ({file_size} 字节)")
|
||||
else:
|
||||
print(f" ❌ cv2.imwrite 返回 False: {save_img_path}")
|
||||
except Exception as e:
|
||||
print(f" 💥 保存图像异常: {save_img_path}, 错误: {e}")
|
||||
|
||||
# 保存标签
|
||||
save_label_name = f"{base_name}{suffix}.txt"
|
||||
save_label_path = os.path.join(out_label_dir, save_label_name)
|
||||
try:
|
||||
with open(save_label_path, 'w') as f:
|
||||
for cls_id, norm_poly in new_labels:
|
||||
line = [str(cls_id)] + [f"{val:.6f}" for val in norm_poly]
|
||||
f.write(" ".join(line) + "\n")
|
||||
print(f" ✅ 保存标签成功: {save_label_path} ({len(new_labels)} 行)")
|
||||
except Exception as e:
|
||||
print(f" 💥 保存标签异常: {save_label_path}, 错误: {e}")
|
||||
|
||||
processed_any = True
|
||||
|
||||
# === 强制测试写入能力 ===
|
||||
print("\n" + "="*50)
|
||||
print("🔧 强制测试:尝试创建一张测试图")
|
||||
test_img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
|
||||
test_out = os.path.join(output_images_dir, "DEBUG_TEST_WRITE.jpg")
|
||||
try:
|
||||
success = cv2.imwrite(test_out, test_img)
|
||||
if success and os.path.exists(test_out):
|
||||
sz = os.path.getsize(test_out)
|
||||
print(f"🎉 成功写入测试图像: {test_out} ({sz} 字节)")
|
||||
else:
|
||||
print(f"❌ 写入失败: cv2.imwrite 返回 {success}, 文件存在: {os.path.exists(test_out)}")
|
||||
except Exception as e:
|
||||
print(f"💥 异常: {e}")
|
||||
|
||||
# === 最终总结 ===
|
||||
print("\n" + "="*50)
|
||||
if processed_any:
|
||||
print("✅ 程序完成:已尝试处理图像和标签")
|
||||
else:
|
||||
print("❌ 程序完成:但未处理任何图像,请检查路径和文件格式")
|
||||
|
||||
print(f"📌 请检查输出目录:")
|
||||
print(f" {output_images_dir}")
|
||||
print(f" {output_labels_dir}")
|
||||
Reference in New Issue
Block a user