86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
import os
|
||
import cv2
|
||
from ultralytics import YOLO
|
||
|
||
|
||
class ROIClassifier:
|
||
""" 封装 YOLO 分类模型 + ROI txt 推理 """
|
||
|
||
def __init__(self, model_path):
|
||
if not os.path.exists(model_path):
|
||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||
|
||
self.model = YOLO(model_path)
|
||
print(f"[INFO] 成功加载 YOLO 分类模型: {model_path}")
|
||
|
||
def load_roi(self, roi_txt_path):
|
||
"""读取 ROI txt,可自动支持逗号/空格分隔"""
|
||
if not os.path.exists(roi_txt_path):
|
||
raise FileNotFoundError(f"ROI 文件不存在: {roi_txt_path}")
|
||
with open(roi_txt_path, "r") as f:
|
||
text = f.read().strip()
|
||
# 把逗号替换成空格,再 split
|
||
for ch in [",", ";"]:
|
||
text = text.replace(ch, " ")
|
||
parts = text.split()
|
||
if len(parts) != 4:
|
||
raise ValueError(f"ROI txt 格式错误,应为4个数字,解析得到: {parts}\n文件: {roi_txt_path}")
|
||
x1, y1, x2, y2 = map(int, parts)
|
||
return x1, y1, x2, y2
|
||
|
||
def classify(self, img_np, roi_txt_path):
|
||
"""对 ROI 区域做分类,返回 0/1"""
|
||
h, w = img_np.shape[:2]
|
||
x1, y1, x2, y2 = self.load_roi(roi_txt_path)
|
||
|
||
# -------- ROI 边界安全裁剪 --------
|
||
x1 = max(0, min(x1, w - 1))
|
||
x2 = max(0, min(x2, w - 1))
|
||
y1 = max(0, min(y1, h - 1))
|
||
y2 = max(0, min(y2, h - 1))
|
||
|
||
if x2 <= x1 or y2 <= y1:
|
||
raise ValueError(f"ROI坐标无效: {x1, y1, x2, y2}")
|
||
|
||
# -------- 1. 裁剪 ROI --------
|
||
roi_img = img_np[y1:y2, x1:x2]
|
||
|
||
# -------- 2. resize 到 640×640(强制送给模型)--------
|
||
roi_img = cv2.resize(roi_img, (640, 640))
|
||
|
||
# -------- 3. YOLO 分类推理 --------
|
||
results = self.model.predict(roi_img, verbose=False)
|
||
|
||
cls = int(results[0].probs.top1) # 0 或 1
|
||
return cls
|
||
|
||
|
||
def class_xiantiao():
|
||
# ================== 配置 ==================
|
||
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt"
|
||
img_path = "1.png"
|
||
roi_txt_path = "/home/hx/开发/ML_xiantiao/class_xiantiao/roi_1/1/1_rois1.txt"
|
||
|
||
# ================== 1. 加载模型 ==================
|
||
classifier = ROIClassifier(model_path)
|
||
|
||
# ================== 2. 加载图像 ==================
|
||
if not os.path.exists(img_path):
|
||
raise FileNotFoundError(f"图片不存在: {img_path}")
|
||
|
||
img_np = cv2.imread(img_path)
|
||
if img_np is None:
|
||
raise ValueError("图像加载失败,可能路径错误或图像文件损坏。")
|
||
|
||
# ================== 3. 推理 ==================
|
||
result = classifier.classify(img_np, roi_txt_path)
|
||
|
||
# ================== 4. 输出 ==================
|
||
print(f"\n===== 推理结果 =====")
|
||
print(f"ROI 文件:{roi_txt_path}")
|
||
print(f"分类结果:{result} (0=异常 / 1=正常)\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
class_xiantiao()
|