Skip to content

Commit

Permalink
Revert "[reland][dynamo] Better support for nn.Module (#88959)"
Browse files Browse the repository at this point in the history
This reverts commit e950afc.

Reverted #88959 on behalf of https://github.com/malfet due to Broke `test_accuracy_issue1`
  • Loading branch information
pytorchmergebot committed Nov 13, 2022
1 parent 897d029 commit 98bcb4a
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 205 deletions.
127 changes: 0 additions & 127 deletions test/dynamo/test_modules.py
Expand Up @@ -904,133 +904,6 @@ def forward(self, x):
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))


class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))

def forward(self, x):
return self.relu(self.linear(x) + self.buf0)


class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
def test_nn_module(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)

x = torch.randn(10, 10)
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
self.assertEqual(cnt.frame_count, 1)

def test_to(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)
x = torch.randn(10, 10)
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
self.assertEqual(cnt.frame_count, 1)

# Ensure that there is no recompilation
opt_mod(x)
self.assertEqual(cnt.frame_count, 1)

opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64)
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
x = torch.randn(10, 10).to(dtype=torch.float64)
opt_mod(x)
# Ensure that there is a recompilation
self.assertEqual(cnt.frame_count, 2)

def test_attr(self):
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))

def forward(self, x):
return self.r(torch.sin(x)) + self.buf0

mod = MockModule()
opt_mod = torch._dynamo.optimize("eager")(mod)

# Check parameteres and buffers
for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()):
self.assertTrue(id(p1) == id(p2))

def test_recursion(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)

for _ in range(5):
opt_mod = torch._dynamo.optimize(cnt)(opt_mod)
opt_mod(torch.randn(10, 10))
self.assertEqual(cnt.frame_count, 1)

def test_composition(self):
class InnerModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(torch.sin(x))

opt_inner_mod = InnerModule()

class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = opt_inner_mod

def forward(self, x):
return self.mod(torch.cos(x))

outer_mod = OuterModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)

x = torch.randn(4)
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
self.assertEqual(cnt.frame_count, 1)

def test_composition_with_opt_mod(self):
class InnerModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(torch.sin(x))

inner_mod = InnerModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod)

class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = opt_inner_mod

def forward(self, x):
return self.mod(torch.cos(x))

outer_mod = OuterModule()
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)

x = torch.randn(4)
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
# There will be a graph break for the inner mod being OptimizedModule
self.assertEqual(cnt.frame_count, 2)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

Expand Down
2 changes: 0 additions & 2 deletions torch/_dynamo/__init__.py
Expand Up @@ -7,7 +7,6 @@
export,
optimize,
optimize_assert,
OptimizedModule,
reset_code,
run,
skip,
Expand All @@ -26,7 +25,6 @@
"reset",
"list_backends",
"skip",
"OptimizedModule",
]


Expand Down
8 changes: 0 additions & 8 deletions torch/_dynamo/debug_utils.py
Expand Up @@ -486,16 +486,8 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False):
"""
Check two models have same accuracy.
"""
from .eval_frame import OptimizedModule
from .testing import named_parameters_for_optimized_module
from .utils import same

if isinstance(gm, OptimizedModule):
gm.named_parameters = named_parameters_for_optimized_module(gm)

if isinstance(opt_gm, OptimizedModule):
opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm)

ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)

try:
Expand Down
74 changes: 20 additions & 54 deletions torch/_dynamo/eval_frame.py
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import sys
import textwrap
import threading
import traceback
import types
Expand Down Expand Up @@ -45,27 +44,6 @@
most_recent_backend = None


class OptimizedModule(torch.nn.Module):
"""
Wraps the original nn.Module object and later patches its
forward method to optimized self.forward method.
"""

def __init__(self, mod):
super().__init__()
# Installs the params/buffer
self._orig_mod = mod

def __getattr__(self, name):
if name == "_orig_mod":
return self._modules["_orig_mod"]
return getattr(self._orig_mod, name)

def forward(self, *args, **kwargs):
# This will be monkey patched later
raise RuntimeError("Should not be here")


def remove_from_cache(f):
"""
Make sure f.__code__ is not cached to force a recompile
Expand Down Expand Up @@ -140,15 +118,31 @@ def __call__(self, fn):
# Optimize the forward method of torch.nn.Module object
if isinstance(fn, torch.nn.Module):
mod = fn
new_mod = OptimizedModule(mod)
new_mod.forward = self(mod.forward)
optimized_forward = self(mod.forward)

class TorchDynamoNNModuleWrapper:
"""
A wrapper that redirects the forward call to the optimized
forward, while for rest it redirects the calls to the original
module.
"""

def __getattr__(self, name):
return getattr(mod, name)

def forward(self, *args, **kwargs):
return optimized_forward(*args, **kwargs)

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

new_mod = TorchDynamoNNModuleWrapper()
# Save the function pointer to find the original callable while nesting
# of decorators.
new_mod._torchdynamo_orig_callable = mod.forward
new_mod._torchdynamo_orig_callable = mod
return new_mod

assert callable(fn)

callback = self.callback
on_enter = self.on_enter
backend_ctx_ctor = self.extra_ctx_ctor
Expand Down Expand Up @@ -190,34 +184,6 @@ def _fn(*args, **kwargs):
# If the function is called using torch._dynamo.optimize decorator, we
# should prevent any type of skipping.
if callback not in (None, False):
if not hasattr(fn, "__code__"):
raise RuntimeError(
textwrap.dedent(
"""
torch._dynamo.optimize is called on a non function object.
If this is a callable class, please optimize the individual methods that you are interested in optimizing.
>> class CallableClass:
>> def __init__(self):
>> super().__init__()
>> self.relu = torch.nn.ReLU()
>>
>> def __call__(self, x):
>> return self.relu(torch.sin(x))
>>
>> def print_hello(self):
>> print("Hello world")
>>
>> mod = CallableClass()
If you want to optimize the __call__ function
>> mod.__call__ = torch._dynamo.optimize(mod.__call__)
"""
)
)
always_optimize_code_objects[fn.__code__] = True

return _fn
Expand Down
14 changes: 0 additions & 14 deletions torch/_dynamo/testing.py
Expand Up @@ -32,18 +32,6 @@ def clone_me(x):
return x.detach().clone().requires_grad_(x.requires_grad)


def named_parameters_for_optimized_module(mod):
assert isinstance(mod, eval_frame.OptimizedModule)
return mod._orig_mod.named_parameters


def remove_optimized_module_prefix(name):
prefix = "_orig_mod."
assert name.startswith(prefix)
name = name[len(prefix) :]
return torch.distributed.fsdp._common_utils.clean_tensor_name(name)


def collect_results(model, prediction, loss, example_inputs):
results = []
results.append(prediction)
Expand All @@ -56,8 +44,6 @@ def collect_results(model, prediction, loss, example_inputs):
grads = dict()
params = dict()
for name, param in model.named_parameters():
if isinstance(model, eval_frame.OptimizedModule):
name = remove_optimized_module_prefix(name)
param_copy = param
grad = param.grad
# Treat None and zero grad as same
Expand Down

0 comments on commit 98bcb4a

Please sign in to comment.