Skip to content

Commit

Permalink
[reland][dynamo] fixes dict changed during runtime error (#88877)
Browse files Browse the repository at this point in the history
Reland #87526

Pull Request resolved: #88877
Approved by: https://github.com/ezyang
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Nov 13, 2022
1 parent 4284862 commit 897d029
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
3 changes: 0 additions & 3 deletions test/dynamo/test_aot_cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def fn(x, y):
y = torch.randn(3, device="cuda")
fn(x, y)

@patch("torch._dynamo.config.suppress_errors", True)
@patch_all()
def test_dtoh(self):
def model(x, y):
Expand Down Expand Up @@ -105,7 +104,6 @@ def fn(x, y):
y = torch.randn((), device="cpu")
fn(x, y)

@patch("torch._dynamo.config.suppress_errors", True)
@patch("functorch._src.config.use_functionalize", True)
@patch_all(ok=False) # input mutation not supported yet
def test_mutate_input(self):
Expand Down Expand Up @@ -145,7 +143,6 @@ def fn(x, y):
y = torch.randn(1, device="cuda")
fn(x, y)

@patch("torch._dynamo.config.suppress_errors", True)
@patch_all()
def test_factory(self):
def model(y):
Expand Down
30 changes: 30 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,36 @@ def fn(x):
res = opt_fn(a)
self.assertTrue(same(ref, res))

def test_tokenization(self):
from collections import UserDict

class BatchEncoding(UserDict):
"""
Copied from tokenization
"""

def __init__(
self,
data,
):
super().__init__(data)

def __getattr__(self, item: str):
try:
return self.data[item]
except KeyError:
raise AttributeError

def tokenization(x):
encoding = BatchEncoding({"key": x})
return encoding["key"]

opt_fn = torch._dynamo.optimize("eager")(tokenization)
x = torch.rand((1, 4))
ref = tokenization(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))

def test_modules(self):
class Foo(torch.nn.Module):
def __init__(self):
Expand Down
10 changes: 6 additions & 4 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,18 @@ def has_tensor(obj):
seen_ids[obj_id] = any([has_tensor(v) for v in obj])
return seen_ids[obj_id]
elif istype(obj, dict):
seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()])
# Some packages like pytest can be updated during runtime. So, make a
# copy of values to avoid issues like "RuntimeError: dictionary
# changed size during iteration"
values = list(obj.values())
seen_ids[obj_id] = any([has_tensor(v) for v in values])
return seen_ids[obj_id]
elif istype(obj, (str, int, float, type(None), bool)):
seen_ids[obj_id] = False
return seen_ids[obj_id]
elif is_namedtuple(obj):
seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields])
return seen_ids[obj_id]
elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__):
seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()])
return seen_ids[obj_id]
else:
# if config.debug:
# print(
Expand Down Expand Up @@ -302,6 +303,7 @@ def _convert_frame_assert(frame: types.FrameType, cache_size: int):
# setattr could be tricky to handle generally,
# but also not likely useful to compile- skip the whole frame
return None

# Check if the frame is generated by an exec builtin call
# TODO - Running exec generated frame seems propagates f_globals to the
# next frames.
Expand Down

0 comments on commit 897d029

Please sign in to comment.