Skip to content

Commit

Permalink
[dynamo] Support class members in nn modules (#87531)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Oct 24, 2022
1 parent 272747d commit e46a897
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
17 changes: 17 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,6 +1719,23 @@ def forward(self, getitem_1, getitem_2, add):
]
self.assertTrue(same_two_models(mod, opt_mod, args))

def test_class_member(self):
class Foo(torch.nn.Module):
a = 4
b = torch.ones(3, 4)

def __init__(self):
super().__init__()
self.c = 4

def forward(self, x):
return x.cos() + self.a + self.b + self.c

mod = Foo()
opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
args = (torch.randn(3, 4),)
self.assertTrue(same(mod(*args), opt_mod(*args)))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
11 changes: 10 additions & 1 deletion torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from ..guards import GuardBuilder
from ..mutation_guard import GenerationTracker
from ..source import AttrSource, GetItemSource, NNModuleSource, NotNNModuleSource
from ..utils import is_lazy_module, istype, proxy_args_kwargs
from ..utils import (
is_lazy_module,
is_safe_constant,
istensor,
istype,
proxy_args_kwargs,
)
from .base import MutableLocal, typestr, VariableTracker
from .functions import invoke_and_store_as_constant
from .lists import SliceVariable
Expand Down Expand Up @@ -139,6 +145,9 @@ def var_getattr(self, tx, name):
return variables.UserFunctionVariable(subobj.__get__(base), **options)
elif istype(subobj, types.FunctionType):
return variables.UserMethodVariable(subobj, self, **options)
elif is_safe_constant(subobj) or istensor(subobj):
# Support possibly common cases of class members
return VariableBuilder(tx, NNModuleSource(source))(subobj)
else:
unimplemented(f"class property {typestr(base)} {typestr(subobj)}")

Expand Down

0 comments on commit e46a897

Please sign in to comment.