Skip to content

Commit

Permalink
Update on "[dynamo][disable] Move disable impl to its own __call__ me…
Browse files Browse the repository at this point in the history
…thod"


There were internal cases where calling disable in distributed causes trace_rules to be generated, which imports distributed and causes circular import errors.

The code has also gone bulky. I think it is time for disable code to exist separately.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
  • Loading branch information
anijain2305 committed May 8, 2024
2 parents 818f512 + 62c99af commit 85ee236
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions torch/_dynamo/eval_frame.py
Expand Up @@ -150,7 +150,10 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx):

def _initialize(self):
# Do this stuff in constructor to lower overhead slightly
if isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check(
if isinstance(self.dynamo_ctx, DisableContext):
# No need to check trace rules
self.forward = self.dynamo_ctx(self._orig_mod.__call__)
elif isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check(
self._orig_mod.forward
):
# This may be a torch.nn.* instance in trace_rules.py which
Expand Down Expand Up @@ -514,12 +517,11 @@ def __call__(self, fn):
# create any wrapper.
fn = innermost_fn(fn)

# Optimize the forward method of torch.nn.Module object
if isinstance(fn, torch.nn.Module):
mod = fn
mod.__call__ = self(mod.__call__)
mod._call_impl = self(mod._call_impl)
return mod
new_mod = OptimizedModule(mod, self)
new_mod._torchdynamo_orig_callable = mod.forward
return new_mod

if inspect.isclass(fn):
# User has wrapped the class with compile/disable decorator. Apply
Expand Down

0 comments on commit 85ee236

Please sign in to comment.