Skip to content

Commit

Permalink
[Storage_ipc] Provides IPC extensions for 3rd devices. pytorch#124902
Browse files Browse the repository at this point in the history
  • Loading branch information
mengpenghui committed Apr 28, 2024
1 parent 94b328e commit f58f7b6
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
18 changes: 18 additions & 0 deletions aten/src/ATen/detail/PrivateUse1HooksInterface.h
Expand Up @@ -41,6 +41,24 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `resizePrivateUse1Bytes`.");
}

virtual void* Storage_shareDevice(void* self, void* noargs) const {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `Storage_sharedevice`.");
}

virtual void* Storage_newSharedDevice(void* _unused, void* args) const {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `Storage_newSharedDevice`.");
}

virtual void* Storage_releaseIPCCounterDevice(void* _unused, void* args) const {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `releaseIPCCounterDevice`.");
}
};

struct TORCH_API PrivateUse1HooksArgs {};
Expand Down
28 changes: 28 additions & 0 deletions torch/csrc/StorageSharing.cpp
Expand Up @@ -651,13 +651,37 @@ PyObject* THPStorage_isShared(PyObject* self, PyObject* noargs) {
}
}


// # Provides IPC extensions for third-party devices.
static PyObject* THPStorage_shareDecice(PyObject* self, PyObject* noargs) {
at::globalContext().lazyInitPrivateUse1();
return (PyObject *)at::detail::getPrivateUse1Hooks().Storage_shareDevice(self, noargs);
}

static PyObject* THPStorage_newSharedDevice(PyObject* _unused, PyObject* args) {
at::globalContext().lazyInitPrivateUse1();
return (PyObject *)at::detail::getPrivateUse1Hooks().Storage_newSharedDevice(_unused, args);
}

static PyObject* THPStorage_releaseIPCCounterDevice(
PyObject* _unused,
PyObject* args) {
at::globalContext().lazyInitPrivateUse1();
return (PyObject *)at::detail::getPrivateUse1Hooks().Storage_releaseIPCCounterDevice(_unused, args);
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static PyMethodDef THPStorage_sharingMethods[] = {
{"_new_with_weak_ptr",
THPStorage_newWithWeakPtr,
METH_O | METH_CLASS,
nullptr},
{"_share_device_", THPStorage_shareDecice, METH_NOARGS, nullptr},
{"_share_cuda_", THPStorage_shareCuda, METH_NOARGS, nullptr},
{"_new_shared_device",
THPStorage_newSharedDevice,
METH_VARARGS | METH_STATIC,
nullptr},
{"_new_shared_cuda",
THPStorage_newSharedCuda,
METH_VARARGS | METH_STATIC,
Expand All @@ -666,6 +690,10 @@ static PyMethodDef THPStorage_sharingMethods[] = {
THPStorage_releaseIPCCounter,
METH_VARARGS | METH_STATIC,
nullptr},
{"_release_ipc_counter_device",
THPStorage_releaseIPCCounterDevice,
METH_VARARGS | METH_STATIC,
nullptr},
{"_share_fd_cpu_", THPStorage_shareFd, METH_NOARGS, nullptr},
{"_new_shared_fd_cpu",
THPStorage_newSharedFd,
Expand Down
56 changes: 56 additions & 0 deletions torch/utils/backend_registration.py
Expand Up @@ -266,6 +266,23 @@ def wrap_storage_to(self, device=None, non_blocking=False):
_check_register_once(torch.storage._StorageBase, custom_backend_name)
setattr(torch.storage._StorageBase, custom_backend_name, wrap_storage_to)

# Provides IPC extensions for third-party devices.
_check_register_once(torch.storage._StorageBase,
f'_share_{custom_backend_name}_')
setattr(torch.storage._StorageBase, f'_share_{custom_backend_name}_',
torch._C.StorageBase._share_device_)

_check_register_once(torch.storage._StorageBase,
f'_new_shared_{custom_backend_name}')
setattr(torch.storage._StorageBase, f'_new_shared_{custom_backend_name}',
classmethod(torch._C.StorageBase._new_shared_device))

_check_register_once(torch.storage._StorageBase,
f'_release_ipc_counter_{custom_backend_name}')
setattr(torch.storage._StorageBase,
f'_release_ipc_counter_{custom_backend_name}',
classmethod(torch._C.StorageBase._release_ipc_counter_device))

# Register the corresponding attribute for the TypedStorage class.
# When the TypedStorage class is removed, the registration is also removed.

Expand All @@ -290,6 +307,45 @@ def wrap_typed_storage_to(self: torch.storage.TypedStorage,
_check_register_once(torch.TypedStorage, custom_backend_name)
setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to)

# Provides IPC extensions for third-party devices.
def wrap_typed_storage_share(self, *args, **kwargs):
return self._untyped_storage._share_device_(*args, **kwargs)

_check_register_once(torch.TypedStorage, f'_share_{custom_backend_name}_')
setattr(torch.TypedStorage, f'_share_{custom_backend_name}_',
wrap_typed_storage_share)

def wrap_typed_storage_new_shared(cls, *args, **kwargs):
return torch.UntypedStorage._new_shared_device(*args, **kwargs)

_check_register_once(torch.TypedStorage,
f'_new_shared_{custom_backend_name}')
setattr(torch.TypedStorage, f'_new_shared_{custom_backend_name}',
classmethod(wrap_typed_storage_new_shared))

def wrap_typed_storage_release_ipc_counter(cls,
*args,
device=None,
**kwargs):
return torch.UntypedStorage._release_ipc_counter_device(*args, **kwargs)

_check_register_once(torch.TypedStorage,
f'_release_ipc_counter_{custom_backend_name}')
setattr(torch.TypedStorage, f'_release_ipc_counter_{custom_backend_name}',
classmethod(wrap_typed_storage_release_ipc_counter))

def wrap_legacy_storage_release_ipc_counter(cls,
*args,
device=None,
**kwargs):
return torch.UntypedStorage._release_ipc_counter_device(*args, **kwargs)

_check_register_once(torch.storage._LegacyStorage,
f'_release_ipc_counter_{custom_backend_name}')
setattr(torch.storage._LegacyStorage,
f'_release_ipc_counter_{custom_backend_name}',
classmethod(wrap_legacy_storage_release_ipc_counter))


def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True,
for_packed_sequence: bool = True,
Expand Down

0 comments on commit f58f7b6

Please sign in to comment.