Skip to content

Commit

Permalink
Fix CUDA-12 wheel loading on AmazonLinux
Browse files Browse the repository at this point in the history
Or any other distro that have different purelib and platlib paths
Regression was introduced, when small wheel base dependency was migrated
from CUDA-11 to CUDA-12

Not sure why, but minor version of the package is no longer shipped with
following CUDA-12:
 - nvidia_cuda_nvrtc_cu12-12.1.105
 - nvidia-cuda-cupti-cu12-12.1.105
 - nvidia-cuda-cupti-cu12-12.1.105

But those were present in CUDA-11 release

Fixes #109221
  • Loading branch information
malfet committed Sep 13, 2023
1 parent e066056 commit cceab7a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch/__init__.py
Expand Up @@ -174,13 +174,13 @@ def _load_global_deps() -> None:
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
except OSError as err:
# Can only happen for wheel with cuda libs as PYPI deps
# As PyTorch is not purelib, but nvidia-*-cu11 is
# As PyTorch is not purelib, but nvidia-*-cu12 is
cuda_libs: Dict[str, str] = {
'cublas': 'libcublas.so.*[0-9]',
'cudnn': 'libcudnn.so.*[0-9]',
'cuda_nvrtc': 'libnvrtc.so.*[0-9].*[0-9]',
'cuda_runtime': 'libcudart.so.*[0-9].*[0-9]',
'cuda_cupti': 'libcupti.so.*[0-9].*[0-9]',
'cuda_nvrtc': 'libnvrtc.so.*[0-9]',
'cuda_runtime': 'libcudart.so.*[0-9]',
'cuda_cupti': 'libcupti.so.*[0-9]',
'cufft': 'libcufft.so.*[0-9]',
'curand': 'libcurand.so.*[0-9]',
'cusolver': 'libcusolver.so.*[0-9]',
Expand Down

0 comments on commit cceab7a

Please sign in to comment.