63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
|
|
import os
|
|||
|
|
import cv2
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FullImageClassifier:
|
|||
|
|
"""封装 YOLO 分类模型,对整张图像进行二分类(无缺陷/有缺陷)"""
|
|||
|
|
|
|||
|
|
def __init__(self, model_path):
|
|||
|
|
if not os.path.exists(model_path):
|
|||
|
|
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
|||
|
|
|
|||
|
|
self.model = YOLO(model_path)
|
|||
|
|
print(f"[INFO] 成功加载 YOLO 分类模型: {model_path}")
|
|||
|
|
|
|||
|
|
def classify(self, img_np):
|
|||
|
|
"""
|
|||
|
|
对整张图像进行分类,返回类别 ID(0 或 1)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
img_np (np.ndarray): BGR 格式的 OpenCV 图像 (H, W, C)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
int: 分类结果,{0: "有缺陷", 1: "无缺陷"}
|
|||
|
|
"""
|
|||
|
|
# 直接 resize 整图到模型输入尺寸(YOLO 默认为 224x224 或 640x640,由训练决定)
|
|||
|
|
# Ultralytics YOLO 会自动处理 resize,但显式指定更可控
|
|||
|
|
resized_img = cv2.resize(img_np, (640, 640))
|
|||
|
|
|
|||
|
|
# 推理(verbose=False 关闭进度条)
|
|||
|
|
results = self.model.predict(resized_img, verbose=False)
|
|||
|
|
|
|||
|
|
cls = int(results[0].probs.top1) # 获取 top-1 类别索引
|
|||
|
|
return cls
|
|||
|
|
|
|||
|
|
|
|||
|
|
def cls_quexian():
|
|||
|
|
# ================== 配置 ==================
|
|||
|
|
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt"
|
|||
|
|
img_path = "1.png"
|
|||
|
|
|
|||
|
|
# ================== 1. 加载模型 ==================
|
|||
|
|
classifier = FullImageClassifier(model_path)
|
|||
|
|
|
|||
|
|
# ================== 2. 加载图像 ==================
|
|||
|
|
if not os.path.exists(img_path):
|
|||
|
|
raise FileNotFoundError(f"图片不存在: {img_path}")
|
|||
|
|
|
|||
|
|
img_np = cv2.imread(img_path)
|
|||
|
|
if img_np is None:
|
|||
|
|
raise ValueError("图像加载失败,可能路径错误或图像文件损坏。")
|
|||
|
|
|
|||
|
|
# ================== 3. 推理(整图)==================
|
|||
|
|
result = classifier.classify(img_np)
|
|||
|
|
|
|||
|
|
# ================== 4. 输出 ==================
|
|||
|
|
label_map = {0: "有缺陷", 1: "无缺陷"}
|
|||
|
|
print(f"\n===== 推理结果 =====")
|
|||
|
|
print(f"分类结果:{result} → {label_map.get(result, '未知')}\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
cls_quexian()
|