63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
import os
|
||
import cv2
|
||
from ultralytics import YOLO
|
||
|
||
|
||
class FullImageClassifier:
|
||
"""封装 YOLO 分类模型,对整张图像进行二分类(无缺陷/有缺陷)"""
|
||
|
||
def __init__(self, model_path):
|
||
if not os.path.exists(model_path):
|
||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||
|
||
self.model = YOLO(model_path)
|
||
print(f"[INFO] 成功加载 YOLO 分类模型: {model_path}")
|
||
|
||
def classify(self, img_np):
|
||
"""
|
||
对整张图像进行分类,返回类别 ID(0 或 1)
|
||
|
||
Args:
|
||
img_np (np.ndarray): BGR 格式的 OpenCV 图像 (H, W, C)
|
||
|
||
Returns:
|
||
int: 分类结果,{0: "有缺陷", 1: "无缺陷"}
|
||
"""
|
||
# 直接 resize 整图到模型输入尺寸(YOLO 默认为 224x224 或 640x640,由训练决定)
|
||
# Ultralytics YOLO 会自动处理 resize,但显式指定更可控
|
||
resized_img = cv2.resize(img_np, (640, 640))
|
||
|
||
# 推理(verbose=False 关闭进度条)
|
||
results = self.model.predict(resized_img, verbose=False)
|
||
|
||
cls = int(results[0].probs.top1) # 获取 top-1 类别索引
|
||
return cls
|
||
|
||
|
||
def cls_quexian():
|
||
# ================== 配置 ==================
|
||
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt"
|
||
img_path = "1.png"
|
||
|
||
# ================== 1. 加载模型 ==================
|
||
classifier = FullImageClassifier(model_path)
|
||
|
||
# ================== 2. 加载图像 ==================
|
||
if not os.path.exists(img_path):
|
||
raise FileNotFoundError(f"图片不存在: {img_path}")
|
||
|
||
img_np = cv2.imread(img_path)
|
||
if img_np is None:
|
||
raise ValueError("图像加载失败,可能路径错误或图像文件损坏。")
|
||
|
||
# ================== 3. 推理(整图)==================
|
||
result = classifier.classify(img_np)
|
||
|
||
# ================== 4. 输出 ==================
|
||
label_map = {0: "有缺陷", 1: "无缺陷"}
|
||
print(f"\n===== 推理结果 =====")
|
||
print(f"分类结果:{result} → {label_map.get(result, '未知')}\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
cls_quexian() |