Files
zjsh_yolov11/test_cuda.py

100 lines
3.2 KiB
Python
Raw Permalink Normal View History

2025-08-18 15:13:54 +08:00
# test_cuda.py
"""
CUDA 环境全面检测脚本
Author: Qwen
功能检测 NVIDIA 驱动CUDAPyTorch GPU 是否正常工作
"""
import torch
import sys
import os
def print_separator():
print("\n" + "=" * 60 + "\n")
def check_nvidia_smi():
print("🔍 正在检查 nvidia-smiNVIDIA 驱动和 GPU 状态)...")
try:
result = os.popen("nvidia-smi").read()
print(result)
except Exception as e:
print(f"❌ 执行 nvidia-smi 失败: {e}")
def check_torch_cuda():
print("🔍 正在检查 PyTorch 与 CUDA 的集成...")
print(f"➡️ PyTorch 版本: {torch.__version__}")
# 检查 CUDA 是否可用
is_available = torch.cuda.is_available()
print(f"➡️ torch.cuda.is_available(): {is_available}")
if is_available:
print(f"➡️ CUDA 版本 (PyTorch 使用): {torch.version.cuda}")
print(f"➡️ GPU 数量: {torch.cuda.device_count()}")
print(f"➡️ 当前设备: {torch.cuda.current_device()}")
print(f"➡️ 设备名称: {torch.cuda.get_device_name(0)}")
# 尝试创建一个 GPU 张量
try:
x = torch.randn(3, 3).cuda()
print("✅ GPU 张量创建成功CUDA 可用且正常工作。")
except Exception as e:
print(f"❌ GPU 张量创建失败: {e}")
else:
print("❌ CUDA 不可用PyTorch 无法使用 GPU。")
print("💡 常见原因:")
print(" - 没有安装 NVIDIA 驱动")
print(" - PyTorch 安装的是 CPU 版本")
print(" - CUDA Toolkit 与 PyTorch 不匹配")
print(" - 环境变量 CUDA_VISIBLE_DEVICES 设置错误")
def check_environment():
print("🔍 检查环境变量...")
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
print(f"➡️ CUDA_VISIBLE_DEVICES: {cuda_visible_devices}")
print(f"➡️ Python 路径: {sys.executable}")
print(f"➡️ Python 版本: {sys.version}")
def check_torch_install_type():
print("🔍 检查 PyTorch 安装类型...")
if "cpu" in torch.__version__ or "cpu" in str(torch.version.cuda):
print("⚠️ PyTorch 可能是 CPU 版本")
else:
print("✅ PyTorch 应该是 GPU 版本(含 CUDA 支持)")
def main():
print("🚀 开始检测 CUDA 与 PyTorch 环境...\n")
print_separator()
check_environment()
print_separator()
check_nvidia_smi()
print_separator()
check_torch_install_type()
print_separator()
check_torch_cuda()
print_separator()
if torch.cuda.is_available():
print("🎉 恭喜CUDA 和 PyTorch 集成正常,可以使用 GPU 训练!")
else:
print("💔 问题严重CUDA 不可用,请按以下步骤排查:")
print(" 1. 运行 `nvidia-smi` 看是否显示 GPU 和驱动版本")
print(" 2. 如果没有,安装 NVIDIA 驱动")
print(" 3. 重新安装 PyTorch GPU 版本:")
print(" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121")
print(" 4. 确保没有设置 CUDA_VISIBLE_DEVICES= 或设为无效值")
if __name__ == "__main__":
main()