Files
xiantiao_CV/class_xiantiao_pc/class_xiantiao.py
琉璃月光 8506c3af79 first commit
2025-12-16 15:12:02 +08:00

86 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
from ultralytics import YOLO
class ROIClassifier:
""" 封装 YOLO 分类模型 + ROI txt 推理 """
def __init__(self, model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
self.model = YOLO(model_path)
print(f"[INFO] 成功加载 YOLO 分类模型: {model_path}")
def load_roi(self, roi_txt_path):
"""读取 ROI txt可自动支持逗号/空格分隔"""
if not os.path.exists(roi_txt_path):
raise FileNotFoundError(f"ROI 文件不存在: {roi_txt_path}")
with open(roi_txt_path, "r") as f:
text = f.read().strip()
# 把逗号替换成空格,再 split
for ch in [",", ";"]:
text = text.replace(ch, " ")
parts = text.split()
if len(parts) != 4:
raise ValueError(f"ROI txt 格式错误应为4个数字解析得到: {parts}\n文件: {roi_txt_path}")
x1, y1, x2, y2 = map(int, parts)
return x1, y1, x2, y2
def classify(self, img_np, roi_txt_path):
"""对 ROI 区域做分类,返回 0/1"""
h, w = img_np.shape[:2]
x1, y1, x2, y2 = self.load_roi(roi_txt_path)
# -------- ROI 边界安全裁剪 --------
x1 = max(0, min(x1, w - 1))
x2 = max(0, min(x2, w - 1))
y1 = max(0, min(y1, h - 1))
y2 = max(0, min(y2, h - 1))
if x2 <= x1 or y2 <= y1:
raise ValueError(f"ROI坐标无效: {x1, y1, x2, y2}")
# -------- 1. 裁剪 ROI --------
roi_img = img_np[y1:y2, x1:x2]
# -------- 2. resize 到 640×640强制送给模型--------
roi_img = cv2.resize(roi_img, (640, 640))
# -------- 3. YOLO 分类推理 --------
results = self.model.predict(roi_img, verbose=False)
cls = int(results[0].probs.top1) # 0 或 1
return cls
def class_xiantiao():
# ================== 配置 ==================
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt"
img_path = "1.png"
roi_txt_path = "/home/hx/开发/ML_xiantiao/class_xiantiao/roi_1/1/1_rois1.txt"
# ================== 1. 加载模型 ==================
classifier = ROIClassifier(model_path)
# ================== 2. 加载图像 ==================
if not os.path.exists(img_path):
raise FileNotFoundError(f"图片不存在: {img_path}")
img_np = cv2.imread(img_path)
if img_np is None:
raise ValueError("图像加载失败,可能路径错误或图像文件损坏。")
# ================== 3. 推理 ==================
result = classifier.classify(img_np, roi_txt_path)
# ================== 4. 输出 ==================
print(f"\n===== 推理结果 =====")
print(f"ROI 文件:{roi_txt_path}")
print(f"分类结果:{result} (0=异常 / 1=正常)\n")
if __name__ == "__main__":
class_xiantiao()