Skip to content

Commit

Permalink
Fix cuda deps search path
Browse files Browse the repository at this point in the history
Fixes #88869
  • Loading branch information
malfet committed Dec 7, 2022
1 parent a076bdb commit 73794d8
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion torch/__init__.py
Expand Up @@ -142,6 +142,20 @@
kernel32.SetErrorMode(prev_error_mode)


def _preload_cuda_deps():
# Should only be called on Linux if default path resolution have failed
assert platform.system() == 'Linux', 'Should only be called on Linux'
for path in sys.path:
nvidia_path = os.path.join(path, 'nvidia')
if not os.path.exists(nvidia_path):
continue
cublas_path = os.path.join(nvidia_path, 'cublas', 'lib', 'libcublas.so.11')
cudnn_path = os.path.join(nvidia_path, 'cudnn', 'lib', 'libcudnn.so.8')
ctypes.CDLL(cublas_path)
ctypes.CDLL(cudnn_path)
break


# See Note [Global dependencies]
def _load_global_deps():
if platform.system() == 'Windows' or sys.executable == 'torch_deploy':
Expand All @@ -151,7 +165,13 @@ def _load_global_deps():
here = os.path.abspath(__file__)
lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name)

ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
try:
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
except OSError as e:
if 'libcublas.so.11' not in e.args[0]:
raise e
_preload_cuda_deps()
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)


if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \
Expand Down

0 comments on commit 73794d8

Please sign in to comment.