Skip to content

Commit

Permalink
Add Context Manager for Disabling Multithreading in Backwards, use in…
Browse files Browse the repository at this point in the history
… aot autograd (#86245)

We were running into a few issues with running multithreaded backwards in aot_autograd: such as #86136, and `FakeTensorMode` getting into a weird state as a result of not executing functions completely sequentially. The multithreaded backwards is lost in translation when we trace out the backwards anyway, and adds a lot of additional complexity.

Pull Request resolved: #86245
Approved by: https://github.com/albanD, https://github.com/yf225
  • Loading branch information
eellison authored and pytorchmergebot committed Oct 6, 2022
1 parent 237316a commit d048893
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 10 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/ThreadLocalState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ void ThreadLocalState::set_grad_mode(bool enabled) {
autograd_tls_.set_grad_mode(enabled);
}

void ThreadLocalState::set_multithreading_enabled(bool enabled) {
autograd_tls_.set_multithreading_enabled(enabled);
}

/* static */
void ThreadLocalState::setThreadLocalState(
const ThreadLocalState& state) {
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/ThreadLocalState.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ class TORCH_API ThreadLocalState {
// autograd engine.
void set_grad_mode(bool enabled);

// set_multithreading_enabled - force the value of the multithreadinmaximum
// threads TLS in
// the current state object. This is used for example in the
// autograd engine.
void set_multithreading_enabled(bool enabled);

// Sets thread local variables in the current thread,
// according to the thread boundary specified
static void setThreadLocalState(const ThreadLocalState& state);
Expand Down
6 changes: 4 additions & 2 deletions c10/core/AutogradState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
namespace c10 {

namespace {
// By default, grad mode is enabled and inference mode is disabled
// By default, grad mode and mulithreading are enabled, inference mode is
// disabled,
thread_local AutogradState autograd_state_tls = AutogradState(
/* grad_mode */ true,
/* inference_mode */ false,
/* fw_grad_mode */ true);
/* fw_grad_mode */ true,
/* multithreading_enabled */ true);
} // namespace

AutogradState& AutogradState::get_tls_state() {
Expand Down
18 changes: 16 additions & 2 deletions c10/core/AutogradState.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ struct C10_API AutogradState {
static AutogradState& get_tls_state();
static void set_tls_state(AutogradState state);

AutogradState(bool grad_mode, bool inference_mode, bool fw_grad_mode)
AutogradState(
bool grad_mode,
bool inference_mode,
bool fw_grad_mode,
bool multithreading_enabled)
: grad_mode_(grad_mode),
inference_mode_(inference_mode),
fw_grad_mode_(fw_grad_mode) {}
fw_grad_mode_(fw_grad_mode),
mulithreading_enabled_(multithreading_enabled) {}

void set_grad_mode(bool enabled) {
grad_mode_ = enabled;
Expand All @@ -29,6 +34,10 @@ struct C10_API AutogradState {
inference_mode_ = enabled;
}

void set_multithreading_enabled(bool mulithreading_enabled) {
mulithreading_enabled_ = mulithreading_enabled;
}

bool get_grad_mode() const {
return grad_mode_;
}
Expand All @@ -41,10 +50,15 @@ struct C10_API AutogradState {
return inference_mode_;
}

bool get_multithreading_enabled() const {
return mulithreading_enabled_;
}

private:
bool grad_mode_ : 1;
bool inference_mode_ : 1;
bool fw_grad_mode_ : 1;
bool mulithreading_enabled_ : 1;
};

} // namespace c10
3 changes: 2 additions & 1 deletion c10/core/InferenceMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ struct TORCH_API InferenceMode {
AutogradState::set_tls_state(AutogradState(
/* grad_mode */ !enabled,
/* inference_mode */ enabled,
/* fw_grad_mode */ !enabled));
/* fw_grad_mode */ !enabled,
/* multithreading_enabled*/ !enabled));
DispatchKeySet included = enabled
? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
: prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);
Expand Down
11 changes: 11 additions & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,14 @@ Operator Tags
.. This module needs to be documented. Adding here in the meantime
.. for tracking purposes
.. py:module:: torch.utils.model_dump
.. automodule:: torch.autograd
.. currentmodule:: torch.autograd

Engine Configuration
----------------------------------
.. autosummary::
:toctree: generated
:nosignatures:

set_multithreading_enabled
2 changes: 1 addition & 1 deletion functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def create_aot_dispatcher_function(
python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext()
shape_env = ShapeEnv() if config.use_dynamic_shapes else None

with preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
with torch.autograd.set_multithreading_enabled(False), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:

def process_inputs(flat_args):
if config.use_fake_tensor:
Expand Down
1 change: 1 addition & 0 deletions test/allowlist_for_publicAPI.json
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@
"no_grad",
"set_detect_anomaly",
"set_grad_enabled",
"set_multithreading_enabled",
"variable"
],
"torch.autograd.function": [
Expand Down
27 changes: 27 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9304,6 +9304,33 @@ def foo(x):
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
foo(nt).backward(torch.nested.nested_tensor([torch.rand(1), torch.rand(1)], device=device))

@onlyCUDA
def test_backward_single_threaded(self):

threads_eq = None

class TestFn(Function):
@staticmethod
def forward(ctx, x, self):
ctx.self = self
ctx.tid = threading.get_ident()
return x.clone()

@staticmethod
def backward(ctx, gO):
nonlocal threads_eq
threads_eq = ctx.tid == threading.get_ident()
return gO, None

inp = torch.rand(10, device="cuda", requires_grad=True)

with torch.autograd.set_multithreading_enabled(False):
TestFn.apply(inp, None).sum().backward()
self.assertTrue(threads_eq)

TestFn.apply(inp, None).sum().backward()
self.assertFalse(threads_eq)

# Import test cases from below autograd/ here. These are found
# implicitly by the loader, so Flake8 thinks they are unused, hence
# the suppressions.
Expand Down
3 changes: 3 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,9 @@ class _DisableFuncTorch:
class _EnableTorchFunction:
def __init__(self) -> None: ...

class _MultithreadingEnabled:
def __init__(self, mode: _bool) -> None: ...

# Defined in torch/csrc/jit/python/script_init.cpp
class LoggerBase(object):
...
Expand Down
2 changes: 1 addition & 1 deletion torch/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .variable import Variable
from .function import Function, NestedIOFunction
from .gradcheck import gradcheck, gradgradcheck
from .grad_mode import no_grad, enable_grad, set_grad_enabled, inference_mode
from .grad_mode import no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled
from .anomaly_mode import detect_anomaly, set_detect_anomaly
from ..overrides import has_torch_function, handle_torch_function, is_tensor_like
from . import functional
Expand Down
36 changes: 34 additions & 2 deletions torch/autograd/grad_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Callable, TypeVar, cast

__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
'inference_mode']
'inference_mode', 'set_multithreading_enabled']


# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
Expand Down Expand Up @@ -184,7 +184,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:


class set_grad_enabled(_DecoratorContextManager):
r"""Context-manager that sets gradient calculation to on or off.
r"""Context-manager that sets gradient calculation on or off.
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
Expand Down Expand Up @@ -298,3 +298,35 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:

def clone(self):
return self.__class__(self.mode)


class set_multithreading_enabled(_DecoratorContextManager):
r"""Context-manager that sets multithreaded backwards on or off.
``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
Args:
mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
(``False``).
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
"""

def __init__(self, mode: bool) -> None:
self.mode = mode
self.multithreadeding_enabled_guard = torch._C._MultithreadingEnabled(mode)

def __enter__(self) -> None:
pass

def __exit__(self, *args) -> None:
del self.multithreadeding_enabled_guard

def clone(self):
return self.__class__(self.mode)
4 changes: 3 additions & 1 deletion torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,9 @@ void Engine::init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue) {
auto Engine::ready_queue(
std::shared_ptr<ReadyQueue> cpu_ready_queue,
at::Device device) -> std::shared_ptr<ReadyQueue> {
if (should_run_in_cpu_ready_queue(device.type())) {
bool multithreading_disabled =
!c10::AutogradState::get_tls_state().get_multithreading_enabled();
if (multithreading_disabled || should_run_in_cpu_ready_queue(device.type())) {
// return the cpu ready queue passed in
TORCH_INTERNAL_ASSERT(cpu_ready_queue);
return cpu_ready_queue;
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/autograd/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ struct DisableFuncTorch {
c10::impl::ExcludeDispatchKeyGuard back_guard_;
};

struct MultithreadingEnabled {
MultithreadingEnabled(bool enabled)
: old_(c10::AutogradState::get_tls_state().get_multithreading_enabled()) {
c10::AutogradState::get_tls_state().set_multithreading_enabled(enabled);
}
~MultithreadingEnabled() {
c10::AutogradState::get_tls_state().set_multithreading_enabled(old_);
}
bool old_;
};

struct EnableTorchFunction {
EnableTorchFunction()
: old_(at::impl::PythonTorchFunctionTLS::is_disabled()) {
Expand Down Expand Up @@ -354,6 +365,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
_C_m, "_DisablePythonDispatcher")
.def(py::init<>());
py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>());
py::class_<MultithreadingEnabled>(_C_m, "_MultithreadingEnabled")
.def(py::init<bool>());

py::class_<torch::autograd::SavedVariable>(m, "SavedTensor")
.def(py::init([]() -> torch::autograd::SavedVariable {
Expand Down

0 comments on commit d048893

Please sign in to comment.