Skip to content

Commit

Permalink
Fix CUDA-12 wheel loading on AmazonLinux (#109244)
Browse files Browse the repository at this point in the history
Or any other distro that have different purelib and platlib paths Regression was introduced, when small wheel base dependency was migrated from CUDA-11 to CUDA-12

Not sure why, but minor version of the package is no longer shipped with following CUDA-12:
 - nvidia_cuda_nvrtc_cu12-12.1.105
 - nvidia-cuda-cupti-cu12-12.1.105
 - nvidia-cuda-cupti-cu12-12.1.105

But those were present in CUDA-11 release, i.e:
``` shell
bash-5.2# curl -OL https://files.pythonhosted.org/packages/ef/25/922c5996aada6611b79b53985af7999fc629aee1d5d001b6a22431e18fec/nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl; unzip -t nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl |grep \.so
    testing: nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.11.7   OK
    testing: nvidia/cuda_nvrtc/lib/libnvrtc.so.11.2   OK
bash-5.2# curl -OL https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl; unzip -t nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl|grep \.so
    testing: nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.12.1   OK
    testing: nvidia/cuda_nvrtc/lib/libnvrtc.so.12   OK
```

Fixes #109221

Pull Request resolved: #109244
Approved by: https://github.com/huydhn
  • Loading branch information
malfet authored and pytorchmergebot committed Sep 14, 2023
1 parent 47f79e9 commit 90068ab
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch/__init__.py
Expand Up @@ -174,13 +174,13 @@ def _load_global_deps() -> None:
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
except OSError as err:
# Can only happen for wheel with cuda libs as PYPI deps
# As PyTorch is not purelib, but nvidia-*-cu11 is
# As PyTorch is not purelib, but nvidia-*-cu12 is
cuda_libs: Dict[str, str] = {
'cublas': 'libcublas.so.*[0-9]',
'cudnn': 'libcudnn.so.*[0-9]',
'cuda_nvrtc': 'libnvrtc.so.*[0-9].*[0-9]',
'cuda_runtime': 'libcudart.so.*[0-9].*[0-9]',
'cuda_cupti': 'libcupti.so.*[0-9].*[0-9]',
'cuda_nvrtc': 'libnvrtc.so.*[0-9]',
'cuda_runtime': 'libcudart.so.*[0-9]',
'cuda_cupti': 'libcupti.so.*[0-9]',
'cufft': 'libcufft.so.*[0-9]',
'curand': 'libcurand.so.*[0-9]',
'cusolver': 'libcusolver.so.*[0-9]',
Expand Down

0 comments on commit 90068ab

Please sign in to comment.