Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage_ipc] Provides IPC extensions for 3rd devices. #125122

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions aten/src/ATen/detail/PrivateUse1HooksInterface.h
Expand Up @@ -41,6 +41,24 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `resizePrivateUse1Bytes`.");
}

virtual void* Storage_shareDevice(void* self, void* noargs) const {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `Storage_sharedevice`.");
}

virtual void* Storage_newSharedDevice(void* _unused, void* args) const {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `Storage_newSharedDevice`.");
}

virtual void* Storage_releaseIPCCounterDevice(void* _unused, void* args) const {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `releaseIPCCounterDevice`.");
}
};

struct TORCH_API PrivateUse1HooksArgs {};
Expand Down
12 changes: 6 additions & 6 deletions c10/core/StorageImpl.h
Expand Up @@ -56,7 +56,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
size_bytes_(std::move(size_bytes)),
size_bytes_is_heap_allocated_(size_bytes_.is_heap_allocated()),
resizable_(resizable),
received_cuda_(false),
received_device_(false),
allocator_(allocator) {
if (resizable) {
TORCH_INTERNAL_ASSERT(
Expand Down Expand Up @@ -232,12 +232,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {

// This method can be used only after storage construction and cannot be used
// to modify storage status
void set_received_cuda(bool received_cuda) {
received_cuda_ = received_cuda;
void set_received_device(bool received_device) {
received_device_ = received_device;
}

bool received_cuda() {
return received_cuda_;
bool received_device() {
return received_device_;
}

impl::PyObjectSlot* pyobj_slot() {
Expand Down Expand Up @@ -294,7 +294,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
bool resizable_;
// Identifies that Storage was received from another process and doesn't have
// local to process cuda memory allocation
bool received_cuda_;
bool received_device_;
// All special checks in data/data_ptr calls are guarded behind this single
// boolean. This is for performance: .data/.data_ptr calls are commonly in the
// hot-path.
Expand Down
35 changes: 33 additions & 2 deletions torch/csrc/StorageSharing.cpp
Expand Up @@ -293,7 +293,7 @@ static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) {
"_share_cuda_: only available on CUDA");
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();

if (storage_impl->received_cuda()) {
if (storage_impl->received_device()) {
AT_ERROR(
"Attempted to send CUDA tensor received from another process; this is not currently supported. Consider cloning before sending.");
}
Expand Down Expand Up @@ -564,7 +564,7 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) {
/*resizable=*/false);

base->set_resizable(false);
base->set_received_cuda(true);
base->set_received_device(true);

return THPStorage_NewWithStorage(
THPStorageClass,
Expand Down Expand Up @@ -651,13 +651,40 @@ PyObject* THPStorage_isShared(PyObject* self, PyObject* noargs) {
}
}


// # Provides IPC extensions for third-party devices.
static PyObject* THPStorage_shareDecice(PyObject* self, PyObject* noargs) {
at::globalContext().lazyInitPrivateUse1();
return (PyObject*)at::detail::getPrivateUse1Hooks().Storage_shareDevice(
self, noargs);
}

static PyObject* THPStorage_newSharedDevice(PyObject* _unused, PyObject* args) {
at::globalContext().lazyInitPrivateUse1();
return (PyObject*)at::detail::getPrivateUse1Hooks().Storage_newSharedDevice(
_unused, args);
}

static PyObject* THPStorage_releaseIPCCounterDevice(
PyObject* _unused,
PyObject* args) {
at::globalContext().lazyInitPrivateUse1();
return (PyObject*)at::detail::getPrivateUse1Hooks()
.Storage_releaseIPCCounterDevice(_unused, args);
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static PyMethodDef THPStorage_sharingMethods[] = {
{"_new_with_weak_ptr",
THPStorage_newWithWeakPtr,
METH_O | METH_CLASS,
nullptr},
{"_share_device_", THPStorage_shareDecice, METH_NOARGS, nullptr},
{"_share_cuda_", THPStorage_shareCuda, METH_NOARGS, nullptr},
{"_new_shared_device",
THPStorage_newSharedDevice,
METH_VARARGS | METH_STATIC,
nullptr},
{"_new_shared_cuda",
THPStorage_newSharedCuda,
METH_VARARGS | METH_STATIC,
Expand All @@ -666,6 +693,10 @@ static PyMethodDef THPStorage_sharingMethods[] = {
THPStorage_releaseIPCCounter,
METH_VARARGS | METH_STATIC,
nullptr},
{"_release_ipc_counter_device",
THPStorage_releaseIPCCounterDevice,
METH_VARARGS | METH_STATIC,
nullptr},
{"_share_fd_cpu_", THPStorage_shareFd, METH_NOARGS, nullptr},
{"_new_shared_fd_cpu",
THPStorage_newSharedFd,
Expand Down
5 changes: 5 additions & 0 deletions torch/storage.py
Expand Up @@ -63,6 +63,8 @@ def _new_shared_filename_cpu(cls: Type[T], manager, obj, size, *, device=None, d
@classmethod
def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _release_ipc_counter_device(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
def _shared_decref(self) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def _write_file(self, *args, **kwargs): ... # noqa: E704
Expand All @@ -71,9 +73,12 @@ def _weak_ref(self, *args, **kwargs) -> T: ... # type: ignore[empty-body, misc,
def _set_from_file(self, *args, **kwargs): ... # noqa: E704
def _set_cdata(self, *args, **kwargs): ... # noqa: E704
def _share_cuda_(self, *args, **kwargs): ... # noqa: E704
def _share_device_(self, *args, **kwargs): ... # noqa: E704
def is_shared(self) -> bool: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _new_shared_device(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
def _shared_incref(self, *args, **kwargs): ... # noqa: E704
@classmethod
def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
Expand Down
50 changes: 50 additions & 0 deletions torch/utils/backend_registration.py
Expand Up @@ -266,6 +266,26 @@ def wrap_storage_to(self, device=None, non_blocking=False):
_check_register_once(torch.storage._StorageBase, custom_backend_name)
setattr(torch.storage._StorageBase, custom_backend_name, wrap_storage_to)

# Provides IPC extensions for third-party devices.
_check_register_once(torch.storage._StorageBase,
f'_share_{custom_backend_name}_')
# type: ignore[attr-defined]
setattr(torch.storage._StorageBase, f'_share_{custom_backend_name}_',
torch.StorageBase._share_device_)

_check_register_once(torch.storage._StorageBase,
f'_new_shared_{custom_backend_name}')
# type: ignore[attr-defined]
setattr(torch.storage._StorageBase, f'_new_shared_{custom_backend_name}',
classmethod(torch.StorageBase._new_shared_device))

_check_register_once(torch.storage._StorageBase,
f'_release_ipc_counter_{custom_backend_name}')
# type: ignore[attr-defined]
setattr(torch.storage._StorageBase,
f'_release_ipc_counter_{custom_backend_name}',
classmethod(torch.StorageBase._release_ipc_counter_device))

# Register the corresponding attribute for the TypedStorage class.
# When the TypedStorage class is removed, the registration is also removed.

Expand All @@ -290,6 +310,36 @@ def wrap_typed_storage_to(self: torch.storage.TypedStorage,
_check_register_once(torch.TypedStorage, custom_backend_name)
setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to)

# Provides IPC extensions for third-party devices.
# type: ignore[attr-defined]
def wrap_typed_storage_share(self, *args, **kwargs):
return self._untyped_storage._share_device_(*args, **kwargs)

_check_register_once(torch.TypedStorage, f'_share_{custom_backend_name}_')
setattr(torch.TypedStorage, f'_share_{custom_backend_name}_',
wrap_typed_storage_share)

# type: ignore[attr-defined]
def wrap_typed_storage_new_shared(cls, *args, **kwargs):
return torch.UntypedStorage._new_shared_device(*args, **kwargs)

_check_register_once(torch.TypedStorage,
f'_new_shared_{custom_backend_name}')
setattr(torch.TypedStorage, f'_new_shared_{custom_backend_name}',
classmethod(wrap_typed_storage_new_shared))

# type: ignore[attr-defined]
def wrap_typed_storage_release_ipc_counter(cls,
*args,
device=None,
**kwargs):
return torch.UntypedStorage._release_ipc_counter_device(*args, **kwargs)

_check_register_once(torch.TypedStorage,
f'_release_ipc_counter_{custom_backend_name}')
setattr(torch.TypedStorage, f'_release_ipc_counter_{custom_backend_name}',
classmethod(wrap_typed_storage_release_ipc_counter))


def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True,
for_packed_sequence: bool = True,
Expand Down