forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 9
/
torch.py
699 lines (608 loc) · 25.6 KB
/
torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
import logging
import re
import types
from typing import Dict, List
import numpy
import torch._C
import torch.nn
import torch.onnx.operators
from .. import config, variables
from ..allowed_functions import torch_get_name
from ..exc import unimplemented
from ..source import GetItemSource, NNModuleSource
from ..utils import (
check_constant_args,
check_unspec_python_args,
istype,
product,
proxy_args_kwargs,
specialize_args_kwargs,
tensortype_to_dtype,
)
from .base import VariableTracker
from .lists import ListVariable, TupleVariable
from .misc import AutocastModeVariable, ProfilerContextWrapperVariable
from .nn_module import NNModuleVariable
from .tensor import TensorWithTFOverrideVariable
log = logging.getLogger(__name__)
# TODO(voz): Maybe rename these later
tensor_dunder_fns = [
torch.Tensor.__rmatmul__,
torch.Tensor.__rmod__,
torch.Tensor.__rpow__,
torch.Tensor.__rsub__,
torch._C._TensorBase.__radd__,
torch._C._TensorBase.__rmul__,
torch._C._TensorBase.__ror__,
torch._C._TensorBase.__rxor__,
torch._C._TensorBase.__rand__,
]
torch_special_class_types = (torch._C.Generator,)
REWRITE_OPS_TO_TENSOR_SIZE_METHOD = [
torch.onnx.operators.shape_as_tensor,
torch._shape_as_tensor,
]
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
def remap_as_fn___radd__(*args):
return torch._C._TensorBase.__radd__(*args)
def remap_as_fn___rmul__(*args):
return torch._C._TensorBase.__rmul__(*args)
def remap_as_fn___ror__(*args):
return torch._C._TensorBase.__ror__(*args)
def remap_as_fn___rxor__(*args):
return torch._C._TensorBase.__rxor__(*args)
def remap_as_fn___rand__(*args):
return torch._C._TensorBase.__rand__(*args)
tensor_dunder_fns_remap = {
torch._C._TensorBase.__radd__: remap_as_fn___radd__,
torch._C._TensorBase.__rmul__: remap_as_fn___rmul__,
torch._C._TensorBase.__ror__: remap_as_fn___ror__,
torch._C._TensorBase.__rxor__: remap_as_fn___rxor__,
torch._C._TensorBase.__rand__: remap_as_fn___rand__,
}
try:
# Wed need to monkeypatch transformers here, sadly.
# TODO(voz): Upstream to transformers lib
import transformers
def _dynamo_overriden_transformers_eq(self, other):
if not hasattr(other, "__dict__"):
return False
return self.__dict__ == other.__dict__
transformers.configuration_utils.PretrainedConfig.__eq__ = (
_dynamo_overriden_transformers_eq
)
except ImportError:
pass
class TorchVariable(VariableTracker):
"""Points to a module or method in torch.*"""
def __init__(self, value, **kwargs):
super(TorchVariable, self).__init__(**kwargs)
if value in tensor_dunder_fns_remap:
value = tensor_dunder_fns_remap[value]
self.value = value
# the remainder of this is just optional debug checks
try:
self_should_be_none = getattr(self.value, "__self__", None)
except RuntimeError as e:
assert "No such operator" in str(e), str(e)
self_should_be_none = None
# assert "_ntuple.<locals>.parse" not in str(value)
if self_should_be_none is None:
pass
elif isinstance(self_should_be_none, types.ModuleType):
# weird ones like torch.nn.functional.avg_pool2d have __self__
name = self_should_be_none.__name__
assert re.match(r"^(torch|math)([.]|$)", name), f"__self__ set to {name}"
elif isinstance(
self_should_be_none, type(torch._C._get_tracing_state.__self__)
):
# some _C functions have __self__ as a null capsule
pass
elif isinstance(self_should_be_none, torch_special_class_types):
pass
else:
raise AssertionError(f"{value} found with __self__ set")
def __repr__(self):
return f"TorchVariable({self.value})"
def unique_var_name(self):
name = torch_get_name(self.value, f"allowed_fn_{id(self.value)}")
return "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name)
def reconstruct(self, codegen):
return codegen.setup_globally_cached(self.unique_var_name(), self.value)
def as_proxy(self):
return self.value
def python_type(self):
if isinstance(self.value, (torch.Tensor, torch.nn.Module)):
return type(self.value)
return super().python_type()
def as_python_constant(self):
return self.value
def can_constant_fold_through(self):
if self.value in (
torch._assert,
torch.device,
torch.finfo,
torch.iinfo,
torch.is_floating_point,
torch.is_tensor,
torch.overrides.is_tensor_like,
):
return True
return getattr(self.value, "__module__", None) == "math"
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import ConstantVariable, GradModeVariable, TensorVariable
constant_args = check_constant_args(args, kwargs)
unspec_python_args = check_unspec_python_args(args, kwargs)
options = VariableTracker.propagate(self, args, kwargs.values())
if self.value in config.constant_functions:
assert not args and not kwargs
return ConstantVariable(config.constant_functions[self.value], **options)
elif self.can_constant_fold_through() and (constant_args or unspec_python_args):
args, kwargs = specialize_args_kwargs(tx, args, kwargs)
# constant fold
return ConstantVariable(
self.as_python_constant()(
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
),
**options,
)
elif istype(self.value, type) and issubclass(self.value, torch.nn.Module):
if self.value is torch.nn.Softmax:
return self._call_softmax(tx, args, kwargs, options)
if self.value is torch.nn.CrossEntropyLoss:
return self._call_cross_entropy_loss(tx, args, kwargs, options)
else:
unimplemented(f"construct nn.Module: {self.value.__name__}")
elif (
self.value
in (
torch.is_tensor,
torch.is_floating_point,
torch.is_complex,
torch.overrides.is_tensor_like,
torch.is_complex,
)
and isinstance(args[0], TensorVariable)
and args[0].dtype is not None
):
if self.value in (torch.is_tensor, torch.overrides.is_tensor_like):
return ConstantVariable(True, **options)
elif self.value is torch.is_floating_point:
return ConstantVariable(args[0].dtype.is_floating_point, **options)
elif self.value is torch.is_complex:
return ConstantVariable(args[0].dtype.is_complex, **options)
else:
raise AssertionError()
elif (
self.value is torch.numel
and isinstance(args[0], TensorVariable)
and args[0].size is not None
):
return ConstantVariable(product(args[0].size), **options)
elif self.value in REWRITE_OPS_TO_TENSOR_SIZE_METHOD:
assert len(args) == 1
assert isinstance(args[0], TensorVariable)
return args[0].call_method(tx, "size", [], {})
elif self.value in (
torch.nn.modules.utils._single,
torch.nn.modules.utils._pair,
torch.nn.modules.utils._triple,
torch.nn.modules.utils._quadruple,
torch.nn.modules.utils._ntuple,
):
return self._call_ntuple(tx, args, kwargs, options)
elif self.value is torch.no_grad:
return GradModeVariable.create(tx, False, **options)
elif self.value is torch.enable_grad:
return GradModeVariable.create(tx, True, **options)
elif self.value is torch.set_grad_enabled and len(args) == 1:
return GradModeVariable.create(tx, args[0].as_python_constant(), **options)
elif self.value is torch.is_grad_enabled:
assert not (args or kwargs)
return ConstantVariable(torch.is_grad_enabled(), **options).add_guards(
GradModeVariable._guards_singleton
)
elif not config.dynamic_shapes and self.is_dynamic_shapes(args, kwargs):
unimplemented(f"dynamic shapes: {self.value.__name__}")
elif len(args) > 0 and isinstance(args[0], TensorWithTFOverrideVariable):
# This code block implements inlining the __torch_function__
# override of a tensor.
tensor_with_tf_override = args[0]
# TODO(future PR): make this implement the full __torch_function__ API
# instead of assuming the relevant override is in the first argument.
args[0] = args[0].tensor_variable
unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
tx,
self,
tensor_with_tf_override.orig_tensor_variable_source,
tensor_with_tf_override.subclass_torch_function__func,
tensor_with_tf_override.subclass_type,
options,
args,
kwargs,
)
# The wrapping here follows the logic in
# `torch.Tensor.__torch_function__`.
if self.value in torch.overrides.get_default_nowrap_functions():
return unwrapped
return TensorWithTFOverrideVariable(
unwrapped,
tensor_with_tf_override.orig_tensor_variable_source,
tensor_with_tf_override.subclass_torch_function__func,
tensor_with_tf_override.subclass_type,
)
elif self.value is torch.amp.autocast_mode.autocast:
return AutocastModeVariable.create(tx, target_values=args, kwargs=kwargs)
elif self.value in (
torch.profiler.profile,
torch.profiler.record_function,
torch.autograd.profiler.profile,
torch.autograd.profiler.record_function,
):
log.warning("Profiler will be ignored")
return ProfilerContextWrapperVariable(**options)
elif self.value is torch.jit.annotate:
assert len(args) == 2
return args[1]
if (
self.value.__name__ == "get_state"
and hasattr(self.value, "__self__")
and isinstance(self.value.__self__, torch._C.Generator)
):
def get_state_from_generator():
return self.value()
return TensorVariable.create(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
get_state_from_generator,
*proxy_args_kwargs(args, kwargs),
current_tx=tx,
),
example_value=self.value(),
**options,
)
if (
self.value.__name__ == "set_state"
and hasattr(self.value, "__self__")
and isinstance(self.value.__self__, torch._C.Generator)
) or self.value == torch.random.set_rng_state:
assert len(args) == 1
assert isinstance(args[0], TensorVariable)
if config.fake_tensor_propagation:
unimplemented(
"TODO: make torch.random.set_rng_state work with FakeTensor/aot_autograd"
)
# In fake tensor case, this state doesn't matter, but
# it needs to be valid to not segfault. Pull a real tensor out.
# The value won't matter since we are running with fake tensors anyway, so rng doesn't matter.
# However, it is imperative to record the call_function in the graph with the true args
# (Not the fake example_value) - for the sake of graph correctness.
if self.value == torch.random.set_rng_state:
example_value = torch.random.get_rng_state()
else:
example_value = self.value.__self__.get_state()
else:
example_value = args[0].proxy.node.meta["example_value"]
self.value.__module__ = self.__module__
return TensorVariable.create(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
*proxy_args_kwargs(args, kwargs),
current_tx=tx,
),
example_value=example_value,
**options,
)
elif (
self.value == torch.numel
and len(args) == 1
and isinstance(args[0], TensorVariable)
and len(kwargs) == 0
):
# TODO(voz): This is rewritten as a call_method because
# torch.numel(x) w/ sym shapes raises a RuntimeError and x.numel() does not
return TensorVariable.create(
tx=tx,
proxy=tx.output.create_proxy(
"call_method",
"numel",
*proxy_args_kwargs(args, kwargs),
current_tx=tx,
),
**options,
)
else:
# Handle sth like torch.LongTensor(list(np.int64, np.int64, ...)),
# as FX symbolic trace doesn't support numpy int/float as base types.
if (
self.value in tensortype_to_dtype
and len(args) == 1
and isinstance(args[0], ListVariable)
and args[0].is_python_constant()
):
for x in args[0].items:
if isinstance(x.value, numpy.generic):
x.value = x.value.item()
tensor_variable = TensorVariable.create(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
*proxy_args_kwargs(args, kwargs),
current_tx=tx,
),
**options,
)
if "out" in kwargs:
# out variants of torch operators like torch.sort and
# torch.sigmoid mutate the tensors in the out field. Track such
# tensors and rewrite the symbolic locals.
if isinstance(tensor_variable, TupleVariable):
assert isinstance(kwargs["out"], TupleVariable)
output_tensor_names = [
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
]
for idx, name in enumerate(output_tensor_names):
assert name in tx.symbolic_locals
tx.symbolic_locals[name] = tensor_variable.items[idx]
elif isinstance(tensor_variable, TensorVariable):
assert isinstance(kwargs["out"], TensorVariable)
name = tx.find_symbolic_locals_name(kwargs["out"])
assert name in tx.symbolic_locals
tx.symbolic_locals[name] = tensor_variable
else:
unimplemented(f"out variant of {type(kwargs['out'])}")
return tensor_variable
def is_dynamic_shapes(self, args, kwargs):
"""Check for dynamic shapes when shape specialization is enabled"""
# TODO(jansel): need to get a complete list
if self.value in (
torch.nonzero,
torch.unique,
torch.unique_consecutive,
) or self.value.__name__ in ("nms",):
return True
if self.value is torch.where and len(args) + len(kwargs) == 1:
return True
if self.value in (
torch.arange,
torch.repeat_interleave,
):
none = variables.ConstantVariable(None)
def has_non_const(it):
return not all(x.is_python_constant() for x in it)
def arange(start=none, end=none, step=none, **kwargs):
return has_non_const([start, end, step])
def repeat_interleave(input, repeats, dim=none, **kwargs):
return has_non_const([repeats])
return locals()[self.value.__name__](*args, **kwargs)
return False
def _call_softmax(self, tx, args, kwargs, options):
"""rewrite the pattern nn.Softmax(dim=-1)(x) to F.softmax(x, -1)"""
dim = args[0] if args else kwargs.get("dim", variables.ConstantVariable(None))
def fake_softmax(input):
return variables.TensorVariable.create(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
torch.nn.functional.softmax,
*proxy_args_kwargs([input, dim], {}),
current_tx=tx,
),
**VariableTracker.propagate([self, dim, input]),
)
return variables.LambdaVariable(fake_softmax, **options)
def _call_cross_entropy_loss(self, tx, args, kwargs, options):
"""
functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
label_smoothing=0.0
non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
label_smoothing=0.0
non functional loss call: input, target, optional_output
"""
from . import ConstantVariable
def normalize_args(
weight=ConstantVariable(None),
size_average=ConstantVariable(None),
ignore_index=ConstantVariable(-100),
reduce=ConstantVariable(None),
reduction=ConstantVariable("mean"),
label_smoothing=ConstantVariable(0.0),
):
return (
weight,
size_average,
ignore_index,
reduce,
reduction,
label_smoothing,
)
(
weight,
size_average,
ignore_index,
reduce_arg,
reduction,
label_smoothing,
) = normalize_args(*args, **kwargs)
def fake_cross_entropy_loss(input, target):
return variables.TensorVariable.create(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
torch.nn.functional.cross_entropy,
*proxy_args_kwargs(
[
input,
target,
weight,
size_average,
ignore_index,
reduce_arg,
reduction,
label_smoothing,
],
{},
),
current_tx=tx,
),
**VariableTracker.propagate(
[
self,
weight,
size_average,
ignore_index,
reduce_arg,
reduction,
label_smoothing,
input,
target,
]
),
)
return variables.LambdaVariable(fake_cross_entropy_loss, **options)
def _call_ntuple(self, tx, args, kwargs, options):
"""inline behavior of torch.nn.modules.utils._ntuple"""
if self.value is torch.nn.modules.utils._ntuple:
count = args[0].as_python_constant()
else:
count = self.value.__closure__[0].cell_contents
assert isinstance(count, int)
def handle_ntuple(value):
if value.has_unpack_var_sequence(tx):
return variables.TupleVariable(
list(value.unpack_var_sequence(tx)),
**VariableTracker.propagate(self, value, args, kwargs.values()),
)
elif value.is_python_constant():
# constant prop through it
return variables.ConstantVariable(
torch.nn.modules.utils._ntuple(count)(value.as_python_constant()),
**VariableTracker.propagate(self, value, args, kwargs.values()),
)
else:
unimplemented(f"torch.nn.modules.utils._ntuple({value})")
if self.value is torch.nn.modules.utils._ntuple:
return variables.LambdaVariable(handle_ntuple, **options)
else:
return handle_ntuple(args[0])
class TorchPyOperator(VariableTracker):
def __init__(self, value, **kwargs):
super(TorchPyOperator, self).__init__(**kwargs)
self.value = value
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import ListVariable, TensorVariable, UserFunctionVariable
assert kwargs is None or len(kwargs) == 0, "kwargs are not supported, yet"
def unwrap_real(arg):
if isinstance(arg, TensorVariable):
return arg.get_real_value()
if isinstance(arg, UserFunctionVariable):
return arg.fn
if isinstance(arg, NNModuleVariable):
return tx.output.get_submodule(arg.module_key)
if arg.has_unpack_var_sequence(tx):
return [
unwrap_real(arg_inner) for arg_inner in arg.unpack_var_sequence(tx)
]
return arg
def make_attr(name, proxy_args=None):
node = tx.output.create_proxy(
"get_attr",
name,
tuple(proxy_args) if proxy_args else tuple(),
{},
)
return node
# Get values
u_args = [unwrap_real(arg) for arg in args]
def unwrap_proxy(arg):
try:
if isinstance(arg, TensorVariable):
return arg.as_proxy()
if isinstance(arg, NNModuleVariable):
name = arg.module_key
mod = unwrap_real(arg)
options = VariableTracker.propagate(self, args, kwargs.values())
tx.output.register_attr_or_module(
mod,
name,
name,
source=NNModuleSource(
GetItemSource(self.source, arg.module_key)
),
**options,
)
return make_attr(name)
if arg.has_unpack_var_sequence(tx):
return [
unwrap_proxy(arg_inner)
for arg_inner in arg.unpack_var_sequence(tx)
]
return arg.as_proxy()
except NotImplementedError:
return arg
def register_as_subgraph(fn, name, args):
from .. import export
gm, guards = export(fn, *args)
next_name = None
i = 0
while not next_name:
candidate = f"name_{i}"
if candidate in tx.output.nn_modules:
i += 1
else:
next_name = candidate
gm.__name__ = next_name
src = NNModuleSource(GetItemSource(self.source, next_name))
gm.torchdynamo_force_dynamic = False
tx.output.register_attr_or_module(gm, next_name, source=src)
return next_name, gm, guards
# Get args as proxies
p_args = [unwrap_proxy(arg) for arg in args]
if self.value.__name__ == "cond":
# TODO(voz): Support fake tensor dispatch for recursive
# ops - see torch/dispatch/_dispatcher.py
from .. import config
if config.fake_tensor_propagation:
unimplemented("Fake tensor mode not yet supported for cond")
assert len(p_args) == 4
assert type(args[0]) is TensorVariable # predicate
assert type(p_args[1]) is UserFunctionVariable # true_fn
assert type(p_args[2]) is UserFunctionVariable # false_fn
assert type(args[3]) is ListVariable # args
node_args = [unwrap_real(x) for x in args[3].unpack_var_sequence(tx)]
proxy_args = [unwrap_proxy(x) for x in args[3].unpack_var_sequence(tx)]
true_name, true_graph, true_guards = register_as_subgraph(
p_args[1].get_function(), "true", node_args
)
false_name, false_graph, false_guards = register_as_subgraph(
p_args[2].get_function(), "false", node_args
)
if config.enforce_cond_guards_match:
assert (
true_guards == false_guards
), "Guards for true and false path must be equal."
true_node = make_attr(true_name, proxy_args)
false_node = make_attr(false_name, proxy_args)
p_args[1] = true_node
p_args[2] = false_node
# Store the invocation as a call
return variables.TensorVariable.create(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(p_args),
kwargs={},
current_tx=tx,
),
example_value=self.value(*u_args),
)