Files
xiantiao_CV/class_xiantiao_pc/cls_quexian.py

63 lines
2.1 KiB
Python
Raw Normal View History

2025-12-16 15:12:02 +08:00
import os
import cv2
from ultralytics import YOLO
class FullImageClassifier:
"""封装 YOLO 分类模型,对整张图像进行二分类(无缺陷/有缺陷)"""
def __init__(self, model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
self.model = YOLO(model_path)
print(f"[INFO] 成功加载 YOLO 分类模型: {model_path}")
def classify(self, img_np):
"""
对整张图像进行分类返回类别 ID0 1
Args:
img_np (np.ndarray): BGR 格式的 OpenCV 图像 (H, W, C)
Returns:
int: 分类结果{0: "有缺陷", 1: "无缺陷"}
"""
# 直接 resize 整图到模型输入尺寸YOLO 默认为 224x224 或 640x640由训练决定
# Ultralytics YOLO 会自动处理 resize但显式指定更可控
resized_img = cv2.resize(img_np, (640, 640))
# 推理verbose=False 关闭进度条)
results = self.model.predict(resized_img, verbose=False)
cls = int(results[0].probs.top1) # 获取 top-1 类别索引
return cls
def cls_quexian():
# ================== 配置 ==================
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt"
img_path = "1.png"
# ================== 1. 加载模型 ==================
classifier = FullImageClassifier(model_path)
# ================== 2. 加载图像 ==================
if not os.path.exists(img_path):
raise FileNotFoundError(f"图片不存在: {img_path}")
img_np = cv2.imread(img_path)
if img_np is None:
raise ValueError("图像加载失败,可能路径错误或图像文件损坏。")
# ================== 3. 推理(整图)==================
result = classifier.classify(img_np)
# ================== 4. 输出 ==================
label_map = {0: "有缺陷", 1: "无缺陷"}
print(f"\n===== 推理结果 =====")
print(f"分类结果:{result}{label_map.get(result, '未知')}\n")
if __name__ == "__main__":
cls_quexian()