# test_cuda.py """ CUDA 环境全面检测脚本 Author: Qwen 功能:检测 NVIDIA 驱动、CUDA、PyTorch 与 GPU 是否正常工作 """ import torch import sys import os def print_separator(): print("\n" + "=" * 60 + "\n") def check_nvidia_smi(): print("🔍 正在检查 nvidia-smi(NVIDIA 驱动和 GPU 状态)...") try: result = os.popen("nvidia-smi").read() print(result) except Exception as e: print(f"❌ 执行 nvidia-smi 失败: {e}") def check_torch_cuda(): print("🔍 正在检查 PyTorch 与 CUDA 的集成...") print(f"➡️ PyTorch 版本: {torch.__version__}") # 检查 CUDA 是否可用 is_available = torch.cuda.is_available() print(f"➡️ torch.cuda.is_available(): {is_available}") if is_available: print(f"➡️ CUDA 版本 (PyTorch 使用): {torch.version.cuda}") print(f"➡️ GPU 数量: {torch.cuda.device_count()}") print(f"➡️ 当前设备: {torch.cuda.current_device()}") print(f"➡️ 设备名称: {torch.cuda.get_device_name(0)}") # 尝试创建一个 GPU 张量 try: x = torch.randn(3, 3).cuda() print("✅ GPU 张量创建成功!CUDA 可用且正常工作。") except Exception as e: print(f"❌ GPU 张量创建失败: {e}") else: print("❌ CUDA 不可用!PyTorch 无法使用 GPU。") print("💡 常见原因:") print(" - 没有安装 NVIDIA 驱动") print(" - PyTorch 安装的是 CPU 版本") print(" - CUDA Toolkit 与 PyTorch 不匹配") print(" - 环境变量 CUDA_VISIBLE_DEVICES 设置错误") def check_environment(): print("🔍 检查环境变量...") cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) print(f"➡️ CUDA_VISIBLE_DEVICES: {cuda_visible_devices}") print(f"➡️ Python 路径: {sys.executable}") print(f"➡️ Python 版本: {sys.version}") def check_torch_install_type(): print("🔍 检查 PyTorch 安装类型...") if "cpu" in torch.__version__ or "cpu" in str(torch.version.cuda): print("⚠️ PyTorch 可能是 CPU 版本") else: print("✅ PyTorch 应该是 GPU 版本(含 CUDA 支持)") def main(): print("🚀 开始检测 CUDA 与 PyTorch 环境...\n") print_separator() check_environment() print_separator() check_nvidia_smi() print_separator() check_torch_install_type() print_separator() check_torch_cuda() print_separator() if torch.cuda.is_available(): print("🎉 恭喜!CUDA 和 PyTorch 集成正常,可以使用 GPU 训练!") else: print("💔 问题严重:CUDA 不可用,请按以下步骤排查:") print(" 1. 运行 `nvidia-smi` 看是否显示 GPU 和驱动版本") print(" 2. 如果没有,安装 NVIDIA 驱动") print(" 3. 重新安装 PyTorch GPU 版本:") print(" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121") print(" 4. 确保没有设置 CUDA_VISIBLE_DEVICES= 或设为无效值") if __name__ == "__main__": main()