Files
xiantiao_CV/class_xiantiao_pc/cls_quexian.py
琉璃月光 8506c3af79 first commit
2025-12-16 15:12:02 +08:00

63 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
from ultralytics import YOLO
class FullImageClassifier:
"""封装 YOLO 分类模型,对整张图像进行二分类(无缺陷/有缺陷)"""
def __init__(self, model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
self.model = YOLO(model_path)
print(f"[INFO] 成功加载 YOLO 分类模型: {model_path}")
def classify(self, img_np):
"""
对整张图像进行分类,返回类别 ID0 或 1
Args:
img_np (np.ndarray): BGR 格式的 OpenCV 图像 (H, W, C)
Returns:
int: 分类结果,{0: "有缺陷", 1: "无缺陷"}
"""
# 直接 resize 整图到模型输入尺寸YOLO 默认为 224x224 或 640x640由训练决定
# Ultralytics YOLO 会自动处理 resize但显式指定更可控
resized_img = cv2.resize(img_np, (640, 640))
# 推理verbose=False 关闭进度条)
results = self.model.predict(resized_img, verbose=False)
cls = int(results[0].probs.top1) # 获取 top-1 类别索引
return cls
def cls_quexian():
# ================== 配置 ==================
model_path = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/cls/exp_xiantiao_cls/weights/best.pt"
img_path = "1.png"
# ================== 1. 加载模型 ==================
classifier = FullImageClassifier(model_path)
# ================== 2. 加载图像 ==================
if not os.path.exists(img_path):
raise FileNotFoundError(f"图片不存在: {img_path}")
img_np = cv2.imread(img_path)
if img_np is None:
raise ValueError("图像加载失败,可能路径错误或图像文件损坏。")
# ================== 3. 推理(整图)==================
result = classifier.classify(img_np)
# ================== 4. 输出 ==================
label_map = {0: "有缺陷", 1: "无缺陷"}
print(f"\n===== 推理结果 =====")
print(f"分类结果:{result}{label_map.get(result, '未知')}\n")
if __name__ == "__main__":
cls_quexian()