From cceab7a67f4dd344d5507e5270950bebb634498a Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 13 Sep 2023 12:35:53 -0700 Subject: [PATCH] Fix CUDA-12 wheel loading on AmazonLinux Or any other distro that have different purelib and platlib paths Regression was introduced, when small wheel base dependency was migrated from CUDA-11 to CUDA-12 Not sure why, but minor version of the package is no longer shipped with following CUDA-12: - nvidia_cuda_nvrtc_cu12-12.1.105 - nvidia-cuda-cupti-cu12-12.1.105 - nvidia-cuda-cupti-cu12-12.1.105 But those were present in CUDA-11 release Fixes https://github.com/pytorch/pytorch/issues/109221 --- torch/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index c13fca9244a0..10611c70a955 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -174,13 +174,13 @@ def _load_global_deps() -> None: ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) except OSError as err: # Can only happen for wheel with cuda libs as PYPI deps - # As PyTorch is not purelib, but nvidia-*-cu11 is + # As PyTorch is not purelib, but nvidia-*-cu12 is cuda_libs: Dict[str, str] = { 'cublas': 'libcublas.so.*[0-9]', 'cudnn': 'libcudnn.so.*[0-9]', - 'cuda_nvrtc': 'libnvrtc.so.*[0-9].*[0-9]', - 'cuda_runtime': 'libcudart.so.*[0-9].*[0-9]', - 'cuda_cupti': 'libcupti.so.*[0-9].*[0-9]', + 'cuda_nvrtc': 'libnvrtc.so.*[0-9]', + 'cuda_runtime': 'libcudart.so.*[0-9]', + 'cuda_cupti': 'libcupti.so.*[0-9]', 'cufft': 'libcufft.so.*[0-9]', 'curand': 'libcurand.so.*[0-9]', 'cusolver': 'libcusolver.so.*[0-9]',