From 90068ab30afab32a8603f0965d5b7b0a008c3d3a Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 14 Sep 2023 03:13:28 +0000 Subject: [PATCH] Fix CUDA-12 wheel loading on AmazonLinux (#109244) Or any other distro that have different purelib and platlib paths Regression was introduced, when small wheel base dependency was migrated from CUDA-11 to CUDA-12 Not sure why, but minor version of the package is no longer shipped with following CUDA-12: - nvidia_cuda_nvrtc_cu12-12.1.105 - nvidia-cuda-cupti-cu12-12.1.105 - nvidia-cuda-cupti-cu12-12.1.105 But those were present in CUDA-11 release, i.e: ``` shell bash-5.2# curl -OL https://files.pythonhosted.org/packages/ef/25/922c5996aada6611b79b53985af7999fc629aee1d5d001b6a22431e18fec/nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl; unzip -t nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl |grep \.so testing: nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.11.7 OK testing: nvidia/cuda_nvrtc/lib/libnvrtc.so.11.2 OK bash-5.2# curl -OL https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl; unzip -t nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl|grep \.so testing: nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.12.1 OK testing: nvidia/cuda_nvrtc/lib/libnvrtc.so.12 OK ``` Fixes https://github.com/pytorch/pytorch/issues/109221 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109244 Approved by: https://github.com/huydhn --- torch/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index c13fca9244a0..10611c70a955 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -174,13 +174,13 @@ def _load_global_deps() -> None: ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) except OSError as err: # Can only happen for wheel with cuda libs as PYPI deps - # As PyTorch is not purelib, but nvidia-*-cu11 is + # As PyTorch is not purelib, but nvidia-*-cu12 is cuda_libs: Dict[str, str] = { 'cublas': 'libcublas.so.*[0-9]', 'cudnn': 'libcudnn.so.*[0-9]', - 'cuda_nvrtc': 'libnvrtc.so.*[0-9].*[0-9]', - 'cuda_runtime': 'libcudart.so.*[0-9].*[0-9]', - 'cuda_cupti': 'libcupti.so.*[0-9].*[0-9]', + 'cuda_nvrtc': 'libnvrtc.so.*[0-9]', + 'cuda_runtime': 'libcudart.so.*[0-9]', + 'cuda_cupti': 'libcupti.so.*[0-9]', 'cufft': 'libcufft.so.*[0-9]', 'curand': 'libcurand.so.*[0-9]', 'cusolver': 'libcusolver.so.*[0-9]',