diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 930035f99a30..2fb83b3add6c 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -904,133 +904,6 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) -class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) - - def forward(self, x): - return self.relu(self.linear(x) + self.buf0) - - -class OptimizedModuleTest(torch._dynamo.test_case.TestCase): - def test_nn_module(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) - - x = torch.randn(10, 10) - self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - def test_to(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - x = torch.randn(10, 10) - self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - # Ensure that there is no recompilation - opt_mod(x) - self.assertEqual(cnt.frame_count, 1) - - opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) - self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) - x = torch.randn(10, 10).to(dtype=torch.float64) - opt_mod(x) - # Ensure that there is a recompilation - self.assertEqual(cnt.frame_count, 2) - - def test_attr(self): - class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) - - def forward(self, x): - return self.r(torch.sin(x)) + self.buf0 - - mod = MockModule() - opt_mod = torch._dynamo.optimize("eager")(mod) - - # Check parameteres and buffers - for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): - self.assertTrue(id(p1) == id(p2)) - - def test_recursion(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - - for _ in range(5): - opt_mod = torch._dynamo.optimize(cnt)(opt_mod) - opt_mod(torch.randn(10, 10)) - self.assertEqual(cnt.frame_count, 1) - - def test_composition(self): - class InnerModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(torch.sin(x)) - - opt_inner_mod = InnerModule() - - class OuterModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = opt_inner_mod - - def forward(self, x): - return self.mod(torch.cos(x)) - - outer_mod = OuterModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) - - x = torch.randn(4) - self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) - self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - def test_composition_with_opt_mod(self): - class InnerModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(torch.sin(x)) - - inner_mod = InnerModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) - - class OuterModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = opt_inner_mod - - def forward(self, x): - return self.mod(torch.cos(x)) - - outer_mod = OuterModule() - opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) - - x = torch.randn(4) - self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) - self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) - # There will be a graph break for the inner mod being OptimizedModule - self.assertEqual(cnt.frame_count, 2) - - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 5eee609b0852..80f927aeef2f 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -7,7 +7,6 @@ export, optimize, optimize_assert, - OptimizedModule, reset_code, run, skip, @@ -26,7 +25,6 @@ "reset", "list_backends", "skip", - "OptimizedModule", ] diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 089ef172d625..f09991f9bf34 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -486,16 +486,8 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): """ Check two models have same accuracy. """ - from .eval_frame import OptimizedModule - from .testing import named_parameters_for_optimized_module from .utils import same - if isinstance(gm, OptimizedModule): - gm.named_parameters = named_parameters_for_optimized_module(gm) - - if isinstance(opt_gm, OptimizedModule): - opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) - ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) try: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 20e8c7de085e..8d9e3b7b6aa1 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -5,7 +5,6 @@ import logging import os import sys -import textwrap import threading import traceback import types @@ -45,27 +44,6 @@ most_recent_backend = None -class OptimizedModule(torch.nn.Module): - """ - Wraps the original nn.Module object and later patches its - forward method to optimized self.forward method. - """ - - def __init__(self, mod): - super().__init__() - # Installs the params/buffer - self._orig_mod = mod - - def __getattr__(self, name): - if name == "_orig_mod": - return self._modules["_orig_mod"] - return getattr(self._orig_mod, name) - - def forward(self, *args, **kwargs): - # This will be monkey patched later - raise RuntimeError("Should not be here") - - def remove_from_cache(f): """ Make sure f.__code__ is not cached to force a recompile @@ -140,15 +118,31 @@ def __call__(self, fn): # Optimize the forward method of torch.nn.Module object if isinstance(fn, torch.nn.Module): mod = fn - new_mod = OptimizedModule(mod) - new_mod.forward = self(mod.forward) + optimized_forward = self(mod.forward) + + class TorchDynamoNNModuleWrapper: + """ + A wrapper that redirects the forward call to the optimized + forward, while for rest it redirects the calls to the original + module. + """ + + def __getattr__(self, name): + return getattr(mod, name) + + def forward(self, *args, **kwargs): + return optimized_forward(*args, **kwargs) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + new_mod = TorchDynamoNNModuleWrapper() # Save the function pointer to find the original callable while nesting # of decorators. - new_mod._torchdynamo_orig_callable = mod.forward + new_mod._torchdynamo_orig_callable = mod return new_mod assert callable(fn) - callback = self.callback on_enter = self.on_enter backend_ctx_ctor = self.extra_ctx_ctor @@ -190,34 +184,6 @@ def _fn(*args, **kwargs): # If the function is called using torch._dynamo.optimize decorator, we # should prevent any type of skipping. if callback not in (None, False): - if not hasattr(fn, "__code__"): - raise RuntimeError( - textwrap.dedent( - """ - - torch._dynamo.optimize is called on a non function object. - If this is a callable class, please optimize the individual methods that you are interested in optimizing. - - >> class CallableClass: - >> def __init__(self): - >> super().__init__() - >> self.relu = torch.nn.ReLU() - >> - >> def __call__(self, x): - >> return self.relu(torch.sin(x)) - >> - >> def print_hello(self): - >> print("Hello world") - >> - >> mod = CallableClass() - - If you want to optimize the __call__ function - - >> mod.__call__ = torch._dynamo.optimize(mod.__call__) - - """ - ) - ) always_optimize_code_objects[fn.__code__] = True return _fn diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 6e0d32d21f97..d6082ce48acf 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -32,18 +32,6 @@ def clone_me(x): return x.detach().clone().requires_grad_(x.requires_grad) -def named_parameters_for_optimized_module(mod): - assert isinstance(mod, eval_frame.OptimizedModule) - return mod._orig_mod.named_parameters - - -def remove_optimized_module_prefix(name): - prefix = "_orig_mod." - assert name.startswith(prefix) - name = name[len(prefix) :] - return torch.distributed.fsdp._common_utils.clean_tensor_name(name) - - def collect_results(model, prediction, loss, example_inputs): results = [] results.append(prediction) @@ -56,8 +44,6 @@ def collect_results(model, prediction, loss, example_inputs): grads = dict() params = dict() for name, param in model.named_parameters(): - if isinstance(model, eval_frame.OptimizedModule): - name = remove_optimized_module_prefix(name) param_copy = param grad = param.grad # Treat None and zero grad as same