Files
xiantiao_CV/class_xiantiao_pc/class_xiantiao.py

86 lines
2.9 KiB
Python
Raw Permalink Normal View History

2025-12-16 15:12:02 +08:00
import os
import cv2
from ultralytics import YOLO
class ROIClassifier:
""" 封装 YOLO 分类模型 + ROI txt 推理 """
def __init__(self, model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
self.model = YOLO(model_path)
print(f"[INFO] 成功加载 YOLO 分类模型: {model_path}")
def load_roi(self, roi_txt_path):
"""读取 ROI txt可自动支持逗号/空格分隔"""
if not os.path.exists(roi_txt_path):
raise FileNotFoundError(f"ROI 文件不存在: {roi_txt_path}")
with open(roi_txt_path, "r") as f:
text = f.read().strip()
# 把逗号替换成空格,再 split
for ch in [",", ";"]:
text = text.replace(ch, " ")
parts = text.split()
if len(parts) != 4:
raise ValueError(f"ROI txt 格式错误应为4个数字解析得到: {parts}\n文件: {roi_txt_path}")
x1, y1, x2, y2 = map(int, parts)
return x1, y1, x2, y2
def classify(self, img_np, roi_txt_path):
"""对 ROI 区域做分类,返回 0/1"""
h, w = img_np.shape[:2]
x1, y1, x2, y2 = self.load_roi(roi_txt_path)
# -------- ROI 边界安全裁剪 --------
x1 = max(0, min(x1, w - 1))
x2 = max(0, min(x2, w - 1))
y1 = max(0, min(y1, h - 1))
y2 = max(0, min(y2, h - 1))
if x2 <= x1 or y2 <= y1:
raise ValueError(f"ROI坐标无效: {x1, y1, x2, y2}")
# -------- 1. 裁剪 ROI --------
roi_img = img_np[y1:y2, x1:x2]
# -------- 2. resize 到 640×640强制送给模型--------
roi_img = cv2.resize(roi_img, (640, 640))
# -------- 3. YOLO 分类推理 --------
results = self.model.predict(roi_img, verbose=False)
cls = int(results[0].probs.top1) # 0 或 1
return cls
def class_xiantiao():
# ================== 配置 ==================
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt"
img_path = "1.png"
roi_txt_path = "/home/hx/开发/ML_xiantiao/class_xiantiao/roi_1/1/1_rois1.txt"
# ================== 1. 加载模型 ==================
classifier = ROIClassifier(model_path)
# ================== 2. 加载图像 ==================
if not os.path.exists(img_path):
raise FileNotFoundError(f"图片不存在: {img_path}")
img_np = cv2.imread(img_path)
if img_np is None:
raise ValueError("图像加载失败,可能路径错误或图像文件损坏。")
# ================== 3. 推理 ==================
result = classifier.classify(img_np, roi_txt_path)
# ================== 4. 输出 ==================
print(f"\n===== 推理结果 =====")
print(f"ROI 文件:{roi_txt_path}")
print(f"分类结果:{result} (0=异常 / 1=正常)\n")
if __name__ == "__main__":
class_xiantiao()