From f9221bf53b376d1284e2356b716c2cd47fcd65f2 Mon Sep 17 00:00:00 2001 From: Ian Graves Date: Fri, 11 Nov 2022 00:19:20 +0000 Subject: [PATCH 01/62] [pytorch] Enable memory map file support for Android, Apple, and CXX (#88545) Summary: See title. Left Windows out so it still compiles. Test Plan: Add a `#fail` below [this line](https://fburl.com/code/p0mlhlw4) and build for various platforms and confirm it fails which proves the `#ifdef` was hit. ``` buck2 build xplat/langtech/tuna/cli:tuclixAndroid buck2 build xplat/langtech/tuna/cli:tuclix ``` CI/CD for the rest. Differential Revision: D41054824 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88545 Approved by: https://github.com/qihqi --- c2_defs.bzl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/c2_defs.bzl b/c2_defs.bzl index 573ba9f6ad64c2d..0a89bb88093db69 100644 --- a/c2_defs.bzl +++ b/c2_defs.bzl @@ -166,6 +166,7 @@ def get_c2_fbandroid_xplat_compiler_flags(): # T95767731 -- remove this once all builds are on at least llvm-13 "-Wno-unknown-warning-option", "-Wno-unused-but-set-variable", + "-DHAVE_MMAP", ] if get_c2_strip_glog(): @@ -392,6 +393,7 @@ def c2_cxx_library(**kwargs): args = get_c2_default_cxx_args() args.update(kwargs) args.setdefault("platforms", (ANDROID, APPLE, CXX, WINDOWS)) + fb_xplat_cxx_library( labels = [ "supermodule:android/default/caffe2", From 072834d56dada58f99216ce398fb57cce57968a9 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 8 Nov 2022 07:59:12 -0800 Subject: [PATCH 02/62] [ao] qconfig_mapping.py fixing public v private (#87518) Summary: made _GLOBAL_DICT_KEY, _OBJECT_TYPE_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY private Test Plan: python test/test_public_bindings.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D40709278](https://our.internmc.facebook.com/intern/diff/D40709278) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87518 Approved by: https://github.com/jcaip --- test/quantization/fx/test_quantize_fx.py | 20 ++++++------ .../quantization/fx/qconfig_mapping_utils.py | 8 ++--- torch/ao/quantization/qconfig_mapping.py | 32 +++++++++---------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 8c75658a04e1be0..6eb9246c85a7cb8 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -90,11 +90,11 @@ from torch.ao.quantization.qconfig_mapping import ( _get_symmetric_qnnpack_qconfig_mapping, - GLOBAL_DICT_KEY, - MODULE_NAME_DICT_KEY, - MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, - MODULE_NAME_REGEX_DICT_KEY, - OBJECT_TYPE_DICT_KEY, + _GLOBAL_DICT_KEY, + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, + _OBJECT_TYPE_DICT_KEY, QConfigMapping, ) @@ -1972,20 +1972,20 @@ def _get_qconfig_dict_for_qconfig_mapping_test(self, global_qconfig, qconfig1, q Return a dummy qconfig_dict to test QConfigMapping's to_dict and from_dict methods. """ return { - GLOBAL_DICT_KEY: global_qconfig, - OBJECT_TYPE_DICT_KEY: [ + _GLOBAL_DICT_KEY: global_qconfig, + _OBJECT_TYPE_DICT_KEY: [ (torch.nn.Linear, qconfig1), (torch.nn.ReLU, qconfig2), ], - MODULE_NAME_REGEX_DICT_KEY: [ + _MODULE_NAME_REGEX_DICT_KEY: [ ("foo.*bar", qconfig1), ("foo.*", qconfig2), ], - MODULE_NAME_DICT_KEY: [ + _MODULE_NAME_DICT_KEY: [ ("bazbaz", qconfig1), ("borbor", qconfig2), ], - MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ ("bazbaz", torch.nn.Linear, 0, qconfig1), ("foofoo", torch.nn.ReLU, 1, qconfig2), ], diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 66dffd50cd0081b..0b0407c0b106e49 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -23,9 +23,9 @@ get_qconfig_dtypes, ) from ..qconfig_mapping import ( - OBJECT_TYPE_DICT_KEY, - MODULE_NAME_DICT_KEY, - MODULE_NAME_REGEX_DICT_KEY, + _OBJECT_TYPE_DICT_KEY, + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, QConfigMapping, ) from ..qconfig_mapping_utils import ( @@ -223,7 +223,7 @@ def compare_prepare_convert_qconfig_mappings( convert_qconfig_mapping.module_name_qconfigs, convert_qconfig_mapping.module_name_regex_qconfigs, ] - dict_names = [OBJECT_TYPE_DICT_KEY, MODULE_NAME_DICT_KEY, MODULE_NAME_REGEX_DICT_KEY] + dict_names = [_OBJECT_TYPE_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY] for i in range(len(prepare_dicts)): for name, qconfig in prepare_dicts[i].items(): assert name in convert_dicts[i], "Missing key {} {} in convert QConfigMapping \ diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index 418cbb334814c94..e3410a52a9d83d6 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -33,11 +33,11 @@ # TODO: replace all usages with these constants -GLOBAL_DICT_KEY = "" -OBJECT_TYPE_DICT_KEY = "object_type" -MODULE_NAME_REGEX_DICT_KEY = "module_name_regex" -MODULE_NAME_DICT_KEY = "module_name" -MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" +_GLOBAL_DICT_KEY = "" +_OBJECT_TYPE_DICT_KEY = "object_type" +_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex" +_MODULE_NAME_DICT_KEY = "module_name" +_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" _FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = { torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer, @@ -274,11 +274,11 @@ def to_dict(self) -> Dict[str, Any]: The values of this dictionary are lists of tuples. """ return { - GLOBAL_DICT_KEY: self.global_qconfig, - OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()), - MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()), - MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()), - MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ + _GLOBAL_DICT_KEY: self.global_qconfig, + _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()), + _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()), + _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()), + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items() ], } @@ -302,14 +302,14 @@ def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping: The values of this dictionary are expected to be lists of tuples. """ conf = cls() - if GLOBAL_DICT_KEY in qconfig_dict: - conf.set_global(qconfig_dict[GLOBAL_DICT_KEY]) - for object_type, qconfig in qconfig_dict.get(OBJECT_TYPE_DICT_KEY, []): + if _GLOBAL_DICT_KEY in qconfig_dict: + conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY]) + for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []): conf.set_object_type(object_type, qconfig) - for module_name_regex, qconfig in qconfig_dict.get(MODULE_NAME_REGEX_DICT_KEY, []): + for module_name_regex, qconfig in qconfig_dict.get(_MODULE_NAME_REGEX_DICT_KEY, []): conf.set_module_name_regex(module_name_regex, qconfig) - for module_name, qconfig in qconfig_dict.get(MODULE_NAME_DICT_KEY, []): + for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []): conf.set_module_name(module_name, qconfig) - for module_name, object_type, index, qconfig in qconfig_dict.get(MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): + for module_name, object_type, index, qconfig in qconfig_dict.get(_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): conf.set_module_name_object_type_order(module_name, object_type, index, qconfig) return conf From 534ae6ae4790aec1b148b7e878ae60828ae45ac0 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Fri, 11 Nov 2022 01:08:16 +0000 Subject: [PATCH 03/62] [primTorch] Implement group norm reference (#87054) Add group norm reference Split from #81191 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87054 Approved by: https://github.com/mruberry --- test/test_fx.py | 4 +- test/test_ops.py | 5 +- torch/_decomp/decompositions.py | 31 ------ torch/_refs/__init__.py | 62 ++++++++++++ torch/_refs/nn/functional/__init__.py | 40 ++++++++ torch/nn/functional.py | 2 + .../_internal/common_methods_invocations.py | 97 ++++++++++++++++--- 7 files changed, 191 insertions(+), 50 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index 0aa5b28a3de7df4..0aff631b8e814a8 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3925,7 +3925,6 @@ def tearDown(self): "max_pool2d": PROXY_ITERABLE, "max_pool3d": PROXY_ITERABLE, - "group_norm": PROXY_ITERATED, "lp_pool2d": PROXY_ITERATED, "max_unpool1d": PROXY_ITERATED, "max_unpool2d": PROXY_ITERATED, @@ -3959,6 +3958,7 @@ def tearDown(self): "gaussian_nll_loss": CONTROL_FLOW, "glu": CONTROL_FLOW, "grid_sample": CONTROL_FLOW, + "group_norm": CONTROL_FLOW, "gumbel_softmax": CONTROL_FLOW, "hardsigmoid": CONTROL_FLOW, "hardswish": CONTROL_FLOW, @@ -4029,7 +4029,7 @@ def tearDown(self): "max_pool2d": PROXY_ITERATED, "max_pool3d": PROXY_ITERATED, - "group_norm": LEN_ERROR + "group_norm": CONTROL_FLOW } @classmethod diff --git a/test/test_ops.py b/test/test_ops.py index d0aa0906784dc25..73758bfc6b466e2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -417,9 +417,10 @@ def test_python_ref_executor(self, device, dtype, op, executor): # skip zero-dim tensors for some composites of reduction operations and view skip_zero_dim_ops = [ - "_refs.softmax", "_refs.logsumexp", "_refs.log_softmax", + "_refs.native_group_norm", + "_refs.softmax", "_refs.sum_to_size", "ops.nvprims.view", ] @@ -1659,11 +1660,13 @@ class TestRefsOpsInfo(TestCase): '_refs.index_add_', '_refs.index_copy_', '_refs.index_fill_', + '_refs.native_group_norm', } not_in_decomp_table = { # duplicated in _decomp and _refs '_refs.nn.functional.elu', + '_refs.nn.functional.group_norm', '_refs.nn.functional.mse_loss', '_refs.rsub', # duplicated due to efficiency concerns of the ref vs the decomp diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 0e1d1cd1dd511c4..fe63e0db007a7b1 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1138,37 +1138,6 @@ def normalize(input, norm_dims, eps): return out, mean, rstd -@register_decomposition(aten.native_group_norm.default) -def native_group_norm( - input: Tensor, - weight: Optional[Tensor], - bias: Optional[Tensor], - N: int, - C: int, - HxW: int, - group: int, - eps: float, -) -> Tuple[Tensor, Tensor, Tensor]: - orig_shape = input.shape - input = input.view(N, group, C // group, HxW) - reduction_dims = [2, 3] - out, mean, rstd = normalize(input, reduction_dims, eps) - mean = _squeeze_multiple(mean, reduction_dims) - rstd = _squeeze_multiple(rstd, reduction_dims) - out = out.view(orig_shape) - if weight is not None: - weight = _unsqueeze_to_dim(weight, out.dim() - 1) - out = out * weight - if bias is not None: - bias = _unsqueeze_to_dim(bias, out.dim() - 1) - out = out + bias - - out = out.to(dtype=input.dtype) - mean = mean.to(dtype=input.dtype) - rstd = rstd.to(dtype=input.dtype) - return (out, mean, rstd) - - @register_decomposition(aten.native_group_norm_backward) @pw_cast_for_opmath def native_group_norm_backward( diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index cd0344eba7a914b..36fef59df37578b 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -238,6 +238,7 @@ "movedim", "narrow", "narrow_copy", + "native_group_norm", "native_layer_norm", "permute", "ravel", @@ -2781,6 +2782,7 @@ def _normalize( mean (Tensor): mean of the tensor along norm_dims. rstd (Tensor): 1/std of the tensor along norm_dims. """ + norm_dims = utils.canonicalize_dims(a.ndim, norm_dims) computation_dtype = utils.get_computation_dtype(a.dtype) a_acc = _maybe_convert_to_dtype(a, computation_dtype) assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean @@ -2792,6 +2794,66 @@ def _normalize( return out, mean, rstd +# add all specified dimensions +def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType: + for dim in sorted(dimensions): + x = torch.unsqueeze(x, dim) + return x + + +@register_decomposition(torch.ops.aten.native_group_norm.default) +def native_group_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + batch_size: int, + num_channels: int, + flattened_inner_size: int, + num_groups: int, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + utils.check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + utils.check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # num_channels / num_groups and flattened inner dimension are the reduction axes + reduction_dims = [2, 3] + input_reshaped = torch.reshape( + input, + [batch_size, num_groups, num_channels // num_groups, flattened_inner_size], + ) + out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps) + out = out.view(input.shape) + + broadcast_dims = [0] + list(dim for dim in range(2, input.ndim)) + unsqueeze_bias = None + if bias is not None: + unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims) + unsqueeze_weight = None + if weight is not None: + unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims) + + if unsqueeze_weight is not None: + out = out * unsqueeze_weight + if unsqueeze_bias is not None: + out = out + unsqueeze_bias + + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] + + # remove broadcast dimensions from mean and rstd + mean = prims.squeeze(mean, reduction_dims) + rstd = prims.squeeze(rstd, reduction_dims) + return (out, mean, rstd) + + @register_decomposition(torch.ops.aten.native_layer_norm) def native_layer_norm( input: Tensor, diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index 3cde6784494766f..dcd86d8952d2662 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -171,6 +171,46 @@ def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: return torch.where(torch.le(a, 0), 0, a) +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.group_norm`. + """ + utils.check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + + batch_size = input.shape[0] + num_channels = input.shape[1] + utils.check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # input shape is (N, C, *), so we flatten all inner dimensions except (N, C) + flattened_inner_size = 1 + for dim_length in input.shape[2:]: + flattened_inner_size *= dim_length + + return torch.native_group_norm( + input, + weight, + bias, + batch_size, + num_channels, + flattened_inner_size, + num_groups, + eps, + )[0] + + def layer_norm( input: Tensor, normalized_shape: ShapeType, diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 79bf6297e587192..961dd83f57b2cfb 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2524,6 +2524,8 @@ def group_norm( """ if has_torch_function_variadic(input, weight, bias): return handle_torch_function(group_norm, (input, weight, bias,), input, num_groups, weight=weight, bias=bias, eps=eps) + if input.dim() < 2: + raise RuntimeError(f"Expected at least 2 dimensions for input tensor but received {input.dim()}") _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 731dc008ccce7af..b702c116186047d 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3334,27 +3334,72 @@ def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample= def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - # Ordered as input shape, num groups, and eps + # Ordered as input shape, num groups, and kwargs for eps cases: Tuple[Tuple[int], int, float] = ( # type: ignore[assignment] - ((1, 6, 3), 2, 0.5), - ((2, 6, 3), 2, -0.5), - ((1, 2), 1, None), - ((0, 2), 1, None), + ((1, 6, 3), 2, {'eps' : 0.5}), + ((2, 6, 3), 2, {'eps' : -0.5}), + ((1, 3), 1, {'eps' : 1e-5}), + ((0, 2), 1, {'eps' : 1e-5}), + ((S, S, S), 1, {'eps' : 0.5}), ) - for input_shape, num_groups, eps in cases: + # num_channels is inferred to be input.shape[1] dimension + for input_shape, num_groups, kwargs in cases: # Shape of weight and bias should be the same as num_channels - weight = make_arg(input_shape[1]) - bias = make_arg(input_shape[1]) - kwargs = {'weight': weight, 'bias': bias} if eps is None else {'weight': weight, 'bias': bias, 'eps': eps} - yield SampleInput( - make_arg(input_shape), - args=(num_groups,), - kwargs=kwargs - ) + channels = input_shape[1] if len(input_shape) > 1 else 0 + weight_tensor = make_arg(channels) + bias_tensor = make_arg(channels) + + # Checking for permutations of weights and biases as `None` + weights = [weight_tensor, None] + biases = [bias_tensor, None] + for weight, bias in itertools.product(weights, biases): + kwargs = { + 'weight': weight, + 'bias': bias, + **kwargs + } + yield SampleInput(make_arg(input_shape), num_groups, **kwargs) + # Without any optional args yield SampleInput(make_arg((1, 2)), args=(1,)) +def reference_inputs_group_norm(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_group_norm( + op_info, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, num groups, and kwargs for eps + cases: Tuple[Tuple[int], int, float] = ( # type: ignore[assignment] + ((20, 6, 10, 10), 3, {'eps' : 1e-5}), + # equivalent with InstanceNorm + # GroupNorm(C, num_groups=C) == InstanceNorm(num_features=C) + ((20, 6, 10, 10), 6, {'eps' : 1e-5}), + # equivalent with LayerNorm + # GroupNorm(C, num_groups=1, affine=False) == LayerNorm(normalized_shape=[C, H, W], elementwise_affine=False) + ((20, 6, 10, 10), 1, {'eps' : 1e-5}), + ) + + # num_channels is inferred to be input.shape[1] dimension + for input_shape, num_groups, kwargs in cases: + # Shape of weight and bias should be the same as num_channels + channels = input_shape[1] if len(input_shape) > 1 else 0 + input_tensor = make_arg(input_shape) + weight_tensor = make_arg(channels) + bias_tensor = make_arg(channels) + + # Checking for permutations of weights and biases as `None` + weights = [weight_tensor, None] + biases = [bias_tensor, None] + for weight, bias in itertools.product(weights, biases): + kwargs = { + 'weight': weight, + 'bias': bias, + **kwargs + } + yield SampleInput(input_tensor, num_groups, **kwargs) + def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -3481,6 +3526,18 @@ def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwar args=(normalized_shape, None, None, eps), ) +def error_inputs_group_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + + # check that input has minimum number of dimensions + err_msg1 = "Expected at least 2 dimensions for input tensor but received" + s1 = SampleInput(make_arg((1)), args=(1,)) + yield ErrorInput(s1, error_regex=err_msg1) + + # check that the channels dimension is compatible with number of groups + err_msg2 = "Expected number of channels in input to be divisible by num_groups, but got input of shape" + s2 = SampleInput(make_arg((2, 7, 4)), args=(2,)) + yield ErrorInput(s2, error_regex=err_msg2) def error_inputs_native_layer_norm(opinfo, device, **kwargs): make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) @@ -7747,12 +7804,12 @@ def reference_group_norm(inp: np.ndarray, num_groups: int, weight=None, bias=Non if weight is not None: # weight is a vector of length equal to the channel if len(Y.shape) > 2: - weight = np.tile(np.expand_dims(weight, 1), [1] + list(inp.shape[2:])) + weight = np.expand_dims(weight, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) Y = Y * weight if bias is not None: # bias is a vector of length equal to the channel if len(Y.shape) > 2: - bias = np.tile(np.expand_dims(bias, 1), [1] + list(inp.shape[2:])) + bias = np.expand_dims(bias, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) Y = Y + bias return Y @@ -10921,12 +10978,14 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_group_norm, decorators=[ # RuntimeError: Cannot insert a Tensor that requires grad as a constant. # Consider making it a parameter or input, or detaching the gradient DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)) ], sample_inputs_func=sample_inputs_group_norm, + reference_inputs_func=reference_inputs_group_norm, supports_expanded_weight=True,), OpInfo('nn.functional.instance_norm', # no ref because instance_norm will often have numerical instability (large numbers or nan) @@ -17941,6 +18000,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), ) ), + PythonRefInfo( + "_refs.nn.functional.group_norm", + torch_opinfo_name="nn.functional.group_norm", + supports_nvfuser=False, + validate_view_consistency=False, + ), PythonRefInfo( "_refs.narrow_copy", torch_opinfo_name="narrow_copy", From c961e45ee559a61bfb4f1e8a548e574ef89d3102 Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Thu, 10 Nov 2022 12:21:50 -0800 Subject: [PATCH 04/62] handle zero dims in reductions (#88280) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88280 Approved by: https://github.com/ngimel --- test/inductor/test_torchinductor.py | 21 +++++++++++++++++ torch/_inductor/ir.py | 36 +++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 8fd4fa29bf98a88..121f3d31f39c201 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4224,6 +4224,27 @@ def forward(x): ] self.common(forward, args) + def test_zero_dim_reductions(self): + for kd in [True, False]: + inps0 = (torch.zeros(2, 0, device=self.device, dtype=torch.float16), 1, kd) + failed_ops = [aten.argmin, aten.argmax, aten.max, aten.min] + for fo in failed_ops: + with self.assertRaisesRegex( + IndexError, "Expected reduction dim 1 to have non-zero size" + ): + mod = make_fx(fo)(*inps0) + _ = compile_fx_inner(mod, inps0) + + pass_ops = [ + lambda *x: fn(*x) for fn in [aten.sum, aten.prod, aten.any, aten.all] + ] + for po in pass_ops: + compiled = torch._dynamo.optimize("inductor")(po) + expected = po(*inps0) + actual = compiled(*inps0) + + self.assertTrue(torch.allclose(actual, expected, atol=1e-3, rtol=1e-3)) + @requires_cuda() def test_unspec_inputs(self): def fn(x, y): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 924ec7aaa7b2e47..448c057ecb0e15d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -729,6 +729,42 @@ def create( reduction_hint: ReductionHint = ReductionHint.DEFAULT, ): reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + if reduction_numel == 0: + + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val): + return ( + bool(val) + if dst_dtype == torch.bool + else float(val) + if dst_dtype.is_floating_point + else int(val) + ) + + rtypes_to_inits = { + "sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + assert ( + reduction_type in rtypes_to_inits.keys() + ), f"{reduction_type} not supported for zero-dimension tensors!" + + def const_fn(index): + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + return Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + ) + if reduction_numel == 1: # this reduction is actually a pointwise op if reduction_type in ("argmin", "argmax"): From fc9e36dd426d4747bb7c71ee93bcbaa700bda01d Mon Sep 17 00:00:00 2001 From: anjali411 Date: Thu, 10 Nov 2022 22:41:47 +0000 Subject: [PATCH 05/62] Add meta support for scalar_tensor and argmax (#88590) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88590 Approved by: https://github.com/albanD --- test/functorch/test_vmap.py | 1 + test/test_proxy_tensor.py | 6 +-- torch/_meta_registrations.py | 42 +++++++++++++++++++ .../_internal/common_methods_invocations.py | 32 ++++++++++++-- 4 files changed, 74 insertions(+), 7 deletions(-) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 3acab4172fce11e..5ba35de21b8b73c 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3229,6 +3229,7 @@ def test(): xfail('linspace', ''), # test runner can't handle factory functions xfail('arange', ''), # test runner can't handle factory functions xfail('logspace', ''), # test runner can't handle factory functions + xfail('scalar_tensor'), # test runner can't handle factory functions xfail('empty', ''), # test runner can't handle factory functions xfail('ones', ''), # test runner can't handle factory functions xfail('zeros', ''), # test runner can't handle factory functions diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index fbeaa04aa65d980..72c7249f4f14582 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1116,8 +1116,8 @@ def f(a, b, c, d, e): skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition - xfail('masked.argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition + xfail('masked.argmax', ''), # broadcast_to(): argument 'size' (position 2) must be tuple of ints, but found ... + xfail('masked.argmin', ''), # broadcast_to(): argument 'size' (position 2) must be tuple of ints, but found ... xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition @@ -1134,8 +1134,6 @@ def f(a, b, c, d, e): xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition - xfail('argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition - xfail('argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition xfail('argwhere', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition xfail('bucketize', ''), # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 5035eadf84a4740..04c522ab9e3b4db 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1735,6 +1735,48 @@ def meta_sort(self, stable=None, dim=-1, descending=False): return torch.empty_like(self), torch.empty_like(self, dtype=torch.int64) +def zero_numel_check_dims(self, dim, fn_name): + if self.ndim == 0: + check( + dim == 0 or dim == -1, + lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}", + IndexError, + ) + else: + check( + self.size(dim) != 0, + lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.", + IndexError, + ) + + +# From aten/src/ATen/native/ReduceOps.cpp +def check_argmax_argmin(name, self, dim): + if dim is not None: + dim = maybe_wrap_dim(dim, self.dim()) + zero_numel_check_dims(self, dim, name) + else: + check( + self.numel() != 0, + lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.", + ) + + +@register_meta([aten.argmax.default, aten.argmin.default]) +def argmax_argmin_meta(self, dim=None, keepdim=False): + check_argmax_argmin("argmax", self, dim) + dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None) + shape = _compute_reduction_shape(self, dims, keepdim) + return self.new_empty(shape, dtype=torch.int64) + + +@register_meta(aten.scalar_tensor.default) +def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None): + return torch.empty( + (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + # We must also trigger meta registrations from PrimTorch ref # decompositions import torch._refs diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b702c116186047d..b41e74a24c10483 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1372,6 +1372,15 @@ def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs): for case in cases: yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad) +def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs): + # Not including a scalar tensor in vals because meta tests start failing due to + # lack of meta support for _local_scalar_dense + # torch.tensor(2, device=device) + vals = (-5, 0, 1) + + for item in vals: + yield SampleInput(item, device=device, dtype=dtype, requires_grad=requires_grad) + def sample_inputs_eye(op, device, dtype, requires_grad, **kwargs): # only ints >= 0 are allowed for both arguments, unless m is omitted sizes = (None, 0, 1, 2, 3, 4, 7, L, M, S) @@ -9287,9 +9296,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): error_inputs_func=error_inputs_diag), OpInfo('diag_embed', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), - # TODO: this is very questionable, because we do have - # diag_embed.out but it's not bound to Python somehow - # https://github.com/pytorch/pytorch/issues/88598 supports_out=False, # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True, @@ -10546,6 +10552,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): assert_jit_shape_analysis=True, sample_inputs_func=sample_inputs_native_batch_norm, skips=( + # NotImplementedError: Could not run + # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), @@ -14511,6 +14519,24 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), )), + OpInfo('scalar_tensor', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_scalar_tensor, + supports_autograd=False, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + )), OpInfo('new_full', op=lambda x, *args, **kwargs: x.new_full(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), From 3fbf748f2109de408bd47efb1a43e3897d7a775c Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Fri, 11 Nov 2022 02:30:29 +0000 Subject: [PATCH 06/62] Assert we have triton before scheduling on triton (#88849) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88849 Approved by: https://github.com/wconstab, https://github.com/ngimel, https://github.com/jansel --- torch/_inductor/scheduler.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 2f1c4b7c2e64357..cb71a44438049c2 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -16,7 +16,7 @@ from . import config, dependencies, ir from .dependencies import MemoryDep, StarDep from .sizevars import SimplifyIndexing -from .utils import cache_on_self, cmp, dynamo_utils +from .utils import cache_on_self, cmp, dynamo_utils, has_triton from .virtualized import V log = logging.getLogger(__name__) @@ -1078,6 +1078,16 @@ def create_backend(self, device: torch.device): return CppScheduling(self) else: + if not has_triton(): + device_props = torch.cuda.get_device_properties(device) + if device_props.major < 6: + raise RuntimeError( + f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 6.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950 + ) + else: + raise RuntimeError( + "Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950 + ) from .codegen.triton import TritonScheduling return TritonScheduling(self) From 495e7b1c729e64693e794ea22640b4552816f0ef Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 10 Nov 2022 21:22:29 +0000 Subject: [PATCH 07/62] Ref for aten.full; symint changes in prim (#88762) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88762 Approved by: https://github.com/ezyang --- test/functorch/test_vmap.py | 1 + test/test_ops.py | 1 - torch/_prims_common/__init__.py | 5 ++- torch/_refs/__init__.py | 17 +++++--- .../_internal/common_methods_invocations.py | 40 +++++++++++++++++++ 5 files changed, 56 insertions(+), 8 deletions(-) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 5ba35de21b8b73c..6d95077b627e2c2 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3233,6 +3233,7 @@ def test(): xfail('empty', ''), # test runner can't handle factory functions xfail('ones', ''), # test runner can't handle factory functions xfail('zeros', ''), # test runner can't handle factory functions + xfail('full', ''), # test runner can't handle factory functions xfail('eye', ''), # non-tensor input xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops xfail('sparse.sampled_addmm'), # sparse diff --git a/test/test_ops.py b/test/test_ops.py index 73758bfc6b466e2..c688f6521af1421 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1743,7 +1743,6 @@ class TestRefsOpsInfo(TestCase): '_refs.unflatten', '_refs.sum_to_size', # ref implementation missing kwargs - '_refs.full', # missing "layout" '_refs.full_like', # missing "layout" '_refs.ones_like', # missing "layout" '_refs.round', # missing "decimals" diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 90777ed6601aacb..128796dfa3d0717 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -837,10 +837,11 @@ def type_to_dtype(typ: type) -> torch.dtype: if typ is bool: return torch.bool - if typ is int: + if typ in [int, torch.SymInt]: return torch.long - if typ is float: + if typ in [float, torch.SymFloat]: return torch.get_default_dtype() + # TODO: sym_complex_float? if typ is complex: return corresponding_complex_dtype(torch.get_default_dtype()) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 36fef59df37578b..43b0c74192dee5f 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -322,7 +322,7 @@ def _broadcast_shapes(*_shapes): common_shape = [ 1, ] * reduce(max, (len(shape) for shape in shapes)) - for shape in shapes: + for arg_idx, shape in enumerate(shapes): for idx in range(-1, -1 - len(shape), -1): if common_shape[idx] == 1: if shape[idx] < 0: @@ -333,9 +333,9 @@ def _broadcast_shapes(*_shapes): elif shape[idx] != 1: if common_shape[idx] != shape[idx]: raise RuntimeError( - "Attempting to broadcast a dimension of length ", - str(shape[idx]), - "!", + f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " + f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " + f"should be broadcastable to {common_shape}" ) return common_shape @@ -4495,6 +4495,7 @@ def eye( # result.requires_grad_(requires_grad) +@register_decomposition(torch.ops.aten.full) @out_wrapper() def full( shape: ShapeType, @@ -4506,6 +4507,12 @@ def full( pin_memory: bool = False, requires_grad: bool = False, ) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value)) + device = device if device is not None else torch.device("cpu") + e = empty( shape, dtype=dtype, @@ -4514,7 +4521,7 @@ def full( pin_memory=pin_memory, requires_grad=requires_grad, ) - return fill(e, fill_value) + return torch.fill(e, fill_value) # type: ignore[arg-type] def full_like( diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b41e74a24c10483..5178ec978bd1c63 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -772,6 +772,20 @@ def sample_inputs_ones_zeros(op, device, dtype, requires_grad, **kwargs): for size in sizes: yield SampleInput(size, kwargs={'dtype': dtype, 'device': device}) +def sample_inputs_full(op, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + sizes = ( + (M,), + (S, S), + ) + fill_values = [get_val(dtype), get_val(torch.int)] + + for size, fill_value in product(sizes, fill_values): + yield SampleInput(size, fill_value, dtype=dtype, device=device) + + def error_inputs_uniform(op, device, **kwargs): t = torch.zeros([10], device=device) yield ErrorInput( @@ -14373,6 +14387,32 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), )), + OpInfo('full', + op=torch.full, + supports_autograd=False, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_full, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # boolean alpha not handled properly + DecorateInfo(unittest.expectedFailure, + 'TestCudaFuserOpInfo', + 'test_nvfuser_correctness', + dtypes=(torch.bool,)), + # RuntimeError: UNSUPPORTED DTYPE: bool + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bool,)), + )), OpInfo('new_empty', op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), From 3082378701605884ff07f7ba7984864340b19b34 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 11 Nov 2022 03:33:55 +0000 Subject: [PATCH 08/62] [vision hash update] update the pinned vision hash (#88853) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/master/.github/workflows/_update-commit-hash.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88853 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index d8180093d8859f2..48685938a146b49 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -ffd5a567eb90abf6b5555063da434d3c130d540f +d72e90640ec8514e0369b5419d7f3b74a387b1d7 From 9d09968bbe05fc6d7d7c3d8b1acfbe1b1b1413a8 Mon Sep 17 00:00:00 2001 From: Emil Lynegaard Date: Fri, 11 Nov 2022 03:34:54 +0000 Subject: [PATCH 09/62] Disable check for dropout in MultiheadAttention fast_path (#88831) Since we already enforce eval mode for the fast_path, we do not need to also check for a falsy dropout value, as a model trained with dropout will have a non-zero dropout during eval mode, even though it won't be applied. Fixes #88806 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88831 Approved by: https://github.com/drisspg --- torch/nn/modules/activation.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 5f5615b496d7d05..7b0e7e3effaac4c 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -904,7 +904,6 @@ class MultiheadAttention(Module): - inputs are batched (3D) with ``batch_first==True`` - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` - training is disabled (using ``.eval()``) - - dropout is 0 - ``add_bias_kv`` is ``False`` - ``add_zero_attn`` is ``False`` - ``batch_first`` is ``True`` and the input is batched @@ -1088,8 +1087,6 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O why_not_fast_path = "self.bias_k was not None" elif self.bias_v is not None: why_not_fast_path = "self.bias_v was not None" - elif self.dropout: - why_not_fast_path = f"dropout was {self.dropout}, required zero" elif self.add_zero_attn: why_not_fast_path = "add_zero_attn was enabled" elif not self._qkv_same_embed_dim: From c4fc5d372f3db37380fe213b5726403cb1330d5d Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 7 Nov 2022 23:46:29 +0000 Subject: [PATCH 10/62] [FSDP][state_dict][1/N] Moving state_dict logic to pre_state_dict_hook (#87900) This is one step toward the ultimate goal: remove the overwritten state_dict in FSDP. All the logic should be either in `pre_state_dict_hook` or `post_state_dict_hook`. Since current `nn.Module` does not support `pre_state_dict_hook`, this PR mimic `pre_state_dict_hook` by calling the pre hook inside post the hook, effectively ditching all the work done by `nn.Module.state_dict`. Once `pre_state_dict_hook` is supported by `nn.Module`, these pre hook calls can be moved out from the post hooks and be registered to `nn.Module.pre_state_dict_hook`. The major issue of this temporary solution is that `post_state_dict_hook` is called from the leaf node to the root node. This makes the `module._lazy_init()` invalid as FSDP assumes `_lazy_init()` to be called from the root. As a result, `FSDP.state_dict` currently contains only one logic -- calling `module._lazy_init()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87900 Approved by: https://github.com/rohan-varma --- test/distributed/fsdp/test_fsdp_state_dict.py | 2 +- torch/distributed/fsdp/_runtime_utils.py | 19 +- torch/distributed/fsdp/_state_dict_utils.py | 388 +++++++++++++----- .../fsdp/fully_sharded_data_parallel.py | 101 +---- 4 files changed, 288 insertions(+), 222 deletions(-) diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 133405033730d5e..48dad3118db749f 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -447,7 +447,7 @@ def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool): ) @parametrize("fp16", [True, False]) @parametrize("state_dict_rank0_and_offload", [True, False]) - @parametrize("use_orig_params", [False, True]) + @parametrize("use_orig_params", [True, False]) def test_basic_save_and_load_state_dict( self, state_dict_type: StateDictType, diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index 9aee15a016c445f..e0986d300a65ac2 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -1113,28 +1113,23 @@ def _get_buffers_and_dtypes_for_computation( @no_type_check -def _get_buffers_and_dtypes_for_checkpoint( +def _get_buffer_dtypes( state: _FSDPState, - root_module: nn.Module, -) -> Tuple[List[torch.Tensor], List[torch.dtype]]: + buffer_names: List[str], +) -> List[torch.dtype]: """ - Returns all buffers in the module tree rooted at ``root_module`` and a - corresponding list of the buffer dtypes for checkpointing. Each buffer - dtype is the original buffer dtype ignoring any buffer mixed precision. + Returns the original buffer types of the given buffer names. """ - p_assert(state._is_root, "Expects the root to cast buffers") - buffers: List[torch.Tensor] = [] - buffer_dtypes: List[Optional[torch.dtype]] = [] - for buffer_name, buffer in root_module.named_buffers(): + buffer_dtypes: List[torch.dtype] = [] + for buffer_name in buffer_names: p_assert( buffer_name in state._buffer_name_to_orig_dtype, f"{buffer_name} is missing from pre-computed dict on rank " f"{state.rank}, which only has keys " f"{state._buffer_name_to_orig_dtype.keys()}", ) - buffers.append(buffer) buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name]) - return buffers, buffer_dtypes + return buffer_dtypes def _cast_buffers_to_dtype_and_device( diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 0169aa8f10eb2ed..1109f1e881506d2 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -1,7 +1,7 @@ import functools import math import warnings -from typing import Any, cast, Dict +from typing import Any, Callable, cast, Dict import torch import torch.distributed as dist @@ -11,15 +11,22 @@ import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file import torch.nn as nn import torch.nn.functional as F + from torch.distributed._shard.sharded_tensor import ( init_from_local_shards, Shard, ShardedTensor, ) -from torch.distributed.fsdp._common_utils import clean_tensor_name +from torch.distributed.fsdp._common_utils import ( + clean_tensor_name, + FSDP_PREFIX, + TrainingState, +) from torch.distributed.fsdp._runtime_utils import ( _cast_buffers_to_dtype_and_device, - _get_buffers_and_dtypes_for_computation, + _clear_grads_if_needed, + _get_buffer_dtypes, + _lazy_init, ) from torch.distributed.utils import _replace_by_prefix @@ -31,49 +38,218 @@ from .flat_param import FlatParamHandle -def _full_post_state_dict_hook( +def _enter_full_param_ctx( + module, + recurse: bool = False, + writeback: bool = False, + rank0_only: bool = False, + offload_to_cpu: bool = False, + with_grads: bool = False, +) -> None: + """ + state_dict hooks cannot use the pure context call as the checkpoint flow + requires to enter the context in the pre-hook but leave the context in the + post-hook. This API enters the context of ``summon_full_params``. + """ + assert module._full_param_ctx is None, ( + "Entering the ``summon_full_params`` context but module._full_param_ctx " + "is not None." + ) + assert module.training_state != TrainingState.SUMMON_FULL_PARAMS, ( + "Entering the summon_full_params context but the state is already " + "SUMMON_FULL_PARAMS." + ) + module._full_param_ctx = module._summon_full_params( + recurse=recurse, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ) + module._full_param_ctx.__enter__() + + +def _exit_full_param_ctx(module) -> None: + """A helper function to exit ``summon_full_params`` context.""" + module._assert_state([TrainingState.SUMMON_FULL_PARAMS]) + assert module._full_param_ctx is not None + module._full_param_ctx.__exit__(None, None, None) + module._full_param_ctx = None + + +def _common_pre_state_dict_hook( + module, + state_dict: Dict[str, Any], + prefix: str, +) -> None: + """Performs the pre-state_dict tasks shared by all state_dict types.""" + if torch.cuda.is_available(): + torch.cuda.synchronize() + _lazy_init(module, module) + # TODO: change to this call after pre_state_dict_hook is in `nn.Module`. + # if module.is_root: + # _clear_grads_if_needed(module._fsdp_handles(module)) + if module._has_params: + _clear_grads_if_needed([module._handles[0]]) + + +def _common_summon_pre_state_dict_hook( + module, + offload_to_cpu: bool, + rank0_only: bool, +) -> None: + """ + Performs the pre-state_dict tasks shared by all state_dict types that require + ``summon_full_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. + """ + _enter_full_param_ctx( + module, + recurse=False, + writeback=False, + offload_to_cpu=offload_to_cpu, + rank0_only=rank0_only, + ) + + +# TODO: change to the decorator style. See ``_full_pre_state_dict_hook``. +def _common_summon_post_state_dict_hook( module, state_dict: Dict[str, Any], prefix: str, + param_hook: Callable, ) -> Dict[str, Any]: """ - Hook that runs after model.state_dict() is called before returning result to - user. For FSDP, we may have to clone the tensors in state_dict as params go - back to sharded version after _summon_full_params ends, and also remove - the ``FSDP_WRAPPED_MODULE`` prefix. + The post-state_dict flow that shared by all state_dict types that require + ``summon_full_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this + hook. """ - _replace_by_prefix(state_dict, prefix + f"{fsdp_file.FSDP_PREFIX}", prefix) - module._assert_state([fsdp_file.TrainingState.SUMMON_FULL_PARAMS]) + _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) + module._assert_state([TrainingState.SUMMON_FULL_PARAMS]) # Return early for trivial cases if not state_dict or not module._has_params: + _exit_full_param_ctx(module) return state_dict - # If a rank has already exited the `summon_full_params()` context here - # (e.g. when `rank0_only=True` and `rank != 0`), then the rank only - # needed to participate in the all-gather and does not need to save the - # state dict. For `use_orig_params=False`, we can check this via - # `FlatParameter` registration. - # TODO: For `use_orig_params=True`, we check for the reshard upon - # exiting `summon_full_params()` via the parameter shape. However, for - # `NO_SHARD`, we cannot tell from the shape, so we do not return early. - if ( - not module._use_orig_params - and fsdp_file.FLAT_PARAM in module.module._parameters - ) or ( - module._use_orig_params - and module._handles - and module._handles[0].uses_sharded_strategy - and module._handles[0].is_sharded(module._handles[0].flat_param) - ): - return state_dict + # TODO: Once pre_state_dict hook is supported, this pop should be removed. + # For `use_orig_params=True`, the `FlatParameter` is not registered, so + # there is no entry in the state dict for it to pop. + if not module._use_orig_params: + state_dict.pop(f"{prefix}{fsdp_file.FLAT_PARAM}") - offload_to_cpu = module._state_dict_config.offload_to_cpu - cpu_device = torch.device("cpu") + # If a rank does not have unsharded parameters(when `rank0_only=True` + # and `rank != 0`), then the rank only needed to participate in the + # all-gather and does not need to save the # state dict. We simply check + # rank0_only to ensure this issue. + rank0_only = ( + module._state_dict_type == fsdp_file.StateDictType.FULL_STATE_DICT + and cast(fsdp_file.FullStateDictConfig, module._state_dict_config).rank0_only + ) + # no_fsdp_return means the state_dict returned by this rank should contain + # only non-FSDP controlled parameters and buffers. + no_fsdp_return = rank0_only and module.rank != 0 + if no_fsdp_return and not module._use_orig_params: + for clean_key in module._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_key.replace( + f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" + ) + state_dict.pop(f"{prefix}{clean_key}", None) + _exit_full_param_ctx(module) + return state_dict # Loop only the parameters saved in this instance's wrapped module to # avoid processing buffers. for fqn, param_name, module_name in module._param_fqns: + # TODO: remove the parameter retrieval. See ``_full_pre_state_dict_hook``. + param = functools.reduce(getattr, fqn.split("."), module.module) fqn = f"{prefix}{fqn}" + if no_fsdp_return: + state_dict.pop(fqn) + continue + state_dict[fqn] = param + assert fqn in state_dict, ( + f"FSDP assumes {fqn} is in the state_dict but the state_dict only " + f"has {state_dict.keys()}. " + f"prefix={prefix}, module_name={module_name}, " + f"param_name={param_name} rank={module.rank}." + ) + + param_hook(module, state_dict, prefix, fqn) + _exit_full_param_ctx(module) + + cpu_device = torch.device("cpu") + buffer_clean_fqns = [] + buffers = [] + for clean_key in module._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_tensor_name(clean_key) + fqn = f"{prefix}{clean_key}" + if fqn not in state_dict: + # A buffer can be registered as non-persistent. + continue + if no_fsdp_return: + state_dict.pop(fqn) + else: + buffer = state_dict[fqn] + if module._state_dict_config.offload_to_cpu and buffer.device != cpu_device: + state_dict[fqn] = buffer.to(cpu_device) + # TODO: for composable FSDP, this should be clean_tensor_name(clean_key), + buffer_clean_fqns.append(clean_key) + buffers.append(state_dict[fqn]) + if buffers and module._mixed_precision_enabled_for_buffers(): + buffer_dtypes = _get_buffer_dtypes(module, buffer_clean_fqns) + _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, module.compute_device) + for buffers, clean_fqn in zip(buffers, buffer_clean_fqns): + fqn = f"{prefix}{clean_fqn}" + state_dict[fqn] = buffer.clone() + return state_dict + + +def _full_pre_state_dict_hook( + module, + state_dict: Dict[str, Any], + prefix: str, +) -> None: + """ + Hook that runs before model.state_dict() is called. pre-state_dict hook is + not actually supported by ``nn.Module``. As a result, this API is called + from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict + is supported in ``nn.Module``, this hook will be registered as a hook in + ``nn.Module``. + + TODO: clean the callsites and hacks after ``pre_state_dict_hook` ` is supported + in ``nn.Module``. + """ + _common_pre_state_dict_hook(module, state_dict, prefix) + _common_summon_pre_state_dict_hook( + module, + offload_to_cpu=module._state_dict_config.offload_to_cpu, + rank0_only=cast( + fsdp_file.FullStateDictConfig, module._state_dict_config + ).rank0_only, + ) + + +def _full_post_state_dict_hook( + module, + state_dict: Dict[str, Any], + prefix: str, +) -> Dict[str, Any]: + """ + Hook that runs after model.state_dict() is called before returning result to + user. For FSDP, we may have to clone the tensors in state_dict as params go + back to sharded version after _summon_full_params ends, and also remove + the ``FSDP_WRAPPED_MODULE`` prefix. + """ + # TODO: remove the hack. See ``_full_pre_state_dict_hook``. + _full_pre_state_dict_hook(module, state_dict, prefix) + + def param_hook( + module, + state_dict: Dict[str, Any], + prefix: str, + fqn: str, + ) -> None: clean_key = fqn clean_prefix = clean_tensor_name(prefix) # Strip prefix out of key if needed as buffer names and param names @@ -84,11 +260,6 @@ def _full_post_state_dict_hook( # Clone non-ignored parameters before exiting the # `_summon_full_params()` context - assert fqn in state_dict, ( - f"FSDP assumes {fqn} is in the state_dict but the state_dict " - f"only has {state_dict.keys()}. prefix={prefix}, " - f"module_name={module_name} param_name={param_name} rank={module.rank}." - ) if clean_key not in module._ignored_param_names and not getattr( state_dict[fqn], "_has_been_cloned", False ): @@ -104,24 +275,7 @@ def _full_post_state_dict_hook( f"implementation of {fqn}. Error: {str(e)}" ) - # Offload the buffer to CPU if needed -- we do not do this in - # `_summon_full_params()` since without care, that would free - # the original buffer's GPU memory and require reallocating - # that memory later; this only affects the state dict's buffer - # variable and leaves the original buffer's GPU memory intact - if offload_to_cpu: - for clean_key in module._buffer_names: - # This is a hack to support activation checkpoint. - clean_key = clean_key.replace( - f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" - ) - fqn = f"{prefix}{clean_key}" - if fqn not in state_dict: - # A buffer can be registered as non-persistent. - continue - if state_dict[fqn].device != cpu_device: - state_dict[fqn] = state_dict[fqn].to(cpu_device) - return state_dict + return _common_summon_post_state_dict_hook(module, state_dict, prefix, param_hook) def _full_pre_load_state_dict_hook( @@ -129,21 +283,30 @@ def _full_pre_load_state_dict_hook( state_dict: Dict[str, Any], prefix: str, ) -> None: - # We do not expect to be calling pre-hooks twice without post-hook - # call in between. - assert getattr(module, "_full_param_ctx", None) is None - # Note that it needs writeback=True to persist. - module._full_param_ctx = module._summon_full_params(recurse=False, writeback=True) - module._full_param_ctx.__enter__() - _replace_by_prefix(state_dict, prefix, prefix + f"{fsdp_file.FSDP_PREFIX}") + _enter_full_param_ctx(module, recurse=False, writeback=True) + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") def _full_post_load_state_dict_hook(module, *args, **kwargs) -> None: - # We should exit summon_full_params context. - module._assert_state([fsdp_file.TrainingState.SUMMON_FULL_PARAMS]) - assert getattr(module, "_full_param_ctx", None) is not None - module._full_param_ctx.__exit__(None, None, None) - module._full_param_ctx = None + _exit_full_param_ctx(module) + + +def _local_pre_state_dict_hook( + module, + state_dict: Dict[str, Any], + prefix: str, +) -> None: + """ + Hook that runs before model.state_dict() is called. Right now, pre-state_dict + hook is not supported by the PyTorch core. So this API is called from + `_local_post_state_dict_hook()` to simulate the case. + """ + if module._has_params and not module._handles[0].uses_sharded_strategy: + raise RuntimeError( + "``local_state_dict`` can only be used when parameters are flatten " + "and sharded." + ) + _common_pre_state_dict_hook(module, state_dict, prefix) def _local_post_state_dict_hook( @@ -156,7 +319,10 @@ def _local_post_state_dict_hook( the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy will happen. The underlying storage is the same. """ - _replace_by_prefix(state_dict, f"{prefix}{fsdp_file.FSDP_PREFIX}", prefix) + # TODO: remove the hack. See ``_full_pre_state_dict_hook``. + _local_pre_state_dict_hook(module, state_dict, prefix) + + _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix) if not module._has_params: return state_dict @@ -198,8 +364,8 @@ def _local_pre_load_state_dict_hook( state_dict. The flat_param should be a ShardedTensor. This hook converts the ShardedTensor to a tensor. No copy happen unless padding is required. """ - _replace_by_prefix(state_dict, prefix, f"{prefix}{fsdp_file.FSDP_PREFIX}") - fqn = f"{prefix}{fsdp_file.FSDP_PREFIX}{fsdp_file.FLAT_PARAM}" + _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}") + fqn = f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}" if fqn not in state_dict: assert not module._has_params, ( "No `FlatParameter` in `state_dict` for this FSDP instance " @@ -229,6 +395,30 @@ def _local_pre_load_state_dict_hook( state_dict[fqn] = load_tensor +def _sharded_pre_state_dict_hook( + module, + state_dict: Dict[str, Any], + prefix: str, +) -> None: + """ + Hook that runs before model.state_dict() is called. Check + ``_full_pre_load_state_dict_hook`` for the detail. + """ + if module._has_params and not module._handles[0].uses_sharded_strategy: + raise RuntimeError( + "``sharded_state_dict`` can only be used when parameters are flatten " + "and sharded." + ) + _common_pre_state_dict_hook(module, state_dict, prefix) + # Setting offload_to_cpu here does not work even if offload_to_cpu is True. + # We have to create ShardedTensor first then move it to CPU. + _common_summon_pre_state_dict_hook( + module, + offload_to_cpu=False, + rank0_only=False, + ) + + def _sharded_post_state_dict_hook( module, state_dict: Dict[str, Any], @@ -238,33 +428,24 @@ def _sharded_post_state_dict_hook( The hook replaces the unflattened, unsharded parameter in the state_dict with a unflattened, sharded parameter (a ShardedTensor). """ - _replace_by_prefix(state_dict, f"{prefix}{fsdp_file.FSDP_PREFIX}", prefix) - if not module._has_params: - return state_dict - assert module.training_state != fsdp_file.TrainingState.SUMMON_FULL_PARAMS, ( - "Inside _sharded_post_state_dict_hook, the training_state must " - "not be SUMMON_FULL_PARAMS." - ) - with module._summon_full_params(recurse=False, writeback=False): - for fqn, _, _ in module._param_fqns: - # Create a ShardedTensor for the unflattened, non-sharded parameter. - param = functools.reduce(getattr, fqn.split("."), module.module) - sharded_tensor = _ext_chunk_tensor( - tensor=param, - rank=module.rank, - world_size=module.world_size, - num_devices_per_node=torch.cuda.device_count(), - pg=module.process_group, - ) - if module._state_dict_config.offload_to_cpu: - sharded_tensor = sharded_tensor.cpu() - state_dict[f"{prefix}{fqn}"] = sharded_tensor - # For `use_orig_params=True`, the `FlatParameter` is not registered, so - # there is no entry in the state dict for it to pop. - if not module._use_orig_params: - state_dict.pop(f"{prefix}{fsdp_file.FLAT_PARAM}") - return state_dict + # TODO: remove the hack. See ``_full_pre_state_dict_hook``. + _sharded_pre_state_dict_hook(module, state_dict, prefix) + + def param_hook(module, state_dict: Dict[str, Any], prefix: str, fqn: str): + param = state_dict[fqn] + sharded_tensor = _ext_chunk_tensor( + tensor=param, + rank=module.rank, + world_size=module.world_size, + num_devices_per_node=torch.cuda.device_count(), + pg=module.process_group, + ) + if module._state_dict_config.offload_to_cpu: + sharded_tensor = sharded_tensor.cpu() + state_dict[fqn] = sharded_tensor + + return _common_summon_post_state_dict_hook(module, state_dict, prefix, param_hook) def _sharded_post_load_state_dict_hook(module, *args, **kwargs) -> None: @@ -281,7 +462,7 @@ def _sharded_pre_load_state_dict_hook( The hook combines the unflattened, sharded parameters (ShardedTensor) to a new FlatParameter and shards the new FlatParameter to the local chunk. """ - _replace_by_prefix(state_dict, prefix, prefix + f"{fsdp_file.FSDP_PREFIX}") + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") if not module._has_params: return @@ -295,7 +476,7 @@ def _sharded_pre_load_state_dict_hook( shared_fqns = [fqn for fqn, _, _ in module._shared_param_fqns] loaded_shapes = [] for fqn, _, _ in module._param_fqns: - full_fqn = f"{prefix}{fsdp_file.FSDP_PREFIX}{fqn}" + full_fqn = f"{prefix}{FSDP_PREFIX}{fqn}" param = state_dict.pop(full_fqn) if fqn in shared_fqns: continue @@ -353,9 +534,7 @@ def _sharded_pre_load_state_dict_hook( f"The loaded local chunk has different padding({num_to_pad}) " f"from the local chunk {flat_param._shard_numel_padded}." ) - state_dict[ - f"{prefix}{fsdp_file.FSDP_PREFIX}{fsdp_file.FLAT_PARAM}" - ] = loaded_flat_tensor + state_dict[f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}"] = loaded_flat_tensor if module._use_orig_params: module._deregister_orig_params() @@ -381,17 +560,6 @@ def _post_state_dict_hook( processed_state_dict = _post_state_dict_hook_fn[fsdp_module._state_dict_type]( fsdp_module, state_dict, prefix ) - # Restore buffers, which currently are in their full precision type, - # back to their mixed precision type. This is because buffers are cast - # during lazy_init() and stay at their mixed precision type before/after - # forward/backward. As a result state_dict() should maintain this. - if fsdp_module._is_root and fsdp_module._mixed_precision_enabled_for_buffers(): - buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation( - fsdp_module, fsdp_module - ) - _cast_buffers_to_dtype_and_device( - buffers, buffer_dtypes, fsdp_module.compute_device - ) return processed_state_dict diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 6f5537aad520821..9934e718934258a 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -50,9 +50,7 @@ _init_state_dict_state, ) from torch.distributed.fsdp._runtime_utils import ( - _cast_buffers_to_dtype_and_device, _clear_grads_if_needed, - _get_buffers_and_dtypes_for_checkpoint, _lazy_init, _post_forward, _post_forward_reshard, @@ -512,6 +510,7 @@ def __init__( _pre_load_state_dict_hook, with_module=True ) self.register_load_state_dict_post_hook(_post_load_state_dict_hook) + self._full_param_ctx: Optional[Generator] = None @property def module(self) -> nn.Module: @@ -813,104 +812,8 @@ def _shared_param_fqns(self) -> Iterator[Tuple[str, str, str]]: yield fqn, param_name, module_name def state_dict(self, *args, **kwargs): - """ - This is the entry point of all three FSDP ``state_dict`` APIs: full, - local, and sharded. For the full state dict - (``StateDictType.FULL_STATE_DICT``), FSDP attempts to unshard the model - on all ranks, which may result in an OOM error if the full model cannot - fit on a single GPU. In that case, users may pass in a - :class:`FullStateDictConfig` to only save the checkpoint on rank 0 and/ - or to offload it to CPU memory layer by layer, enabling much larger - checkpoints. If the full model cannot fit in CPU memory, then users may - instead take a local state dict (``StateDictType.LOCAL_STATE_DICT``) - that only saves the local shard of the model. The sharded state dict - (``StateDictType.SHARDED_STATE_DICT``) saves the model parameters as - ``ShardedTensor`` s. The ``state_dict`` type can be configured using - the :meth:`state_dict_type` context manager. - - Example:: - - >>> # xdoctest: +SKIP("undefined variables") - >>> import torch - >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - >>> from torch.distributed.fsdp import StateDictType - >>> torch.cuda.set_device(device_id) - >>> my_module = nn.Linear(...) - >>> sharded_module = FSDP(my_module) - >>> full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - >>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT, full_state_dict_config): - >>> full_dict = sharded_module.state_dict() - >>> full_dict.keys() - >>> odict_keys(['weight', 'bias']) - >>> # using local state dict - >>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT): - >>> local_dict = sharded_module.state_dict() - >>> local_dict.keys() - >>> odict_keys(['flat_param', 'inner.flat_param']) - - .. warning:: This needs to be called on all ranks since it uses - collective communications. - """ - # TODO (rohan-varma): separate these out once a state_dict pre-hook - # is available. - if torch.cuda.is_available(): - torch.cuda.synchronize() _lazy_init(self, self) - if self._is_root: - _clear_grads_if_needed(self._fsdp_handles(self)) - if self._state_dict_type == StateDictType.FULL_STATE_DICT: - # Get config args - full_state_dict_config = ( - self._state_dict_config - if self._state_dict_config is not None - else FullStateDictConfig() - ) - rank0_only = full_state_dict_config.rank0_only - offload_to_cpu = full_state_dict_config.offload_to_cpu - summon_ctx = ( - self._summon_full_params( - recurse=False, - writeback=False, - offload_to_cpu=offload_to_cpu, - rank0_only=rank0_only, - ) - if self.training_state != TrainingState.SUMMON_FULL_PARAMS - else contextlib.suppress() - ) - with summon_ctx: - # Since buffers stay in their low precision throughout runtime, - # we must explicitly restore them to their original dtypes for - # model checkpointing. We have the root module cast for all - # submodules. - # TODO: Investigate if this can and should be refactored into - # `summon_full_params()`. - if self._is_root and self._mixed_precision_enabled_for_buffers(): - buffers, buffer_dtypes = _get_buffers_and_dtypes_for_checkpoint( - self, self - ) - _cast_buffers_to_dtype_and_device( - buffers, buffer_dtypes, self.compute_device - ) - state_dict = super().state_dict(*args, **kwargs) - - # TODO: support offload to CPU in post state dict hook. - if not rank0_only or self.rank == 0: - return state_dict - else: - return {} - - elif ( - self._state_dict_type == StateDictType.LOCAL_STATE_DICT - or self._state_dict_type == StateDictType.SHARDED_STATE_DICT - ): - if self._has_params and not self._handles[0].uses_sharded_strategy: - raise RuntimeError( - "sharded_state_dict/local_state_dict can only be called " - "when parameters are flatten and sharded." - ) - return super().state_dict(*args, **kwargs) - else: - raise ValueError(f"Unknown StateDictType {self._state_dict_type}.") + return super().state_dict(*args, **kwargs) def forward(self, *args: Any, **kwargs: Any) -> Any: """ From 86b7aa26f0bb8878d925a625af45d16d4bb2f2af Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 11 Nov 2022 03:49:27 +0000 Subject: [PATCH 11/62] Fix FakeTensorProp on Module with Parameters or Buffers (#88700) In `FakeTensorMode.__torch_dispatch__`, the output is now always computed by meta kernels in ```python try: with in_kernel_invocation_manager(self): r = func(*args, **kwargs) # <----- "r" can be a real tensor. except NotImplementedError as not_implemented_error: # no meta kernel registered, fallback to kernel for the device if not self.allow_fallback_kernels: raise not_implemented_error return run_fallback_kernel(self, func, args, kwargs, not_implemented_error) return self.wrap_meta_outputs_with_default_device_logic(r, func, args, kwargs) ``` For example, I observed a CPU tensor is generated when executing `aten.addmm` when running `FakeTensorProp`. Therefore, I'd like to allow `FakeTensorMode` to wrap real tensor as `FakeTensor` during the computation. Does this PR look a good direction to fix this problem? If yes, I can go ahead and add some tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88700 Approved by: https://github.com/eellison, https://github.com/ezyang --- test/test_fake_tensor.py | 59 +++++++++++++++++++++++++++++ torch/fx/passes/fake_tensor_prop.py | 12 +++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index ad9042196bff13d..3d47cc8ea0e51bf 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -2,6 +2,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfCrossRef, skipIfRocm import torch +import torch._dynamo import itertools import numpy as np from torch.testing._internal.jit_utils import RUN_CUDA @@ -11,6 +12,7 @@ FakeTensorConverter, DynamicOutputShapeException, ) +from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.testing import FileCheck from torch import nn import unittest @@ -663,5 +665,62 @@ def test_like_ops(self): op = self.get_aten_op(schema) self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors) +class FakeTensorPropTest(TestCase): + def test_fake_tensor_prop_on_nn_module(self): + class ToyNnModuleWithParameters(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(4, 3) + self.layer2 = torch.nn.Linear(3, 2) + + def forward(self, value): + value = self.layer1(value) + value = torch.relu(value) + value = self.layer2(value) + return value + + model = ToyNnModuleWithParameters() + value = torch.randn(5, 4) + # Convert nn.Module to GraphModule so that FakeTensorProp runs. + graph_model = torch.fx.symbolic_trace(model, (value,)) + # The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode + # + # TODO(wschin): there should be an API to run FakeTensorProp for GraphModule + # with parameters and buffers. + with FakeTensorMode() as fake_tensor_mode: + + def to_fake_tensor(x): + if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor): + return fake_tensor_mode.from_tensor(x) + return x + + fake_parameters_and_buffers = { + k: to_fake_tensor(v) + for k, v in itertools.chain( + graph_model.named_parameters(), graph_model.named_buffers() + ) + } + with torch.nn.utils.stateless._reparametrize_module( + graph_model, fake_parameters_and_buffers + ): + # This case uses the **same** fake tensor mode to + # 1. create fake parameters and fake buffers, and + # 2. run FakeTensorProp + # The result should be correct. + result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value) + self.assertTrue(isinstance(result, FakeTensor)) + self.assertEqual(result.shape, (5, 2)) + # This case uses the **different** fake tensor modes to + # 1. create fake parameters and fake buffers, and + # 2. run FakeTensorProp + # The following code should fail. + failed = False + try: + FakeTensorProp(graph_model).propagate(value) + except AssertionError: + # AssertionError: tensor's device must be `meta`, got cpu instead + failed = True + self.assertTrue(failed) + if __name__ == "__main__": run_tests() diff --git a/torch/fx/passes/fake_tensor_prop.py b/torch/fx/passes/fake_tensor_prop.py index b034b5341b068fd..403db5b9a009b93 100644 --- a/torch/fx/passes/fake_tensor_prop.py +++ b/torch/fx/passes/fake_tensor_prop.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch.fx from torch.fx import Node from torch.fx._compatibility import compatibility @@ -17,7 +19,13 @@ class FakeTensorProp(torch.fx.Interpreter): Args: module (GraphModule): The module to be executed + mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node. """ + def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None): + super().__init__(module) + if mode is None: + mode = FakeTensorMode() + self._mode = mode def run_node(self, n: Node): result = super().run_node(n) @@ -25,6 +33,6 @@ def run_node(self, n: Node): return result def propagate(self, *args): - with FakeTensorMode.push() as mode: - fake_args = [mode.from_tensor(a) for a in args] + with self._mode: + fake_args = [self._mode.from_tensor(a) for a in args] return super().run(*fake_args) From 310335de48ab9d8bcd33b98f3f71ef88ae4bd45c Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 11 Nov 2022 04:02:44 +0000 Subject: [PATCH 12/62] Update lr_scheduler.pyi to match lr_scheduler.py (#88818) Following #88503, we should also update the pyi file Pull Request resolved: https://github.com/pytorch/pytorch/pull/88818 Approved by: https://github.com/soulitzer --- torch/optim/lr_scheduler.pyi | 37 +++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/torch/optim/lr_scheduler.pyi b/torch/optim/lr_scheduler.pyi index 97603e064a70c09..00d9eb512ae11e3 100644 --- a/torch/optim/lr_scheduler.pyi +++ b/torch/optim/lr_scheduler.pyi @@ -1,7 +1,7 @@ from typing import Iterable, Any, Optional, Callable, Union, List from .optimizer import Optimizer -class _LRScheduler: +class LRScheduler: optimizer: Optimizer = ... base_lrs: List[float] = ... last_epoch: int = ... @@ -14,46 +14,49 @@ class _LRScheduler: def step(self, epoch: Optional[int] = ...) -> None: ... def print_lr(self, is_verbose: bool, group: dict, lr: float, epoch: Optional[int] = ...) -> None: ... -class LambdaLR(_LRScheduler): +class _LRScheduler(LRScheduler): + ... + +class LambdaLR(LRScheduler): lr_lambdas: List[Callable[[int], float]] = ... def __init__(self, optimizer: Optimizer, lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], last_epoch: int = ..., verbose: bool = ...) -> None: ... -class MultiplicativeLR(_LRScheduler): +class MultiplicativeLR(LRScheduler): lr_lambdas: List[Callable[[int], float]] = ... def __init__(self, optimizer: Optimizer, lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], last_epoch: int = ..., verbose: bool = ...) -> None: ... -class StepLR(_LRScheduler): +class StepLR(LRScheduler): step_size: int = ... gamma: float = ... def __init__(self, optimizer: Optimizer, step_size: int, gamma: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... -class MultiStepLR(_LRScheduler): +class MultiStepLR(LRScheduler): milestones: Iterable[int] = ... gamma: float = ... def __init__(self, optimizer: Optimizer, milestones: Iterable[int], gamma: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... -class ConstantLR(_LRScheduler): +class ConstantLR(LRScheduler): factor: float = ... total_iters: int = ... def __init__(self, optimizer: Optimizer, factor: float=..., total_iters: int=..., last_epoch: int=..., verbose: bool = ...) -> None: ... -class LinearLR(_LRScheduler): +class LinearLR(LRScheduler): start_factor: float = ... end_factor: float = ... total_iters: int = ... def __init__(self, optimizer: Optimizer, start_factor: float=..., end_factor: float= ..., total_iters: int= ..., last_epoch: int= ..., verbose: bool = ...) -> None: ... -class ExponentialLR(_LRScheduler): +class ExponentialLR(LRScheduler): gamma: float = ... def __init__(self, optimizer: Optimizer, gamma: float, last_epoch: int = ..., verbose: bool = ...) -> None: ... -class ChainedScheduler(_LRScheduler): - def __init__(self, schedulers: List[_LRScheduler]) -> None: ... +class ChainedScheduler(LRScheduler): + def __init__(self, schedulers: List[LRScheduler]) -> None: ... -class SequentialLR(_LRScheduler): - def __init__(self, optimizer: Optimizer, schedulers: List[_LRScheduler], milestones: List[int], last_epoch: int=..., verbose: bool=...) -> None: ... +class SequentialLR(LRScheduler): + def __init__(self, optimizer: Optimizer, schedulers: List[LRScheduler], milestones: List[int], last_epoch: int=..., verbose: bool=...) -> None: ... -class CosineAnnealingLR(_LRScheduler): +class CosineAnnealingLR(LRScheduler): T_max: int = ... eta_min: float = ... def __init__(self, optimizer: Optimizer, T_max: int, eta_min: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... @@ -82,7 +85,7 @@ class ReduceLROnPlateau: def state_dict(self) -> dict: ... def load_state_dict(self, state_dict: dict) -> None: ... -class CyclicLR(_LRScheduler): +class CyclicLR(LRScheduler): max_lrs: List[float] = ... total_size: float = ... step_ratio: float = ... @@ -95,7 +98,7 @@ class CyclicLR(_LRScheduler): def __init__(self, optimizer: Optimizer, base_lr: Union[float, List[float]], max_lr: Union[float, List[float]], step_size_up: int = ..., step_size_down: Optional[int] = ..., mode: str = ..., gamma: float = ..., scale_fn: Optional[Callable[[float], float]] = ..., scale_mode: str = ..., cycle_momentum: bool = ..., base_momentum: float = ..., max_momentum: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... def scale_fn(self, x: Any) -> float: ... -class CosineAnnealingWarmRestarts(_LRScheduler): +class CosineAnnealingWarmRestarts(LRScheduler): T_0: int = ... T_i: int = ... T_mult: Optional[int] = ... @@ -104,14 +107,14 @@ class CosineAnnealingWarmRestarts(_LRScheduler): def __init__(self, optimizer: Optimizer, T_0: int, T_mult: int = ..., eta_min: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... def step(self, epoch: Optional[Any] = ...): ... -class OneCycleLR(_LRScheduler): +class OneCycleLR(LRScheduler): total_steps: int = ... anneal_func: Callable[[float, float, float], float] = ... cycle_momentum: bool = ... use_beta1: bool = ... def __init__(self, optimizer: Optimizer, max_lr: Union[float, List[float]], total_steps: int = ..., epochs: int = ..., steps_per_epoch: int = ..., pct_start: float = ..., anneal_strategy: str = ..., cycle_momentum: bool = ..., base_momentum: Union[float, List[float]] = ..., max_momentum: Union[float, List[float]] = ..., div_factor: float = ..., final_div_factor: float = ..., three_phase: bool = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... -class PolynomialLR(_LRScheduler): +class PolynomialLR(LRScheduler): total_iters: int = ... power: float = ... def __init__(self, optimizer: Optimizer, total_iters: int = ..., power: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... From 0de8f047c1cc950c59b0448b9b78dafc0202c43f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 11 Nov 2022 04:19:08 +0000 Subject: [PATCH 13/62] Revert "[dynamo] fixes dict changed during runtime error (#87526)" This reverts commit cf04b36ce8f531730210b03eaa347977a1c2d75c. Reverted https://github.com/pytorch/pytorch/pull/87526 on behalf of https://github.com/anijain2305 due to error reported --- test/dynamo/test_aot_cudagraphs.py | 3 +++ torch/_dynamo/convert_frame.py | 15 +++------------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_aot_cudagraphs.py b/test/dynamo/test_aot_cudagraphs.py index fdb7c88762b8b0e..cb1d2a0e601fffa 100644 --- a/test/dynamo/test_aot_cudagraphs.py +++ b/test/dynamo/test_aot_cudagraphs.py @@ -71,6 +71,7 @@ def fn(x, y): y = torch.randn(3, device="cuda") fn(x, y) + @patch("torch._dynamo.config.suppress_errors", True) @patch_all() def test_dtoh(self): def model(x, y): @@ -104,6 +105,7 @@ def fn(x, y): y = torch.randn((), device="cpu") fn(x, y) + @patch("torch._dynamo.config.suppress_errors", True) @patch("functorch._src.config.use_functionalize", True) @patch_all(ok=False) # input mutation not supported yet def test_mutate_input(self): @@ -143,6 +145,7 @@ def fn(x, y): y = torch.randn(1, device="cuda") fn(x, y) + @patch("torch._dynamo.config.suppress_errors", True) @patch_all() def test_factory(self): def model(y): diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index db9b23f2da7e3c9..f1ce83727a19f6c 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -156,11 +156,7 @@ def has_tensor(obj): seen_ids[obj_id] = any([has_tensor(v) for v in obj]) return seen_ids[obj_id] elif istype(obj, dict): - # Some packages like pytest can be updated during runtime. So, make a - # copy of values to avoid issues like "RuntimeError: dictionary - # changed size during iteration" - values = list(obj.values()) - seen_ids[obj_id] = any([has_tensor(v) for v in values]) + seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()]) return seen_ids[obj_id] elif istype(obj, (str, int, float, type(None), bool)): seen_ids[obj_id] = False @@ -168,13 +164,8 @@ def has_tensor(obj): elif is_namedtuple(obj): seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields]) return seen_ids[obj_id] - elif ( - not is_allowed(obj) - and not hasattr(obj, "__get__") # overridden get can mutate the object - and hasattr(obj, "__dict__") - and istype(obj.__dict__, dict) - ): - seen_ids[obj_id] = has_tensor(obj.__dict__) + elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__): + seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()]) return seen_ids[obj_id] else: # if config.debug: From a6d72f44a4e8b6e9d2e878f30fd8b1d3e1197f0e Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 9 Nov 2022 17:27:22 +0000 Subject: [PATCH 14/62] [ONNX] Add onnx::Max into standard Op for scalar type alignment (#88750) Easy fix for onnx::Max ScalarType Pull Request resolved: https://github.com/pytorch/pytorch/pull/88750 Approved by: https://github.com/justinchuby, https://github.com/BowenBao --- aten/src/ATen/core/interned_strings.h | 1 + torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 80919e52b58fddb..2abc6217516de8b 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -239,6 +239,7 @@ namespace c10 { _(onnx, LSTM) \ _(onnx, MatMul) \ _(onnx, Min) \ + _(onnx, Max) \ _(onnx, Mul) \ _(onnx, Pow) \ _(onnx, RNN) \ diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 657c27f70c7d9d6..3af0360b7e01128 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -48,6 +48,7 @@ static const std::unordered_set standardOps = { onnx::Div, onnx::Gemm, onnx::Min, + onnx::Max, onnx::Mod, onnx::Mul, onnx::Pow, From 396c3b1d88d7624938a2bb0b287f2a19f1e89bb4 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Fri, 11 Nov 2022 05:23:48 +0000 Subject: [PATCH 15/62] Use `atomicAdd` for `bfloat16` in Ampere and above (#84981) WIP to fix extremely slow `scatter_add` issue vs. fp16. The current changes seem to improve performance, but it still appears to lag behind the fp16 equivalent. CC @ngimel @ptrblck Pull Request resolved: https://github.com/pytorch/pytorch/pull/84981 Approved by: https://github.com/ngimel --- aten/src/ATen/cuda/Atomic.cuh | 17 ++++++-- aten/src/ATen/native/cuda/KernelUtils.cuh | 48 ++++++++++++++++++++++- 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/cuda/Atomic.cuh b/aten/src/ATen/cuda/Atomic.cuh index 42975411e841e1c..3d60b672e9725ba 100644 --- a/aten/src/ATen/cuda/Atomic.cuh +++ b/aten/src/ATen/cuda/Atomic.cuh @@ -6,6 +6,10 @@ #include +#if !(defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#endif + template struct AtomicFPOp; @@ -219,10 +223,15 @@ static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) } static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) { - return AtomicFPOp()(address, val, - [](at::BFloat16 bsum, at::BFloat16 val) { - return bsum + val; - }); +#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) +return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return bsum + val; + }); +#else + __nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val)); + return *reinterpret_cast(&r); +#endif } #if defined(CUDA_VERSION) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000) diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index 1e36e2db74d541b..d2e956d1a3e44c1 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -1,6 +1,10 @@ #pragma once #include +#if !(defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#endif + namespace at { namespace native { @@ -66,7 +70,49 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd( template < typename scalar_t, typename index_t, - typename std::enable_if::value>::type* = + typename std::enable_if::value>::type* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { +#if ( \ + (defined(USE_ROCM)) || \ + (defined(CUDA_VERSION) && (CUDA_VERSION < 11000)) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, + static_cast(value)); +#else + // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned) + __nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index); + bool low_byte = (reinterpret_cast(target_addr) % sizeof(__nv_bfloat162) == 0); + + if (low_byte && index < (numel - 1)) { + __nv_bfloat162 value2; + value2.x = *reinterpret_cast<__nv_bfloat16*>(&value); + value2.y = __int2bfloat16_rz(0); + atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2); + + } else if (!low_byte && index > 0) { + __nv_bfloat162 value2; + value2.x = __int2bfloat16_rz(0); + value2.y = *reinterpret_cast<__nv_bfloat16*>(&value); + atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2); + + } else { + atomicAdd( + reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value)); + } +#endif +} + + +template < + typename scalar_t, + typename index_t, + typename std::enable_if::value && !std::is_same::value >::type* = nullptr> __device__ __forceinline__ void fastSpecializedAtomicAdd( scalar_t* tensor, From b843f4db0a26aae6536e6b971f73bcc5af21c90a Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 9 Nov 2022 17:41:10 +0000 Subject: [PATCH 16/62] [ONNX] Add test case for onnx::Max scalar type (#88751) Referenced by minimum cases Pull Request resolved: https://github.com/pytorch/pytorch/pull/88751 Approved by: https://github.com/wschin, https://github.com/BowenBao --- test/onnx/test_pytorch_onnx_onnxruntime.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 1e36163d0394cce..e4fc3f83b288df7 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -8728,6 +8728,28 @@ def forward(self, x, y): y = torch.full_like(x, True) self.run_test(MinimumModel(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(12) + def test_maximum_dtypes(self): + class MaximumModel(torch.nn.Module): + def forward(self, x, y): + return torch.maximum(x, y) + + x = torch.randn((5, 5), dtype=torch.float16) + y = torch.randn((5, 5), dtype=torch.float) + self.run_test(MaximumModel(), (x, y)) + + x = torch.randn((5, 5), dtype=torch.float16) + y = torch.randint(10, (5, 5), dtype=torch.int16) + self.run_test(MaximumModel(), (x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.int16) + y = torch.randint(10, (5, 5), dtype=torch.int32) + self.run_test(MaximumModel(), (x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.int) + y = torch.full_like(x, True) + self.run_test(MaximumModel(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(9) def test_any(self): class M(torch.nn.Module): From d15a6b0c975b9e1e90ed4e951071e5269c10ac5b Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 11 Nov 2022 08:51:26 +0000 Subject: [PATCH 17/62] Error on ZeroTensor serialization (#88803) Follow-up : https://github.com/pytorch/pytorch/pull/88182#issuecomment-1308628415 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88803 Approved by: https://github.com/anjali411 --- test/cpp/api/serialize.cpp | 8 ++++++++ test/test_serialization.py | 22 ++++++++++++++++++++++ torch/csrc/jit/serialization/pickler.h | 6 ++++++ 3 files changed, 36 insertions(+) diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp index 05bb0f941d402d3..20d572853d3a160 100644 --- a/test/cpp/api/serialize.cpp +++ b/test/cpp/api/serialize.cpp @@ -288,6 +288,14 @@ TEST(SerializeTest, MathBits) { ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); ASSERT_TRUE(actual.allclose(expected)); } + + { + // We don't support serializing `ZeroTensor` as it is not public facing yet. + // If in future, `ZeroTensor` serialization is supported, this test should + // start failing! + auto t = torch::_efficientzerotensor({5, 5}); + ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,"); + } } TEST(SerializeTest, BasicToFile) { diff --git a/test/test_serialization.py b/test/test_serialization.py index af0317e87a145e8..779d6fb5c20c5da 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -931,6 +931,28 @@ def _save_load_check(t): t_n_c = torch._neg_view(torch.conj(t)) _save_load_check(t_n_c) + @parametrize('weights_only', (False, True)) + def test_serialization_efficient_zerotensor(self, weights_only): + # We don't support serializing `ZeroTensor` as it is not public + # facing yet. + # If in future, `ZeroTensor` serialization is supported, this test + # should start failing! + t = torch._efficientzerotensor((4, 5)) + + def _save_load_check(t): + with BytesIOContext() as f: + torch.save(t, f) + f.seek(0) + # Unsafe load should work + self.assertEqual(torch.load(f, weights_only=weights_only), t) + + # NOTE: `torch.save` fails before we hit the TORCH_CHECK in `getTensoMetadata` + # as nullptr storage is disabled. + err_msg = (r'python bindings to nullptr storage \(e.g., from torch.Tensor._make_wrapper_subclass\)' + ' are currently unsafe and thus disabled') + with self.assertRaisesRegex(RuntimeError, err_msg): + _save_load_check(t) + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super(TestSerialization, self).run(*args, **kwargs) diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index c289cae12b64963..26f9fcf423965dc 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -300,6 +300,12 @@ bool checkHasValidSetGetState(const std::shared_ptr& cls); // For now, it only takes care of `conj` and `neg` bit. inline std::unordered_map getTensorMetadata( const at::Tensor& t) { + // We don't support serializing `ZeroTensor` as it is not public + // facing yet. + TORCH_CHECK( + !t._is_zerotensor(), + "ZeroTensor is not serializable,", + " please file an issue if required."); std::unordered_map metadata{}; // Only add meta-data if the value is not default. From ee91c328da5739ce03b3127cd7c542ce505212b8 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Fri, 11 Nov 2022 12:19:31 +0000 Subject: [PATCH 18/62] Fix cuda/cpu check on NoneType (#88854) Summary: Fix cuda/cpu check on NoneType Test Plan: sabdcastle/ github CI/CD Differential Revision: D41203955 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88854 Approved by: https://github.com/drisspg, https://github.com/ngimel --- torch/nn/modules/activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 7b0e7e3effaac4c..e6b3b778e5fbcb7 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1111,7 +1111,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): + elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): why_not_fast_path = ("grad is enabled and at least one of query or the " From 324ac93a43a93f671bb34b835926b22d13442735 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 8 Nov 2022 00:16:14 +0000 Subject: [PATCH 19/62] [FSDP][state_dict][2/N] Move state_dict related enums/dataclasses/states to state_dict_utils.py, api.py and init_state_dict() (#88481) **Motivation**: Several Enums, Dataclasses and states defined in fully_sharded_data_paralle.py should be moved to a place where the composable FSDP can access. This PR does the move. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88481 Approved by: https://github.com/rohan-varma, https://github.com/awgu --- torch/distributed/fsdp/_init_utils.py | 12 +- torch/distributed/fsdp/_state_dict_utils.py | 72 +++++++--- torch/distributed/fsdp/api.py | 96 ++++++++++++- .../fsdp/fully_sharded_data_parallel.py | 127 +----------------- 4 files changed, 164 insertions(+), 143 deletions(-) diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index c89f65c3a5b829f..966e61f7fe1231b 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -3,6 +3,7 @@ from typing import ( Callable, Dict, + Generator, Iterable, Iterator, List, @@ -33,8 +34,11 @@ from torch.distributed.fsdp.api import ( BackwardPrefetch, CPUOffload, + FullStateDictConfig, MixedPrecision, ShardingStrategy, + StateDictConfig, + StateDictType, ) from torch.distributed.fsdp.flat_param import ( _HandlesKey, @@ -206,7 +210,13 @@ def _init_prefetching_state( def _init_state_dict_state(state: _FSDPState) -> _FSDPState: - # TODO: after rebase + state._state_dict_type = StateDictType.FULL_STATE_DICT + state_dict_config: StateDictConfig = FullStateDictConfig() + state._state_dict_config = state_dict_config + full_param_ctx: Optional[Generator] = None + # TODO: For composable API, this should be a dict that maps from a module to + # handles. + state._full_param_ctx = full_param_ctx return state diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 1109f1e881506d2..c90bd4d409b1ff9 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -1,7 +1,7 @@ import functools import math import warnings -from typing import Any, Callable, cast, Dict +from typing import Any, Callable, cast, Dict, Iterator, Tuple import torch import torch.distributed as dist @@ -20,6 +20,7 @@ from torch.distributed.fsdp._common_utils import ( clean_tensor_name, FSDP_PREFIX, + FSDP_WRAPPED_MODULE, TrainingState, ) from torch.distributed.fsdp._runtime_utils import ( @@ -28,6 +29,7 @@ _get_buffer_dtypes, _lazy_init, ) +from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType from torch.distributed.utils import _replace_by_prefix from ._fsdp_extensions import ( @@ -38,6 +40,33 @@ from .flat_param import FlatParamHandle +def _convert_to_wrapped_module_name(module_name: str) -> str: + module_name = module_name.replace(f"{FSDP_PREFIX}", "") + module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "") + if module_name: + module_name = f"{module_name}." + # Activation checkpoint adds a prefix that has to be + # removed as well. + module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "") + return module_name + + +def _param_fqns(module) -> Iterator[Tuple[str, str, str]]: + if not module._has_params: + return + for param_name, module_name in module._handles[0].parameter_module_names(): + module_name = _convert_to_wrapped_module_name(module_name) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + + +def _shared_param_fqns(module) -> Iterator[Tuple[str, str, str]]: + for param_name, module_name in module._handles[0].shared_parameter_module_names(): + module_name = _convert_to_wrapped_module_name(module_name) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + + def _enter_full_param_ctx( module, recurse: bool = False, @@ -71,7 +100,10 @@ def _enter_full_param_ctx( def _exit_full_param_ctx(module) -> None: """A helper function to exit ``summon_full_params`` context.""" - module._assert_state([TrainingState.SUMMON_FULL_PARAMS]) + assert module.training_state == TrainingState.SUMMON_FULL_PARAMS, ( + "Exiting the summon_full_params context but the state is not " + "SUMMON_FULL_PARAMS." + ) assert module._full_param_ctx is not None module._full_param_ctx.__exit__(None, None, None) module._full_param_ctx = None @@ -124,7 +156,9 @@ def _common_summon_post_state_dict_hook( hook. """ _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) - module._assert_state([TrainingState.SUMMON_FULL_PARAMS]) + assert ( + module.training_state == TrainingState.SUMMON_FULL_PARAMS + ), "Inside the post_state_dict_hook but the state is not SUMMON_FULL_PARAMS." # Return early for trivial cases if not state_dict or not module._has_params: _exit_full_param_ctx(module) @@ -141,8 +175,8 @@ def _common_summon_post_state_dict_hook( # all-gather and does not need to save the # state dict. We simply check # rank0_only to ensure this issue. rank0_only = ( - module._state_dict_type == fsdp_file.StateDictType.FULL_STATE_DICT - and cast(fsdp_file.FullStateDictConfig, module._state_dict_config).rank0_only + module._state_dict_type == StateDictType.FULL_STATE_DICT + and cast(FullStateDictConfig, module._state_dict_config).rank0_only ) # no_fsdp_return means the state_dict returned by this rank should contain # only non-FSDP controlled parameters and buffers. @@ -159,7 +193,7 @@ def _common_summon_post_state_dict_hook( # Loop only the parameters saved in this instance's wrapped module to # avoid processing buffers. - for fqn, param_name, module_name in module._param_fqns: + for fqn, param_name, module_name in _param_fqns(module): # TODO: remove the parameter retrieval. See ``_full_pre_state_dict_hook``. param = functools.reduce(getattr, fqn.split("."), module.module) fqn = f"{prefix}{fqn}" @@ -224,9 +258,7 @@ def _full_pre_state_dict_hook( _common_summon_pre_state_dict_hook( module, offload_to_cpu=module._state_dict_config.offload_to_cpu, - rank0_only=cast( - fsdp_file.FullStateDictConfig, module._state_dict_config - ).rank0_only, + rank0_only=cast(FullStateDictConfig, module._state_dict_config).rank0_only, ) @@ -473,9 +505,9 @@ def _sharded_pre_load_state_dict_hook( ) nonsharded_tensors = [] - shared_fqns = [fqn for fqn, _, _ in module._shared_param_fqns] + shared_fqns = [fqn for fqn, _, _ in _shared_param_fqns(module)] loaded_shapes = [] - for fqn, _, _ in module._param_fqns: + for fqn, _, _ in _param_fqns(module): full_fqn = f"{prefix}{FSDP_PREFIX}{fqn}" param = state_dict.pop(full_fqn) if fqn in shared_fqns: @@ -552,9 +584,9 @@ def _post_state_dict_hook( what postprocessing will be done. """ _post_state_dict_hook_fn = { - fsdp_file.StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook, - fsdp_file.StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, - fsdp_file.StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, + StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, } fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) processed_state_dict = _post_state_dict_hook_fn[fsdp_module._state_dict_type]( @@ -576,9 +608,9 @@ def _pre_load_state_dict_hook( will be done. """ _pre_load_state_dict_hook_fn = { - fsdp_file.StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook, - fsdp_file.StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook, - fsdp_file.StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook, + StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook, } # Code that is common for all state_dict impls fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) @@ -593,9 +625,9 @@ def _pre_load_state_dict_hook( @torch.no_grad() def _post_load_state_dict_hook(module: nn.Module, *args: Any) -> None: _post_load_state_dict_hook_fn = { - fsdp_file.StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook, - fsdp_file.StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook, - fsdp_file.StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook, + StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook, } # Code that is common for all state_dict impls fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) diff --git a/torch/distributed/fsdp/api.py b/torch/distributed/fsdp/api.py index 9e1327c80633c42..18f3cd3069ddfa0 100644 --- a/torch/distributed/fsdp/api.py +++ b/torch/distributed/fsdp/api.py @@ -10,7 +10,17 @@ import torch -__all__ = ["ShardingStrategy", "BackwardPrefetch", "MixedPrecision", "CPUOffload"] +__all__ = [ + "ShardingStrategy", + "BackwardPrefetch", + "MixedPrecision", + "CPUOffload", + "StateDictType", + "StateDictConfig", + "FullStateDictConfig", + "LocalStateDictConfig", + "ShardedStateDictConfig", +] class ShardingStrategy(Enum): @@ -149,3 +159,87 @@ class CPUOffload: """ offload_params: bool = False + + +class StateDictType(Enum): + """ + This enum indicates that which type of ``state_dict`` the FSDP module is + currently processing (returning or loading). + The default value is FULL_STATE_DICT to comply the PyTorch convention. + ..note:: + FSDP currently supports three types of ``state_dict``: + 1. ``state_dict/load_state_dict`: this pair of APIs return and load + the non-sharded, unflattened parameters. The semantics is the + same as using DDP. + 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return + and load local sharded, flattened parameters. The values returned + by ``_local_state_dict`` can be directly used by FSDP and is only + meaningful to FSDP (because parameters are flattened). Note that + these APIs are meant for use via the :func:`state_dict_type` + context manager as follows: + >>> # xdoctest: +SKIP("undefined variables") + >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): + ... state = fsdp.state_dict() # loads local state dict + 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs + return and load sharded, unflattened parameters. The ``state_dict`` + return by ``sharded_state_dict`` can be used by all other parallel + schemes (resharding may be required). + """ + + FULL_STATE_DICT = auto() + LOCAL_STATE_DICT = auto() + SHARDED_STATE_DICT = auto() + + +@dataclass +class StateDictConfig: + """ + ``StateDictConfig`` is the base class for all state_dict configuration classes. + Users should instantiate a child version (i.e. ``FullStateDictConfig``) in + order to configure settings for the particular type of ``state_dict`` + implementation FSDP will use. + """ + + offload_to_cpu: bool = False + + +@dataclass +class FullStateDictConfig(StateDictConfig): + """ + ``FullStateDictConfig`` is a config class meant to be used with + ``StateDictType.FULL_STATE_DICT``. Currently, it accepts two parameters, + ``offload_to_cpu`` and ``rank0_only`` which can be configured to offload + the full ``state_dict`` to CPU and to materialize the ``state_dict`` on + rank 0 only. When used, it is recommended to enable both of these flags + together to optimize memory savings when taking checkpoints. Note that + this config class is meant for user via the :func:`state_dict_type` + context manager as follows: + >>> # xdoctest: +SKIP("undefined variables") + >>> fsdp = FSDP(model, auto_wrap_policy=...) + >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + >>> with FullyShardedDataParallel.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): + >>> state = fsdp.state_dict() + >>> # state will be empty on non rank 0 and contain CPU tensors on rank 0. + >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: + >>> model = model_fn() # Initialize model on CPU in preparation for wrapping with FSDP + >>> if dist.get_rank() == 0: + >>> # Load checkpoint only on rank 0 to avoid memory redundancy + >>> state_dict = torch.load("my_checkpoint.pt") + >>> model.load_state_dict(state_dict) + >>> # All ranks initialize FSDP module as usual. ``sync_module_states`` argument + >>> # communicates loaded checkpoint states from rank 0 to rest of the world. + >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) + >>> # After this point, all ranks have FSDP model with loaded checkpoint. + """ + + rank0_only: bool = False + + +@dataclass +class LocalStateDictConfig(StateDictConfig): + pass + + +@dataclass +class ShardedStateDictConfig(StateDictConfig): + pass diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 9934e718934258a..773686081a4d2a3 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -5,7 +5,6 @@ import traceback import warnings from contextlib import contextmanager -from dataclasses import dataclass from enum import auto, Enum from typing import ( Any, @@ -25,7 +24,6 @@ import torch.nn as nn from torch.distributed import ProcessGroup from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - _CHECKPOINT_PREFIX, _CHECKPOINT_WRAPPED_MODULE, ActivationWrapper, ) @@ -68,8 +66,13 @@ from torch.distributed.fsdp.api import ( BackwardPrefetch, CPUOffload, + FullStateDictConfig, + LocalStateDictConfig, MixedPrecision, + ShardedStateDictConfig, ShardingStrategy, + StateDictConfig, + StateDictType, ) from ._optim_utils import ( @@ -103,11 +106,6 @@ __all__ = [ "FullyShardedDataParallel", - "StateDictType", - "StateDictConfig", - "FullStateDictConfig", - "LocalStateDictConfig", - "ShardedStateDictConfig", "OptimStateKeyType", ] @@ -115,90 +113,6 @@ FLAT_PARAM = "_flat_param" -class StateDictType(Enum): - """ - This enum indicates that which type of ``state_dict`` the FSDP module is - currently processing (returning or loading). - The default value is FULL_STATE_DICT to comply the PyTorch convention. - ..note:: - FSDP currently supports three types of ``state_dict``: - 1. ``state_dict/load_state_dict`: this pair of APIs return and load - the non-sharded, unflattened parameters. The semantics is the - same as using DDP. - 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return - and load local sharded, flattened parameters. The values returned - by ``_local_state_dict`` can be directly used by FSDP and is only - meaningful to FSDP (because parameters are flattened). Note that - these APIs are meant for use via the :func:`state_dict_type` - context manager as follows: - >>> # xdoctest: +SKIP("undefined variables") - >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): - ... state = fsdp.state_dict() # loads local state dict - 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs - return and load sharded, unflattened parameters. The ``state_dict`` - return by ``sharded_state_dict`` can be used by all other parallel - schemes (resharding may be required). - """ - - FULL_STATE_DICT = auto() - LOCAL_STATE_DICT = auto() - SHARDED_STATE_DICT = auto() - - -@dataclass -class StateDictConfig: - """ - ``StateDictConfig`` is the base class for all state_dict configuration classes. - Users should instantiate a child version (i.e. ``FullStateDictConfig``) in - order to configure settings for the particular type of ``state_dict`` - implementation FSDP will use. - """ - - offload_to_cpu: bool = False - - -@dataclass -class FullStateDictConfig(StateDictConfig): - """ - ``FullStateDictConfig`` is a config class meant to be used with - ``StateDictType.FULL_STATE_DICT``. Currently, it accepts two parameters, - ``offload_to_cpu`` and ``rank0_only`` which can be configured to offload - the full ``state_dict`` to CPU and to materialize the ``state_dict`` on - rank 0 only. When used, it is recommended to enable both of these flags - together to optimize memory savings when taking checkpoints. Note that - this config class is meant for user via the :func:`state_dict_type` - context manager as follows: - >>> # xdoctest: +SKIP("undefined variables") - >>> fsdp = FSDP(model, auto_wrap_policy=...) - >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - >>> with FullyShardedDataParallel.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): - >>> state = fsdp.state_dict() - >>> # state will be empty on non rank 0 and contain CPU tensors on rank 0. - >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: - >>> model = model_fn() # Initialize model on CPU in preparation for wrapping with FSDP - >>> if dist.get_rank() == 0: - >>> # Load checkpoint only on rank 0 to avoid memory redundancy - >>> state_dict = torch.load("my_checkpoint.pt") - >>> model.load_state_dict(state_dict) - >>> # All ranks initialize FSDP module as usual. ``sync_module_states`` argument - >>> # communicates loaded checkpoint states from rank 0 to rest of the world. - >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) - >>> # After this point, all ranks have FSDP model with loaded checkpoint. - """ - - rank0_only: bool = False - - -@dataclass -class LocalStateDictConfig(StateDictConfig): - pass - - -@dataclass -class ShardedStateDictConfig(StateDictConfig): - pass - - class OptimStateKeyType(Enum): PARAM_NAME = auto() PARAM_ID = auto() @@ -502,15 +416,12 @@ def __init__( # `_state_dict_type` controls the `state_dict()` behavior, which is # implemented using post-save and pre-load hooks - _init_state_dict_state(self) # TODO: currently a no-op; need to refactor below - self._state_dict_type = StateDictType.FULL_STATE_DICT - self._state_dict_config = FullStateDictConfig() + _init_state_dict_state(self) self._register_state_dict_hook(_post_state_dict_hook) self._register_load_state_dict_pre_hook( _pre_load_state_dict_hook, with_module=True ) self.register_load_state_dict_post_hook(_post_load_state_dict_hook) - self._full_param_ctx: Optional[Generator] = None @property def module(self) -> nn.Module: @@ -785,32 +696,6 @@ def state_dict_type( module, prev_state_dict_type, prev_state_dict_config ) - def _convert_to_wrapped_module_name(self, module_name: str) -> str: - module_name = module_name.replace(f"{FSDP_PREFIX}", "") - module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "") - if module_name: - module_name = f"{module_name}." - # Activation checkpoint adds a prefix that has to be - # removed as well. - module_name = module_name.replace(_CHECKPOINT_PREFIX, "") - return module_name - - @property - def _param_fqns(self) -> Iterator[Tuple[str, str, str]]: - if not self._has_params: - return - for param_name, module_name in self._handles[0].parameter_module_names(): - module_name = self._convert_to_wrapped_module_name(module_name) - fqn = f"{module_name}{param_name}" - yield fqn, param_name, module_name - - @property - def _shared_param_fqns(self) -> Iterator[Tuple[str, str, str]]: - for param_name, module_name in self._handles[0].shared_parameter_module_names(): - module_name = self._convert_to_wrapped_module_name(module_name) - fqn = f"{module_name}{param_name}" - yield fqn, param_name, module_name - def state_dict(self, *args, **kwargs): _lazy_init(self, self) return super().state_dict(*args, **kwargs) From 91b71cdbe4f31006fad91f9dd460123677a7c625 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 9 Nov 2022 20:39:50 +0000 Subject: [PATCH 20/62] [dynamo] Add torch.device to is_safe_constant (#88766) Test Plan: ``` PYTORCH_TEST_WITH_DYNAMO=1 python test/test_torch.py -k test_advancedindex_mixed_cpu_devices_cuda ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88766 Approved by: https://github.com/jansel --- torch/_dynamo/utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index ef2c1c38ea8ba5b..067a808073743b6 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -583,7 +583,19 @@ def is_safe_constant(v): if istype(v, (tuple, frozenset)): return all(map(is_safe_constant, v)) return istype( - v, (types.CodeType, int, float, bool, str, bytes, type(None), slice, type(type)) + v, + ( + types.CodeType, + int, + float, + bool, + str, + bytes, + type(None), + slice, + type(type), + torch.device, + ), ) From b92acee8f83c7852194d6979362aea0c240709da Mon Sep 17 00:00:00 2001 From: soulitzer Date: Thu, 10 Nov 2022 19:08:42 -0500 Subject: [PATCH 21/62] Add context manager to allow mutation on saved tensors (#79056) Pull Request resolved: https://github.com/pytorch/pytorch/pull/79056 Approved by: https://github.com/albanD --- test/test_autograd.py | 178 ++++++++++++++++++++++++++++++++++++++++ torch/autograd/graph.py | 163 +++++++++++++++++++++++++++++++++++- 2 files changed, 338 insertions(+), 3 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index f5d890fad2d7f1d..e08047860e42333 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -8778,6 +8778,184 @@ def test_warning_in_backward(self, device): with self.assertWarnsRegex(UserWarning, "Warn from backward"): b.backward() +class TestAllowMutationOnSaved(TestCase): + def assertClonedLenEqual(self, ctx, n): + self.assertEqual(len(list(ctx.cloned.items())), n) + + def assertTIDMapLenEqual(self, ctx, n): + self.assertEqual(len(list(ctx.tid_to_weakhandle.items())), n) + + def test_basic(self): + a = torch.rand(2, 3, requires_grad=True) + + def fn(a): + b = a.clone() + out = (b**2).sum() + b.sin_() + out.sum().backward() + return a.grad + msg = "variables needed for gradient computation has been modified by an inplace" + with self.assertRaisesRegex(RuntimeError, msg): + fn(a) + + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + da = fn(a) + + self.assertTrue(torch.allclose(a * 2, da)) + self.assertClonedLenEqual(ctx, 0) + + def test_views(self): + a = torch.rand(2, 3, requires_grad=True) + + def fn(a): + b = a.clone() + c = b.view_as(b) + out = (b**2).sum() # How does this work? + c.sin_() + out.sum().backward() + return a.grad + + msg = "variables needed for gradient computation has been modified by an inplace" + with self.assertRaisesRegex(RuntimeError, msg): + fn(a) + + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + da = fn(a) + + self.assertClonedLenEqual(ctx, 0) + self.assertTrue(torch.allclose(a * 2, da)) + + def test_save_base_and_modify_view(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + b = a.clone() + c = b[:1] + out = b**2 + # modify the view + c *= 10 + # self.assertClonedLenEqual(ctx, 1) + out.sum().backward() + self.assertClonedLenEqual(ctx, 0) + + self.assertClonedLenEqual(ctx, 0) + self.assertTrue(torch.allclose(a * 2, a.grad)) + + def test_save_view_modify_base(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + b = a.clone() + c = b[:] + out = (c**2).sum() + b *= 2 + out.backward() + self.assertTrue(torch.allclose(a * 2, a.grad)) + + def test_double_backward(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + b = a.clone() + out = (b**2).sum() + b.sin_() + torch.autograd.grad(out, a, create_graph=True) + da, = torch.autograd.grad(out, a, create_graph=True) + d2a, = torch.autograd.grad(da.sum(), a) + + self.assertTrue(torch.allclose(torch.ones_like(a) * 2, d2a)) + self.assertClonedLenEqual(ctx, 0) + + def test_saved_but_not_anymore(self): + # Make sure we don't clone if the tensor was once saved, but + # by the time we do in-place, it is no longer saved + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.randn(2, 3, requires_grad=True).clone() + out = (a**2).sum() + self.assertTIDMapLenEqual(ctx, 1) + self.assertClonedLenEqual(ctx, 0) + out.backward() + a.sin_() + self.assertClonedLenEqual(ctx, 0) + out = (a**2).sum() + a.sin_() + self.assertClonedLenEqual(ctx, 1) + del out + self.assertClonedLenEqual(ctx, 0) + + def test_saved_same_tensor_many_times(self): + # We should only clone once + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.randn(2, 3, requires_grad=True).clone() + b = a**2 + c = a**2 + a.sin_() + self.assertClonedLenEqual(ctx, 1) + del b, c + self.assertClonedLenEqual(ctx, 0) + + def test_saved_same_tensor_different_versions(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.randn(2, 3, requires_grad=True).clone() + b = a**2 + a.sin_() + c = a**2 + a.sin_() + self.assertClonedLenEqual(ctx, 2) + del b + self.assertClonedLenEqual(ctx, 1) + del c + self.assertClonedLenEqual(ctx, 0) + + def test_with_math_views(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.tensor([1 + 1j], requires_grad=True).clone() + b = a.conj() + out = (b**2).sum() + a.sin_() + out.backward() + + a = torch.tensor([1 + 1j], requires_grad=True).clone() + b = a.conj() + out = (b**2).sum() + # in this case, it is no longer a view it seems + b.sin_() + out.backward() + + def test_with_out_variant(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.tensor([1.], requires_grad=True) + b = torch.tensor([1.]) + c = torch.tensor([2.]) + out = a * b + self.assertTIDMapLenEqual(ctx, 1) + torch.sin(c, out=b) + self.assertClonedLenEqual(ctx, 1) + out.backward() + self.assertClonedLenEqual(ctx, 0) + + def test_backward_out_of_context(self): + # Out of context + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + out = (a**2).sum() + + msg = "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" + with self.assertRaisesRegex(RuntimeError, msg): + out.backward() + + # Different context + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + out = (a**2).sum() + + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + with self.assertRaisesRegex(RuntimeError, msg): + out.backward() + + def test_disallow_nesting(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + msg = "allow_mutation_on_saved_tensors contexts cannot be nested" + with self.assertRaisesRegex(RuntimeError, msg): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + pass class TestAutogradInferenceMode(TestCase): def _is_inference_tensor(self, tensor): diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 9c333c70bcf2239..fc490a9d8e31c18 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -1,15 +1,17 @@ import torch import contextlib -from typing import Callable, Any, Dict, Tuple, Optional, Sequence, List +from typing import Callable, Any, Dict, Tuple, Optional, Sequence, List, Set from torch.utils.hooks import RemovableHandle - -__all__ = ["saved_tensors_hooks", "save_on_cpu"] +from torch.utils._python_dispatch import TorchDispatchMode +from collections import defaultdict +import weakref __all__ = [ "saved_tensors_hooks", "save_on_cpu", "disable_saved_tensors_hooks", "register_multi_grad_hook", + "allow_mutation_on_saved_tensors", ] class saved_tensors_hooks(): @@ -270,3 +272,158 @@ def __setstate__(self, state): handles.append(t.register_hook(get_inner_hook(i))) return Handle(tuple(handles)) + + +# NOTE [Allow mutation on tensors saved for backward] +# +# 1. Tensor gets saved for backward +# - remember the python object id and the version of the tensor +# - remember aliasing information (data_ptr of base + version) +# - save the original so we control its lifetime +# 2. Any time a tensor gets in-placed +# - for each tensor aliased to it: +# - check using its object id and version to see if it has been saved +# - if it has been saved, clone it +# - delete the reference to the original +# 3. during backward +# - if the clone exists, the tensor must've been modified in-place +_allow_mutation_on_saved_tensors_enabled = False + +def _get_tid(t) -> Tuple[int, int, int]: + return (id(t), t.data_ptr(), t._version) + +def _get_sid(t) -> Tuple[int, int]: + return (t.data_ptr(), t._version) + +class _Handle(): + pass + +class _swap_with_cloned(saved_tensors_hooks): + def __init__(self, ctx): + def pack_hook(t): + tid = _get_tid(t) + sid = _get_sid(t) + # Tensors saved for backward have an entry in _tid_to_weakhandle + handle: Optional[_Handle] = None + + # Save aliasing information + ctx.sid_to_tid[sid].add(tid) + + # NB: The same tensor (of the same version) can be saved multiple times + if tid not in ctx.tid_to_weakhandle: + handle = _Handle() + ctx.tid_to_weakhandle[tid] = handle + ctx.original[handle] = t + else: + # Store an additional strong reference to the handle + handle = ctx.tid_to_weakhandle[tid] + return handle + + def unpack_hook(tup): + handle = tup + error_msg = ( + "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" + "in which the graph was originally recorded.") + assert _allow_mutation_on_saved_tensors_enabled, error_msg + if handle in ctx.cloned: + res = ctx.cloned[handle] + else: + assert handle in ctx.original, error_msg + res = ctx.original[handle] + return res + + super().__init__(pack_hook, unpack_hook) + +class _CloneArgBeforeMutateMode(TorchDispatchMode): + def __init__(self, ctx): + self.ctx = ctx + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + for idx, arg in enumerate(func._schema.arguments): + if arg.alias_info is not None and arg.alias_info.is_write: + t = kwargs["out"] if arg.is_out else args[idx] + tid = _get_tid(t) + sid = _get_sid(t) + ctx = self.ctx + if sid in ctx.sid_to_tid: + for tid in ctx.sid_to_tid[sid]: + if tid not in ctx.tid_to_weakhandle: + # We know that if tid is in sid_to_tid, then it must also be in + # tid_to_weakhandle. However, it is possible for the tensor to be + # saved at one point, but cleared by backward before it is modified + # in-place. Consider the following example: + # + # >>> a = torch.randn(2, 3, requires_grad=True).clone() + # >>> out = (a**2).sum() + # >>> out.backward() + # >>> a.sin_() + continue + handle = ctx.tid_to_weakhandle[tid] + if handle in ctx.cloned: + # The same exact tensor has been cloned already + continue + ctx.cloned[handle] = ctx.original[handle].clone() + del ctx.original[handle] + + rs = func(*args, **kwargs) + return rs + +class _AllowMutationOnSavedContext(): + def __init__(self): + self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self.tid_to_weakhandle: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(set) + + def clear(self): + self.cloned.clear() + self.original.clear() + self.tid_to_weakhandle.clear() + self.sid_to_tid.clear() + +@contextlib.contextmanager +def allow_mutation_on_saved_tensors(): + """Context manager under which mutating tensors saved for backward is allowed + + Under this context manager, tensors saved for backward are cloned on mutation, + so the original version can still be used during backward. Normally, mutating a tensor + saved for backward will result in an error raised when it's used during backward. + + To ensure the correct behavior, both the forward and backward should be run under + the same context manager. + + returns: + An _AllowMutationOnSavedContext object storing the state managed by this + context manager. This object can be useful for debugging purposes. The state + managed by the context manager is automatically cleared upon exiting. + + Example:: + + >>> import torch + >>> with torch.autograd.graph.allow_mutation_on_saved_tensors(): + ... # forward + ... a = torch.ones(2, 3, requires_grad=True) + ... b = a.clone() + ... out = (b**2).sum() + ... b.sin_() + ... # backward + ... out.sum().backward() + ... + tensor([[0.8415, 0.8415, 0.8415], + [0.8415, 0.8415, 0.8415]], grad_fn=) + """ + global _allow_mutation_on_saved_tensors_enabled + + ctx = _AllowMutationOnSavedContext() + + with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx): + try: + if _allow_mutation_on_saved_tensors_enabled: + raise RuntimeError("allow_mutation_on_saved_tensors contexts cannot be nested") + _allow_mutation_on_saved_tensors_enabled = True + yield ctx + finally: + ctx.clear() + _allow_mutation_on_saved_tensors_enabled = False From 3c7f96665e784a793d2d1a120ea8fe370b3f6d81 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 10 Nov 2022 19:54:56 +0000 Subject: [PATCH 22/62] [FSDP][state_dict][3/N] Change how state_dict utils access attributes in _FSDPState (#88635) **What This PR Does** _state_dict_utils currently accesses the FSDP states through module. To enable composable FSDP state_dict, these accesses need to go through _FSDPState. module is still required for most APIs as state_dict has to access per-module information. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88635 Approved by: https://github.com/awgu --- torch/distributed/fsdp/_common_utils.py | 18 ++ torch/distributed/fsdp/_state_dict_utils.py | 260 ++++++++++++-------- 2 files changed, 177 insertions(+), 101 deletions(-) diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index c93c8abb5ebd8fa..f6ccc3e9243f8e0 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -61,6 +61,24 @@ def _all_handles(state: _FSDPState) -> List: ) +@no_type_check +def _module_handles(state: _FSDPState, module: nn.Module) -> List: + """ + Given a module and returns the flat handles that map to this module. If the + module is FullyShardedDataParallel, the module._handles will be returned. + """ + if _is_composable(state): + return state._module_to_handles[module][:] + else: + return module._handles[:] + + +@no_type_check +def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool: + """Given a module and returns if this module has parameters sharded by FSDP.""" + return len(_module_handles(state, module)) > 0 + + def clean_tensor_name(tensor_name: str) -> str: """ Cleans the parameter or buffer name by removing any module wrapper diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index c90bd4d409b1ff9..0bfd149b0112c9c 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -1,7 +1,7 @@ import functools import math import warnings -from typing import Any, Callable, cast, Dict, Iterator, Tuple +from typing import Any, Callable, cast, Dict, Iterator, no_type_check, Tuple import torch import torch.distributed as dist @@ -18,6 +18,9 @@ ShardedTensor, ) from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _has_fsdp_params, + _module_handles, clean_tensor_name, FSDP_PREFIX, FSDP_WRAPPED_MODULE, @@ -51,24 +54,28 @@ def _convert_to_wrapped_module_name(module_name: str) -> str: return module_name -def _param_fqns(module) -> Iterator[Tuple[str, str, str]]: - if not module._has_params: +def _param_fqns(module, fsdp_state: _FSDPState) -> Iterator[Tuple[str, str, str]]: + if not _has_fsdp_params(fsdp_state, module): return - for param_name, module_name in module._handles[0].parameter_module_names(): + for param_name, module_name in _module_handles(fsdp_state, module)[ + 0 + ].parameter_module_names(): module_name = _convert_to_wrapped_module_name(module_name) fqn = f"{module_name}{param_name}" yield fqn, param_name, module_name -def _shared_param_fqns(module) -> Iterator[Tuple[str, str, str]]: - for param_name, module_name in module._handles[0].shared_parameter_module_names(): +def _shared_param_fqns(module, fsdp_state) -> Iterator[Tuple[str, str, str]]: + for param_name, module_name in _module_handles(fsdp_state, module)[ + 0 + ].shared_parameter_module_names(): module_name = _convert_to_wrapped_module_name(module_name) fqn = f"{module_name}{param_name}" yield fqn, param_name, module_name def _enter_full_param_ctx( - module, + fsdp_state: _FSDPState, recurse: bool = False, writeback: bool = False, rank0_only: bool = False, @@ -80,53 +87,56 @@ def _enter_full_param_ctx( requires to enter the context in the pre-hook but leave the context in the post-hook. This API enters the context of ``summon_full_params``. """ - assert module._full_param_ctx is None, ( - "Entering the ``summon_full_params`` context but module._full_param_ctx " + assert fsdp_state._full_param_ctx is None, ( + "Entering the ``summon_full_params`` context but fsdp_state._full_param_ctx " "is not None." ) - assert module.training_state != TrainingState.SUMMON_FULL_PARAMS, ( + assert fsdp_state.training_state != TrainingState.SUMMON_FULL_PARAMS, ( "Entering the summon_full_params context but the state is already " "SUMMON_FULL_PARAMS." ) - module._full_param_ctx = module._summon_full_params( + fsdp_state._full_param_ctx = fsdp_state._summon_full_params( recurse=recurse, writeback=writeback, rank0_only=rank0_only, offload_to_cpu=offload_to_cpu, with_grads=with_grads, ) - module._full_param_ctx.__enter__() + fsdp_state._full_param_ctx.__enter__() -def _exit_full_param_ctx(module) -> None: +@no_type_check +def _exit_full_param_ctx(fsdp_state: _FSDPState) -> None: """A helper function to exit ``summon_full_params`` context.""" - assert module.training_state == TrainingState.SUMMON_FULL_PARAMS, ( + assert fsdp_state.training_state == TrainingState.SUMMON_FULL_PARAMS, ( "Exiting the summon_full_params context but the state is not " "SUMMON_FULL_PARAMS." ) - assert module._full_param_ctx is not None - module._full_param_ctx.__exit__(None, None, None) - module._full_param_ctx = None + assert fsdp_state._full_param_ctx is not None + fsdp_state._full_param_ctx.__exit__(None, None, None) + fsdp_state._full_param_ctx = None def _common_pre_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: """Performs the pre-state_dict tasks shared by all state_dict types.""" if torch.cuda.is_available(): torch.cuda.synchronize() - _lazy_init(module, module) + # TODO: need to check if this is always correct for composable FSDP. + _lazy_init(fsdp_state, module) # TODO: change to this call after pre_state_dict_hook is in `nn.Module`. - # if module.is_root: - # _clear_grads_if_needed(module._fsdp_handles(module)) - if module._has_params: - _clear_grads_if_needed([module._handles[0]]) + # if fsdp_state.is_root: + # _clear_grads_if_needed(_all_handles(fsdp_state)) + if _has_fsdp_params(fsdp_state, module): + _clear_grads_if_needed([_module_handles(fsdp_state, module)[0]]) def _common_summon_pre_state_dict_hook( - module, + fsdp_state: _FSDPState, offload_to_cpu: bool, rank0_only: bool, ) -> None: @@ -135,7 +145,7 @@ def _common_summon_pre_state_dict_hook( ``summon_full_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. """ _enter_full_param_ctx( - module, + fsdp_state, recurse=False, writeback=False, offload_to_cpu=offload_to_cpu, @@ -144,8 +154,10 @@ def _common_summon_pre_state_dict_hook( # TODO: change to the decorator style. See ``_full_pre_state_dict_hook``. +@no_type_check def _common_summon_post_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, param_hook: Callable, @@ -157,17 +169,17 @@ def _common_summon_post_state_dict_hook( """ _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) assert ( - module.training_state == TrainingState.SUMMON_FULL_PARAMS + fsdp_state.training_state == TrainingState.SUMMON_FULL_PARAMS ), "Inside the post_state_dict_hook but the state is not SUMMON_FULL_PARAMS." # Return early for trivial cases - if not state_dict or not module._has_params: - _exit_full_param_ctx(module) + if not state_dict or not _has_fsdp_params(fsdp_state, module): + _exit_full_param_ctx(fsdp_state) return state_dict # TODO: Once pre_state_dict hook is supported, this pop should be removed. # For `use_orig_params=True`, the `FlatParameter` is not registered, so # there is no entry in the state dict for it to pop. - if not module._use_orig_params: + if not fsdp_state._use_orig_params: state_dict.pop(f"{prefix}{fsdp_file.FLAT_PARAM}") # If a rank does not have unsharded parameters(when `rank0_only=True` @@ -175,25 +187,25 @@ def _common_summon_post_state_dict_hook( # all-gather and does not need to save the # state dict. We simply check # rank0_only to ensure this issue. rank0_only = ( - module._state_dict_type == StateDictType.FULL_STATE_DICT - and cast(FullStateDictConfig, module._state_dict_config).rank0_only + fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT + and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only ) # no_fsdp_return means the state_dict returned by this rank should contain # only non-FSDP controlled parameters and buffers. - no_fsdp_return = rank0_only and module.rank != 0 - if no_fsdp_return and not module._use_orig_params: - for clean_key in module._buffer_names: + no_fsdp_return = rank0_only and fsdp_state.rank != 0 + if no_fsdp_return and not fsdp_state._use_orig_params: + for clean_key in fsdp_state._buffer_names: # This is a hack to support activation checkpoint. clean_key = clean_key.replace( f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" ) state_dict.pop(f"{prefix}{clean_key}", None) - _exit_full_param_ctx(module) + _exit_full_param_ctx(fsdp_state) return state_dict # Loop only the parameters saved in this instance's wrapped module to # avoid processing buffers. - for fqn, param_name, module_name in _param_fqns(module): + for fqn, param_name, module_name in _param_fqns(module, fsdp_state): # TODO: remove the parameter retrieval. See ``_full_pre_state_dict_hook``. param = functools.reduce(getattr, fqn.split("."), module.module) fqn = f"{prefix}{fqn}" @@ -205,16 +217,16 @@ def _common_summon_post_state_dict_hook( f"FSDP assumes {fqn} is in the state_dict but the state_dict only " f"has {state_dict.keys()}. " f"prefix={prefix}, module_name={module_name}, " - f"param_name={param_name} rank={module.rank}." + f"param_name={param_name} rank={fsdp_state.rank}." ) - param_hook(module, state_dict, prefix, fqn) - _exit_full_param_ctx(module) + param_hook(state_dict, prefix, fqn) + _exit_full_param_ctx(fsdp_state) cpu_device = torch.device("cpu") buffer_clean_fqns = [] buffers = [] - for clean_key in module._buffer_names: + for clean_key in fsdp_state._buffer_names: # This is a hack to support activation checkpoint. clean_key = clean_tensor_name(clean_key) fqn = f"{prefix}{clean_key}" @@ -225,22 +237,29 @@ def _common_summon_post_state_dict_hook( state_dict.pop(fqn) else: buffer = state_dict[fqn] - if module._state_dict_config.offload_to_cpu and buffer.device != cpu_device: + if ( + fsdp_state._state_dict_config.offload_to_cpu + and buffer.device != cpu_device + ): state_dict[fqn] = buffer.to(cpu_device) # TODO: for composable FSDP, this should be clean_tensor_name(clean_key), buffer_clean_fqns.append(clean_key) buffers.append(state_dict[fqn]) - if buffers and module._mixed_precision_enabled_for_buffers(): - buffer_dtypes = _get_buffer_dtypes(module, buffer_clean_fqns) - _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, module.compute_device) + if buffers and fsdp_state._mixed_precision_enabled_for_buffers(): + buffer_dtypes = _get_buffer_dtypes(fsdp_state, buffer_clean_fqns) + _cast_buffers_to_dtype_and_device( + buffers, buffer_dtypes, fsdp_state.compute_device + ) for buffers, clean_fqn in zip(buffers, buffer_clean_fqns): fqn = f"{prefix}{clean_fqn}" state_dict[fqn] = buffer.clone() return state_dict +@no_type_check def _full_pre_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: @@ -254,16 +273,18 @@ def _full_pre_state_dict_hook( TODO: clean the callsites and hacks after ``pre_state_dict_hook` ` is supported in ``nn.Module``. """ - _common_pre_state_dict_hook(module, state_dict, prefix) + _common_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) _common_summon_pre_state_dict_hook( - module, - offload_to_cpu=module._state_dict_config.offload_to_cpu, - rank0_only=cast(FullStateDictConfig, module._state_dict_config).rank0_only, + fsdp_state, + offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu, + rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only, ) +@no_type_check def _full_post_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> Dict[str, Any]: @@ -274,10 +295,9 @@ def _full_post_state_dict_hook( the ``FSDP_WRAPPED_MODULE`` prefix. """ # TODO: remove the hack. See ``_full_pre_state_dict_hook``. - _full_pre_state_dict_hook(module, state_dict, prefix) + _full_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) def param_hook( - module, state_dict: Dict[str, Any], prefix: str, fqn: str, @@ -292,7 +312,7 @@ def param_hook( # Clone non-ignored parameters before exiting the # `_summon_full_params()` context - if clean_key not in module._ignored_param_names and not getattr( + if clean_key not in fsdp_state._ignored_param_names and not getattr( state_dict[fqn], "_has_been_cloned", False ): try: @@ -300,31 +320,37 @@ def param_hook( state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] except BaseException as e: warnings.warn( - f"Failed to clone() tensor with name {fqn} on rank {module.rank}. " + f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " "This may mean that this state_dict entry could point to invalid " "memory regions after returning from state_dict() call if this " "parameter is managed by FSDP. Please check clone " f"implementation of {fqn}. Error: {str(e)}" ) - return _common_summon_post_state_dict_hook(module, state_dict, prefix, param_hook) + return _common_summon_post_state_dict_hook( + module, fsdp_state, state_dict, prefix, param_hook + ) def _full_pre_load_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: - _enter_full_param_ctx(module, recurse=False, writeback=True) + _enter_full_param_ctx(fsdp_state, recurse=False, writeback=True) _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") -def _full_post_load_state_dict_hook(module, *args, **kwargs) -> None: - _exit_full_param_ctx(module) +def _full_post_load_state_dict_hook( + module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + _exit_full_param_ctx(fsdp_state) def _local_pre_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: @@ -333,16 +359,21 @@ def _local_pre_state_dict_hook( hook is not supported by the PyTorch core. So this API is called from `_local_post_state_dict_hook()` to simulate the case. """ - if module._has_params and not module._handles[0].uses_sharded_strategy: + if ( + _has_fsdp_params(fsdp_state, module) + and not _module_handles(fsdp_state, module)[0].uses_sharded_strategy + ): raise RuntimeError( "``local_state_dict`` can only be used when parameters are flatten " "and sharded." ) - _common_pre_state_dict_hook(module, state_dict, prefix) + _common_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) +@no_type_check def _local_post_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> Dict[str, Any]: @@ -352,42 +383,45 @@ def _local_post_state_dict_hook( will happen. The underlying storage is the same. """ # TODO: remove the hack. See ``_full_pre_state_dict_hook``. - _local_pre_state_dict_hook(module, state_dict, prefix) + _local_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix) - if not module._has_params: + if not _has_fsdp_params(fsdp_state, module): return state_dict # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor # value as the flat_param but it is a pure Tensor because # nn.Module.state_dict() will detach the parameter. Therefore, we need # to get flat_param to get the metadata. - assert module._handles, "Should have returned early" - flat_param = module._handles[0].flat_param + assert _module_handles(fsdp_state, module), "Should have returned early" + flat_param = _module_handles(fsdp_state, module)[0].flat_param # Construct a ShardedTensor from the flat_param. full_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined] - shard_offset = flat_param.numel() * module.rank + shard_offset = flat_param.numel() * fsdp_state.rank valid_data_size = flat_param.numel() - flat_param._shard_numel_padded if valid_data_size > 0 and flat_param._shard_numel_padded > 0: flat_param = flat_param.narrow(0, 0, valid_data_size) local_shards = [ - Shard.from_tensor_and_offsets(flat_param, [shard_offset], module.rank) + Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank) ] sharded_tensor = init_from_local_shards( - local_shards, full_numel, process_group=module.process_group + local_shards, full_numel, process_group=fsdp_state.process_group ) # type: ignore[assignment] - if module._state_dict_config.offload_to_cpu: + if fsdp_state._state_dict_config.offload_to_cpu: sharded_tensor = sharded_tensor.cpu() state_dict[f"{prefix}{fsdp_file.FLAT_PARAM}"] = sharded_tensor return state_dict -def _local_post_load_state_dict_hook(module, *args, **kwargs) -> None: +def _local_post_load_state_dict_hook( + module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: pass def _local_pre_load_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: @@ -399,7 +433,7 @@ def _local_pre_load_state_dict_hook( _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}") fqn = f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}" if fqn not in state_dict: - assert not module._has_params, ( + assert not _has_fsdp_params(fsdp_state, module), ( "No `FlatParameter` in `state_dict` for this FSDP instance " "but it has parameters" ) @@ -416,7 +450,7 @@ def _local_pre_load_state_dict_hook( # Get the metadata of the flat_param to decide whether to pad the loaded # tensor. - flat_param = module._handles[0].flat_param + flat_param = _module_handles(fsdp_state, module)[0].flat_param assert flat_param is not None if flat_param._shard_numel_padded not in (0, flat_param.numel()): assert load_tensor.numel() < flat_param.numel(), ( @@ -429,6 +463,7 @@ def _local_pre_load_state_dict_hook( def _sharded_pre_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: @@ -436,23 +471,28 @@ def _sharded_pre_state_dict_hook( Hook that runs before model.state_dict() is called. Check ``_full_pre_load_state_dict_hook`` for the detail. """ - if module._has_params and not module._handles[0].uses_sharded_strategy: + if ( + _has_fsdp_params(fsdp_state, module) + and not _module_handles(fsdp_state, module)[0].uses_sharded_strategy + ): raise RuntimeError( "``sharded_state_dict`` can only be used when parameters are flatten " "and sharded." ) - _common_pre_state_dict_hook(module, state_dict, prefix) + _common_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) # Setting offload_to_cpu here does not work even if offload_to_cpu is True. # We have to create ShardedTensor first then move it to CPU. _common_summon_pre_state_dict_hook( - module, + fsdp_state, offload_to_cpu=False, rank0_only=False, ) +@no_type_check def _sharded_post_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> Dict[str, Any]: @@ -462,31 +502,38 @@ def _sharded_post_state_dict_hook( """ # TODO: remove the hack. See ``_full_pre_state_dict_hook``. - _sharded_pre_state_dict_hook(module, state_dict, prefix) + _sharded_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) - def param_hook(module, state_dict: Dict[str, Any], prefix: str, fqn: str): + def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str): param = state_dict[fqn] sharded_tensor = _ext_chunk_tensor( tensor=param, - rank=module.rank, - world_size=module.world_size, + rank=fsdp_state.rank, + world_size=fsdp_state.world_size, num_devices_per_node=torch.cuda.device_count(), - pg=module.process_group, + pg=fsdp_state.process_group, ) - if module._state_dict_config.offload_to_cpu: + if fsdp_state._state_dict_config.offload_to_cpu: sharded_tensor = sharded_tensor.cpu() state_dict[fqn] = sharded_tensor - return _common_summon_post_state_dict_hook(module, state_dict, prefix, param_hook) + return _common_summon_post_state_dict_hook( + module, fsdp_state, state_dict, prefix, param_hook + ) -def _sharded_post_load_state_dict_hook(module, *args, **kwargs) -> None: - if module._use_orig_params: - module._register_orig_params() +@no_type_check +def _sharded_post_load_state_dict_hook( + module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + if fsdp_state._use_orig_params: + fsdp_state._register_orig_params() +@no_type_check def _sharded_pre_load_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: @@ -495,19 +542,19 @@ def _sharded_pre_load_state_dict_hook( a new FlatParameter and shards the new FlatParameter to the local chunk. """ _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") - if not module._has_params: + if not _has_fsdp_params(fsdp_state, module): return - if not module._handles[0].uses_sharded_strategy: + if not _module_handles(fsdp_state, module)[0].uses_sharded_strategy: raise RuntimeError( "load_sharded_state_dict can only be called when parameters " "are flatten and sharded." ) nonsharded_tensors = [] - shared_fqns = [fqn for fqn, _, _ in _shared_param_fqns(module)] + shared_fqns = [fqn for fqn, _, _ in _shared_param_fqns(module, fsdp_state)] loaded_shapes = [] - for fqn, _, _ in _param_fqns(module): + for fqn, _, _ in _param_fqns(module, fsdp_state): full_fqn = f"{prefix}{FSDP_PREFIX}{fqn}" param = state_dict.pop(full_fqn) if fqn in shared_fqns: @@ -517,12 +564,12 @@ def _sharded_pre_load_state_dict_hook( loaded_shapes.append(param.size()) assert len(shards) < 2, ( "Expects 0 or 1 shard per rank " - f"but got {len(shards)} shards on rank {module.rank}." + f"but got {len(shards)} shards on rank {fsdp_state.rank}." ) param_numel = param.size().numel() dim_0_size = param.size()[0] chunk_size = ( - math.ceil(dim_0_size / module.world_size) * param_numel // dim_0_size + math.ceil(dim_0_size / fsdp_state.world_size) * param_numel // dim_0_size ) if len(shards) == 1: local_tensor = shards[0].tensor.flatten() @@ -534,14 +581,16 @@ def _sharded_pre_load_state_dict_hook( else: local_tensor = torch.zeros(chunk_size, dtype=param.dtype).cuda() tensor = torch.empty( - chunk_size * module.world_size, dtype=local_tensor.dtype + chunk_size * fsdp_state.world_size, dtype=local_tensor.dtype ).cuda() - dist.all_gather_into_tensor(tensor, local_tensor, group=module.process_group) + dist.all_gather_into_tensor( + tensor, local_tensor, group=fsdp_state.process_group + ) tensor = tensor.narrow(0, 0, param_numel).reshape(param.size()) nonsharded_tensors.append(tensor) # Create a new flat_param from the loaded, non-sharded tensors. - flat_param = module._handles[0].flat_param + flat_param = _module_handles(fsdp_state, module)[0].flat_param loaded_flat_param = FlatParamHandle.flatten_params( nonsharded_tensors, requires_grad=False ) @@ -549,8 +598,8 @@ def _sharded_pre_load_state_dict_hook( # Get the chunk from the loaded flat_param for the local rank. loaded_flat_tensor, num_to_pad = FlatParamHandle._get_shard( loaded_flat_param, - module.rank, - module.world_size, + fsdp_state.rank, + fsdp_state.world_size, ) loaded_flat_tensor.to(flat_param.device) assert all(s1 == s2 for s1, s2 in zip(loaded_shapes, flat_param._shapes)), ( @@ -567,10 +616,11 @@ def _sharded_pre_load_state_dict_hook( f"from the local chunk {flat_param._shard_numel_padded}." ) state_dict[f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}"] = loaded_flat_tensor - if module._use_orig_params: - module._deregister_orig_params() + if fsdp_state._use_orig_params: + fsdp_state._deregister_orig_params() +@no_type_check @torch.no_grad() def _post_state_dict_hook( module: nn.Module, @@ -580,21 +630,24 @@ def _post_state_dict_hook( ) -> Dict[str, Any]: """ _post_state_dict_hook() is called after the state_dict() of this - FSDP module is executed. ``module._state_dict_type`` is used to decide + FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide what postprocessing will be done. """ + # TODO: get the composable state from module + fsdp_state: _FSDPState = module _post_state_dict_hook_fn = { StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook, StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, } fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) - processed_state_dict = _post_state_dict_hook_fn[fsdp_module._state_dict_type]( - fsdp_module, state_dict, prefix + processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type]( + fsdp_module, fsdp_state, state_dict, prefix ) return processed_state_dict +@no_type_check @torch.no_grad() def _pre_load_state_dict_hook( module: nn.Module, @@ -604,9 +657,11 @@ def _pre_load_state_dict_hook( ) -> None: """ ``_pre_state_dict_hook` is called before ``module._load_from_state_dict()`` - is called. ``module._state_dict_type`` is used to decide what preprocessing + is called. ``fsdp_state._state_dict_type`` is used to decide what preprocessing will be done. """ + # TODO: get the composable state from module + fsdp_state: _FSDPState = module _pre_load_state_dict_hook_fn = { StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook, StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook, @@ -617,13 +672,16 @@ def _pre_load_state_dict_hook( if torch.cuda.is_available(): torch.cuda.synchronize() # Dispatch into state_dict specific implementation of pre-hook. - _pre_load_state_dict_hook_fn[fsdp_module._state_dict_type]( - fsdp_module, state_dict, prefix + _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type]( + fsdp_module, fsdp_state, state_dict, prefix ) +@no_type_check @torch.no_grad() def _post_load_state_dict_hook(module: nn.Module, *args: Any) -> None: + # TODO: get the composable state from module + fsdp_state: _FSDPState = module _post_load_state_dict_hook_fn = { StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook, StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook, @@ -633,4 +691,4 @@ def _post_load_state_dict_hook(module: nn.Module, *args: Any) -> None: fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) # Dispatch into state_dict type specific implementation of post-hook for # loading state_dict. - _post_load_state_dict_hook_fn[fsdp_module._state_dict_type](fsdp_module) + _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](fsdp_module, fsdp_state) From d615d1228932eaa5e026f5399e099f2036d2379b Mon Sep 17 00:00:00 2001 From: anjali411 Date: Fri, 11 Nov 2022 15:24:28 +0000 Subject: [PATCH 23/62] Add meta impl for topk (#88694) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88694 Approved by: https://github.com/ezyang --- test/functorch/test_aotdispatch.py | 1 - test/test_proxy_tensor.py | 1 - torch/_meta_registrations.py | 17 +++++++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 09b65a32bfee9b7..4da39210343e7e4 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1214,7 +1214,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition xfail('tensordot', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('topk', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('trapz', ''), # Cannot call sizes() on tensor with symbolic sizes/strides diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 72c7249f4f14582..d1a5c9498bcaa4e 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1337,7 +1337,6 @@ def f(a, b, c, d, e): xfail('take_along_dim', ''), # dtype of indices should be Long but got Float xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition xfail('tensordot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition - xfail('topk', ''), # aten.topk.default - couldn't find symbolic meta function/decomposition xfail('trapz', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('trapezoid', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 04c522ab9e3b4db..5d583de67d19688 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1777,6 +1777,23 @@ def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None): ) +@register_meta(aten.topk.default) +def topk_meta(self, k, dim=-1, largest=True, sorted=True): + # From aten/src/ATen/native/Sorting.cpp + dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True) + check( + k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1), + lambda: "selected index k out of range", + ) + sliceSize = 1 if self.dim() == 0 else self.size(dim) + check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension") + + topKSize = list(self.shape) + if len(topKSize) > 0: + topKSize[dim] = k + return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64) + + # We must also trigger meta registrations from PrimTorch ref # decompositions import torch._refs From 1e8f95ace16cb617d71f8f8254c1d5bafd9f586c Mon Sep 17 00:00:00 2001 From: Nikita Karetnikov Date: Fri, 11 Nov 2022 13:51:18 +0100 Subject: [PATCH 24/62] Symintify `broadcast_to` (#88776) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88776 Approved by: https://github.com/ezyang --- .../ATen/functorch/BatchRulesDecompositions.cpp | 2 +- aten/src/ATen/native/TensorShape.cpp | 4 ++-- aten/src/ATen/native/native_functions.yaml | 4 +++- test/functorch/test_aotdispatch.py | 11 ++--------- test/test_proxy_tensor.py | 15 --------------- 5 files changed, 8 insertions(+), 28 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 66aaa53bfcc1fb8..e31b36d112418ed 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -63,7 +63,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE2(bitwise_or, Scalar); OP_DECOMPOSE2(bitwise_xor, Scalar); OP_DECOMPOSE(broadcast_tensors); - OP_DECOMPOSE(broadcast_to); + m.impl("broadcast_to", native::broadcast_to_symint); OP_DECOMPOSE(cartesian_prod); OP_DECOMPOSE(cdist); OP_DECOMPOSE(clip); diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 31b4011c1281311..deb9b949aa5d3e4 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -537,8 +537,8 @@ Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) { return at::sparse_coo_tensor(new_indices, new_values, size)._coalesced_(is_coalesced); } -Tensor broadcast_to(const Tensor& self, IntArrayRef size) { - return self.expand(size); +Tensor broadcast_to_symint(const Tensor& self, SymIntArrayRef size) { + return self.expand_symint(size); } std::vector broadcast_tensors(TensorList tensors) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0ea606f5e1fb597..de087c0b8a8965a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1195,8 +1195,10 @@ device_check: NoCheck device_guard: False -- func: broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) +- func: broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) variants: function, method + dispatch: + CompositeImplicitAutograd: broadcast_to_symint - func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) variants: function diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 4da39210343e7e4..f4782b8a595df92 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1093,20 +1093,13 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('masked.cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('masked.cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition xfail('masked_fill', ''), # could not find kernel - xfail('masked.log_softmax', ''), # argument 'size' (position 2) must be tuple of ints, not ... xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposi... xfail('masked.logsumexp', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=t... - xfail('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides + # Seems flaky: https://github.com/pytorch/pytorch/issues/88883 + skip('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos... - xfail('masked.softmax', ''), # argument 'size' (position 2) must be tuple of ints, not torc... - xfail('masked.softmin', ''), # argument 'size' (position 2) must be tuple of ints, not torc... - xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... - xfail('masked.sum', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... xfail('matmul', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo... xfail('median', ''), # could not find kernel diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index d1a5c9498bcaa4e..86beb651cb2d1b3 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1114,23 +1114,8 @@ def f(a, b, c, d, e): xfail('linalg.eig'), xfail('linalg.eigvals'), skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel - xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.argmax', ''), # broadcast_to(): argument 'size' (position 2) must be tuple of ints, but found ... - xfail('masked.argmin', ''), # broadcast_to(): argument 'size' (position 2) must be tuple of ints, but found ... xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition - xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ... - xfail('masked.median', ''), # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition - xfail('masked.norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition - xfail('masked.prod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.softmin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... - xfail('masked.sum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition From a6832b08a3f6c1b425a075fe204a1f21361f33d9 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 8 Nov 2022 19:23:21 +0000 Subject: [PATCH 25/62] Regularize bernouilli_ with bernouilli decomp (#88349) Fix for https://github.com/pytorch/torchdynamo/issues/1796. Just like the other [bernouilli decomp](https://github.com/pytorch/pytorch/blob/master/torch/_inductor/decomposition.py#L302) we need to pass `dtype=float32` to avoid `"check_uniform_bounds" not implemented` errors. Are we planning on enabling `TEST_WITH_TORCHINDUCTOR` ? Do I need to change anything with the tests ? Pull Request resolved: https://github.com/pytorch/pytorch/pull/88349 Approved by: https://github.com/desertfire --- torch/_inductor/decomposition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index d7aa5e35f50107f..e8a20c0dbd26eda 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -325,7 +325,7 @@ def bernoulli_p(self, p=0.5, *, generator=None): @register_extra_random_decomp([aten.bernoulli_]) def bernoulli_(self, p=0.5): - return self.copy_(torch.rand_like(self) < p) + return self.copy_(torch.rand_like(self, dtype=torch.float32) < p) @functools.lru_cache(None) From 89a326ff7ea56a1d735d26800b07a10e35c2dff4 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Fri, 11 Nov 2022 16:57:05 +0000 Subject: [PATCH 26/62] Explicitly check filelike arg of `torch.save` (#88867) Fixes #88793 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88867 Approved by: https://github.com/ezyang --- test/test_serialization.py | 9 +++++++++ torch/serialization.py | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/test/test_serialization.py b/test/test_serialization.py index 779d6fb5c20c5da..5ccc6f47b4c5d06 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -585,6 +585,15 @@ def test_serialization_filelike_exceptions(self): with self.assertRaises(TypeError): # Tries to serialize str into tensor with wrong callable write property torch.save('foo', x) + s_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + s = torch.CharStorage(s_data) + with self.assertRaises(AttributeError): + # Tries to serialize list into CharStorage + torch.save(s_data, s) + x = torch.randint(10, (3, 3), dtype=torch.float).cpu().numpy() + with self.assertRaises(AttributeError): + # Tries to serialize ndarray into ndarray + torch.save(x, x) def test_serialization_storage_slice(self): diff --git a/torch/serialization.py b/torch/serialization.py index d123a955ad96643..3078e57587be63d 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -375,6 +375,12 @@ def _check_dill_version(pickle_module) -> None: pickle_module.__version__ )) +def _check_save_filelike(f): + if not isinstance(f, (str, os.PathLike)) and not hasattr(f, 'write'): + raise AttributeError(( + "expected 'f' to be string, path, or a file-like object with " + "a 'write' attribute")) + def save( obj: object, f: FILE_LIKE, @@ -422,6 +428,7 @@ def save( >>> torch.save(x, buffer) """ _check_dill_version(pickle_module) + _check_save_filelike(f) if _use_new_zipfile_serialization: with _open_zipfile_writer(f) as opened_zipfile: From adfbd831cf59111c3d3a4a50ba6372bba94b63d1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 11 Nov 2022 17:03:25 +0000 Subject: [PATCH 27/62] Revert "[Autograd] Use in-place input accumulation fast path for dense Tensors. (#88339)" This reverts commit 8f66ae413f8c9d7f2418d7f0b9f69d409c455b46. Reverted https://github.com/pytorch/pytorch/pull/88339 on behalf of https://github.com/mehtanirav due to Internal test failures --- torch/csrc/autograd/input_buffer.cpp | 54 ++++++++-------------------- 1 file changed, 14 insertions(+), 40 deletions(-) diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 7e6df0cea8da070..6cc6acefc9d45ba 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include @@ -67,18 +66,6 @@ void record_stream_any_impl(Variable& var, c10::Stream& stream) { } } } - -bool can_accumulate_inplace(const Variable& v) { - return ( - // `v` is a "vanilla" Tensor - !(at::isTensorSubclassLike(v) || v._is_zerotensor() || v.is_nested()) && - - // with a favorable memory layout - v.is_non_overlapping_and_dense() && - - // and we hold the last reference - v.use_count() == 1 && v.storage().use_count() == 1); -} } // anonymous namespace static void accumulate( @@ -87,38 +74,25 @@ static void accumulate( Variable&& var) { TORCH_INTERNAL_ASSERT(pos < buffer.size()); auto& old_var = buffer[pos]; - // If we hold the last reference to `old_var` AND its storage we will try to - // repurpose it to store the output. (Or, if `old_var` is sparse then `var` - // becomes the candidate output Tensor.) We only do this if: - // 1) GradMode is disabled since Autograd has special handling for inplace - // mutation which we don't want to trigger. - // - // 2) We hold the last reference. - // (Both `.use_count` and `.storage().use_count()` are one) - // - // 3) The candidate tensor is a contiguous, non-overlapping, dense, and - // otherwise stock standard Tensor. - // - // 4) The candidate is mutable. Currently only ZeroTensors are immutable. - // - // 5) The other Tensor is not a Tensor subclass (except sparse), since - // it's hard to predict the semantics of arbitrary subclass behavior. - - if (at::GradMode::is_enabled()) { - buffer[pos] = old_var + var; - } else if ( - // ATen doesn't route sparse additions correctly... - old_var.is_sparse() || old_var.is_sparse_csr()) { - if (can_accumulate_inplace(var)) { + // ATen doesn't route sparse additions correctly... + // do dense + sparse in-place if possible + if (old_var.is_sparse()) { + // It is safe to change the Tensor inplace if the Tensor is only used in + // this buffer (this could be the gradient passed by the user) and that no + // other Tensor is using the same storage. + if (!var.is_sparse() && var.is_contiguous() && var.use_count() == 1 && + var.storage().use_count() == 1) { buffer[pos] = var.add_(old_var); } else { buffer[pos] = var + old_var; } - } else if ( - can_accumulate_inplace(old_var) && !at::isTensorSubclassLike(var)) { - buffer[pos] = old_var.add_(var); } else { - buffer[pos] = old_var + var; + if (var.is_sparse() && !old_var.is_sparse() && old_var.is_contiguous() && + old_var.use_count() == 1 && old_var.storage().use_count() == 1) { + buffer[pos] = old_var.add_(var); + } else { + buffer[pos] = old_var + var; + } } } From 8ff2e34ca6905404aba35a432acf667ee6a13c6e Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Fri, 11 Nov 2022 04:25:11 +0000 Subject: [PATCH 28/62] Take input striding for conv forward based on eager output (#88706) From discussion with @Chillee and @ngimel we'll likely need further fixes to ensure that we hit channels last kernels but this is still worth landing in its own right. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88706 Approved by: https://github.com/ngimel --- test/inductor/test_torchinductor.py | 26 +++++++++++ torch/_inductor/ir.py | 72 +++++++++++++++++------------ 2 files changed, 69 insertions(+), 29 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 121f3d31f39c201..aea8013bdfac8dd 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4601,6 +4601,8 @@ def fn(a): CommonTemplate.install(CudaTests, "cuda") class CudaReproTests(TestCase): + common = check_model_cuda + def test_index_put_issue(self): def forward( self, @@ -4637,6 +4639,30 @@ def forward( compiled = compile_fx_inner(mod, inps) compiled(inps) + @requires_cuda() + def test_input_channels_last(self): + m = torch.nn.Sequential( + torch.nn.Conv2d(3, 3, 1, 1), + ToTuple(), + ).cuda() + inp = ( + torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda() + ) + + self.common( + m, + (inp,), + check_lowp=False, + ) + + @torch._dynamo.optimize() + def foo(m, inp): + return m(inp) + + self.assertTrue( + foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last) + ) + # https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527 @requires_cuda() def test_unspec_inputs_interop(self): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 448c057ecb0e15d..240c196a73b6d35 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -19,7 +19,12 @@ import torch.fx import torch.utils._pytree as pytree -from torch._prims_common import is_boolean_dtype, is_float_dtype +from torch._prims_common import ( + is_boolean_dtype, + is_float_dtype, + make_channels_last_strides_for, + make_contiguous_strides_for, +) from torch._subclasses.fake_tensor import FakeTensorMode from . import config, dependencies @@ -133,7 +138,7 @@ def ir_node_to_tensor(x, guard_shape=True): if is_storage_and_layout(x): stride = [shape_fn(s) for s in x.get_layout().stride] else: - stride = torch._prims_common.make_contiguous_strides_for(size) + stride = make_contiguous_strides_for(size) dtype = x.get_dtype() device = x.get_device() t = torch.empty_strided( @@ -2462,6 +2467,9 @@ def require_stride_order(cls, x, order): x.get_layout(), FixedLayout ) and x.get_layout().is_stride_ordered(order): return x + # TODO - Storage to InputBuffer + if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): + return x x = cls.copy_input(x) as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order) assert is_stride_order_storage_and_layout(x, order) @@ -3052,9 +3060,32 @@ def create( output_padding_: List[int], groups: int, ): + with torch._subclasses.FakeTensorMode(): + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, + stride_, + padding_, + dilation_, + transposed, + output_padding_, + groups, + ) + req_stride_order = get_stride_order(output.stride()) + + if config.triton.convolution == "aten": + weight = cls.require_stride_order(weight, req_stride_order) + x = cls.require_stride_order(x, req_stride_order) + else: + x = cls.require_stride1(cls.realize_input(x)) + weight = cls.require_stride1(cls.realize_input(weight)) - weight = cls.require_stride1(cls.realize_input(weight)) - x = cls.require_stride_order(x, get_stride_order(weight.get_stride())) stride = tuple(stride_) padding = tuple(padding_) dilation = tuple(dilation_) @@ -3062,22 +3093,6 @@ def create( output_padding = tuple(output_padding_) assert isinstance(groups, int) - # TODO - enable FakeTensorMode for propagation more globally. incorrect stride metas for fallback - # kernels will lead to runtime failures - with FakeTensorMode(): - output, *_ = cls.process_kernel( - torch.ops.aten.convolution, - x, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ) - output_size = output.shape weight_shape = [ @@ -3122,6 +3137,7 @@ def create( # for conv2d or conv3d, prefer channels last format if kernel == "triton_ops.conv": output_layout_str = "torch.channels_last" + elif config.tune_layout and len(x.get_size()) == 4: from .codegen.autotuner import tuned_conv_layout @@ -3151,14 +3167,19 @@ def create( if len(stride_order) < len(output_size): # add batch dim if it exists stride_order = [len(stride_order)] + stride_order + strides = make_channels_last_strides_for(output_size) else: stride_order = list(reversed(range(len(output_size)))) + strides = make_contiguous_strides_for(output_size) - output_layout = FlexibleLayout( + if config.triton.convolution != "aten": + x = cls.require_stride_order(x, stride_order) + + output_layout = FixedLayout( x.get_device(), x.get_dtype(), output_size, - stride_order, + strides, ) if bias is not None: @@ -3178,13 +3199,6 @@ def create( kernel, ) - def apply_constraint(self): - x = self.inputs[0] - # FixedLayout of input - x = self.require_stride_order(x, self.layout.preferred_stride_order) - self.inputs[0] = x - self.freeze_layout_with_stride_order(self.layout.preferred_stride_order) - def map_args(self): # x, w, bias in_args = [x.codegen_reference() for x in self.inputs] From 5f0783bd6d27a0a239263b943d626c533b8b9a90 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Fri, 11 Nov 2022 17:43:46 +0000 Subject: [PATCH 29/62] Fix ATen Fallback for BUILD_CAFFE2=0 for ONNX-only ops (#88504) Follow-up for #87735 Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0 A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again Fixes #87313 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88504 Approved by: https://github.com/justinchuby, https://github.com/BowenBao --- test/onnx/test_pytorch_onnx_no_runtime.py | 220 +++++++++++++--------- torch/onnx/utils.py | 19 +- 2 files changed, 149 insertions(+), 90 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 622f42effb4ab47..89526c71ca3871b 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -18,7 +18,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.onnx import symbolic_helper, utils +from torch.onnx import OperatorExportTypes, symbolic_helper, utils from torch.onnx._globals import GLOBALS from torch.onnx._internal import registration from torch.testing._internal import common_quantization, common_utils, jit_utils @@ -935,6 +935,139 @@ def forward(self, x, w): torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5))) + @common_utils.skipIfNoCaffe2 + def test_caffe2_aten_fallback_must_fallback(self): + class ModelWithAtenNotONNXOp(torch.nn.Module): + def forward(self, x, y): + abcd = x + y + defg = torch.linalg.qr(abcd) + return defg + + # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize + for operator_export_type in ( + OperatorExportTypes.ONNX_ATEN, + OperatorExportTypes.ONNX_ATEN_FALLBACK, + ): + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=operator_export_type, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") + + @common_utils.skipIfNoCaffe2 + def test_caffe2_onnx_aten_must_not_fallback(self): + class ModelWithAtenFmod(torch.nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize + for operator_export_type in ( + OperatorExportTypes.ONNX_ATEN_FALLBACK, + OperatorExportTypes.ONNX_ATEN, + ): + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=operator_export_type, + opset_version=10, # or higher + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + assert onnx_model.graph.node[0].op_type == "Mod" + + @common_utils.skipIfCaffe2 + def test_aten_fallback_must_fallback(self): + class ModelWithAtenNotONNXOp(torch.nn.Module): + def forward(self, x, y): + abcd = x + y + defg = torch.linalg.qr(abcd) + return defg + + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") + + @common_utils.skipIfCaffe2 + def test_onnx_aten(self): + class ModelWithAtenFmod(torch.nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "fmod", "Tensor") + + @common_utils.skipIfCaffe2 + def test_onnx_aten_fallback_must_not_fallback(self): + # For BUILD_CAFFE2=0, aten fallback only when not exportable + class ONNXExportable(torch.nn.Module): + def __init__(self): + super(ONNXExportable, self).__init__() + self.quant = torch.quantization.QuantStub() + self.fc1 = torch.nn.Linear(12, 8) + self.fc2 = torch.nn.Linear(8, 4) + self.fc3 = torch.nn.Linear(4, 6) + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = x.view((-1, 12)) + h = F.relu(self.fc1(x)) + h = F.relu(self.fc2(h)) + h = F.relu(self.fc3(h)) + h = self.dequant(h) + return h + + dummy_input = torch.randn(12) + f = io.BytesIO() + torch.onnx.export( + ONNXExportable(), + (dummy_input,), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + all_aten_nodes = [ + p + for p in onnx_model.graph.node + if p.op_type == "ATen" and p.domain == "org.pytorch.aten" + ] + self.assertEqual(len(all_aten_nodes), 0) + class TestQuantizeEagerONNXExport(common_utils.TestCase): def _test_lower_graph_impl(self, model, data): @@ -997,91 +1130,6 @@ def test_lower_graph_conv3d(self): data = torch.from_numpy(data_numpy).to(dtype=torch.float) self._test_lower_graph_impl(model, data) - @common_utils.skipIfNoCaffe2 - def test_caffe2_aten_fallback(self): - class ModelWithAtenNotONNXOp(torch.nn.Module): - def forward(self, x, y): - abcd = x + y - defg = torch.linalg.qr(abcd) - return defg - - x = torch.rand(3, 4) - y = torch.rand(3, 4) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenNotONNXOp(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - # support for linalg.qr was added in later op set versions. - opset_version=9, - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - self.assertAtenOp(onnx_model, "linalg_qr") - - @common_utils.skipIfNoCaffe2 - def test_caffe2_onnx_aten(self): - class ModelWithAtenFmod(torch.nn.Module): - def forward(self, x, y): - return torch.fmod(x, y) - - x = torch.randn(3, 4, dtype=torch.float32) - y = torch.randn(3, 4, dtype=torch.float32) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenFmod(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, - opset_version=10, # or higher - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - assert onnx_model.graph.node[0].op_type == "Mod" - - @common_utils.skipIfCaffe2 - def test_aten_fallback(self): - class ModelWithAtenNotONNXOp(torch.nn.Module): - def forward(self, x, y): - abcd = x + y - defg = torch.linalg.qr(abcd) - return defg - - x = torch.rand(3, 4) - y = torch.rand(3, 4) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenNotONNXOp(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - # support for linalg.qr was added in later op set versions. - opset_version=9, - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - self.assertAtenOp(onnx_model, "linalg_qr") - - @common_utils.skipIfCaffe2 - def test_onnx_aten(self): - class ModelWithAtenFmod(torch.nn.Module): - def forward(self, x, y): - return torch.fmod(x, y) - - x = torch.randn(3, 4, dtype=torch.float32) - y = torch.randn(3, 4, dtype=torch.float32) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenFmod(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - self.assertAtenOp(onnx_model, "fmod", "Tensor") - if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index ff0ef755968d3a5..b30b71812aaefae 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1752,10 +1752,21 @@ def _should_aten_fallback( ) is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK - return name.startswith("aten::") and ( - ((is_onnx_aten_export or is_aten_fallback_export) and not is_caffe2_build) - or (not is_exportable_aten_op and is_aten_fallback_export) - ) + if not name.startswith("aten::"): + return False + + if is_caffe2_build: + if ( + is_onnx_aten_export or is_aten_fallback_export + ) and not is_exportable_aten_op: + return True + else: + if is_onnx_aten_export or ( + is_aten_fallback_export and not is_exportable_aten_op + ): + return True + + return False @_beartype.beartype From 3d1c5c89ed27ff16601aecf7834a6bd06f578c45 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 10 Nov 2022 21:19:21 +0000 Subject: [PATCH 30/62] [FSDP][state_dict][4/N] Move the core logic of summon full parameters to _unshard_params_utils.py (#88636) **What** `_summon_full_parameters` is required for state_dict. To enable composable FSDP state_dict, `_summon_full_params` must be accessible without FullyShardedDataParall. This PR move the core logic of `_summon_full_params` to `_unshard_params_utils`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88636 Approved by: https://github.com/awgu --- test/distributed/fsdp/test_fsdp_state_dict.py | 2 +- .../fsdp/test_fsdp_summon_full_params.py | 4 +- torch/distributed/fsdp/_state_dict_utils.py | 34 ++- .../distributed/fsdp/_unshard_param_utils.py | 254 ++++++++++++++++++ .../fsdp/fully_sharded_data_parallel.py | 201 ++------------ 5 files changed, 290 insertions(+), 205 deletions(-) create mode 100644 torch/distributed/fsdp/_unshard_param_utils.py diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 48dad3118db749f..ba51ae66ed1b21e 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -25,7 +25,7 @@ StateDictType, ) from torch.distributed.fsdp._shard_utils import _gather_state_dict -from torch.distributed.fsdp.fully_sharded_data_parallel import FLAT_PARAM +from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM from torch.distributed.fsdp.wrap import enable_wrap, transformer_auto_wrap_policy, wrap from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel import DistributedDataParallel diff --git a/test/distributed/fsdp/test_fsdp_summon_full_params.py b/test/distributed/fsdp/test_fsdp_summon_full_params.py index 0d4e98069117a67..18055dbebffbf56 100644 --- a/test/distributed/fsdp/test_fsdp_summon_full_params.py +++ b/test/distributed/fsdp/test_fsdp_summon_full_params.py @@ -212,7 +212,7 @@ def forward(self, fsdp_module): model = FSDP(MyModule()).cuda(self.rank) with self.assertRaisesRegex( - ValueError, "current state is TrainingState.FORWARD" + ValueError, "Current handle state is HandleTrainingState.FORWARD" ): model(model) @@ -231,7 +231,7 @@ def bad_backwards_hook(tensor): output.register_hook(bad_backwards_hook) with self.assertRaisesRegex( - ValueError, "current state is TrainingState.FORWARD_BACKWARD" + ValueError, "Current handle state is HandleTrainingState.BACKWARD_PRE" ): output.backward() diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 0bfd149b0112c9c..eee5522340b46ec 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -24,7 +24,6 @@ clean_tensor_name, FSDP_PREFIX, FSDP_WRAPPED_MODULE, - TrainingState, ) from torch.distributed.fsdp._runtime_utils import ( _cast_buffers_to_dtype_and_device, @@ -40,6 +39,11 @@ _ext_pre_load_state_dict_transform, _extensions as _user_extensions, ) +from ._unshard_param_utils import ( + _deregister_orig_params, + _register_orig_params, + FLAT_PARAM, +) from .flat_param import FlatParamHandle @@ -91,10 +95,6 @@ def _enter_full_param_ctx( "Entering the ``summon_full_params`` context but fsdp_state._full_param_ctx " "is not None." ) - assert fsdp_state.training_state != TrainingState.SUMMON_FULL_PARAMS, ( - "Entering the summon_full_params context but the state is already " - "SUMMON_FULL_PARAMS." - ) fsdp_state._full_param_ctx = fsdp_state._summon_full_params( recurse=recurse, writeback=writeback, @@ -108,10 +108,6 @@ def _enter_full_param_ctx( @no_type_check def _exit_full_param_ctx(fsdp_state: _FSDPState) -> None: """A helper function to exit ``summon_full_params`` context.""" - assert fsdp_state.training_state == TrainingState.SUMMON_FULL_PARAMS, ( - "Exiting the summon_full_params context but the state is not " - "SUMMON_FULL_PARAMS." - ) assert fsdp_state._full_param_ctx is not None fsdp_state._full_param_ctx.__exit__(None, None, None) fsdp_state._full_param_ctx = None @@ -168,9 +164,6 @@ def _common_summon_post_state_dict_hook( hook. """ _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) - assert ( - fsdp_state.training_state == TrainingState.SUMMON_FULL_PARAMS - ), "Inside the post_state_dict_hook but the state is not SUMMON_FULL_PARAMS." # Return early for trivial cases if not state_dict or not _has_fsdp_params(fsdp_state, module): _exit_full_param_ctx(fsdp_state) @@ -180,7 +173,7 @@ def _common_summon_post_state_dict_hook( # For `use_orig_params=True`, the `FlatParameter` is not registered, so # there is no entry in the state dict for it to pop. if not fsdp_state._use_orig_params: - state_dict.pop(f"{prefix}{fsdp_file.FLAT_PARAM}") + state_dict.pop(f"{prefix}{FLAT_PARAM}") # If a rank does not have unsharded parameters(when `rank0_only=True` # and `rank != 0`), then the rank only needed to participate in the @@ -338,6 +331,7 @@ def _full_pre_load_state_dict_hook( state_dict: Dict[str, Any], prefix: str, ) -> None: + _lazy_init(fsdp_state, module) _enter_full_param_ctx(fsdp_state, recurse=False, writeback=True) _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") @@ -409,7 +403,7 @@ def _local_post_state_dict_hook( ) # type: ignore[assignment] if fsdp_state._state_dict_config.offload_to_cpu: sharded_tensor = sharded_tensor.cpu() - state_dict[f"{prefix}{fsdp_file.FLAT_PARAM}"] = sharded_tensor + state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor return state_dict @@ -430,8 +424,9 @@ def _local_pre_load_state_dict_hook( state_dict. The flat_param should be a ShardedTensor. This hook converts the ShardedTensor to a tensor. No copy happen unless padding is required. """ + _lazy_init(fsdp_state, module) _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}") - fqn = f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}" + fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}" if fqn not in state_dict: assert not _has_fsdp_params(fsdp_state, module), ( "No `FlatParameter` in `state_dict` for this FSDP instance " @@ -527,7 +522,7 @@ def _sharded_post_load_state_dict_hook( module, fsdp_state: _FSDPState, *args, **kwargs ) -> None: if fsdp_state._use_orig_params: - fsdp_state._register_orig_params() + _register_orig_params(module, fsdp_state) @no_type_check @@ -541,6 +536,7 @@ def _sharded_pre_load_state_dict_hook( The hook combines the unflattened, sharded parameters (ShardedTensor) to a new FlatParameter and shards the new FlatParameter to the local chunk. """ + _lazy_init(fsdp_state, module) _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") if not _has_fsdp_params(fsdp_state, module): return @@ -605,7 +601,7 @@ def _sharded_pre_load_state_dict_hook( assert all(s1 == s2 for s1, s2 in zip(loaded_shapes, flat_param._shapes)), ( f"The original shapes in FSDP are {flat_param._shapes}. " f"The loaded shapes are {loaded_shapes}. " - f"FSDP extension is {'NOT' if _user_extensions is None else ''} None." + f"FSDP extension is {'NOT' if _user_extensions is not None else ''} None." ) assert flat_param.numel() == loaded_flat_tensor.numel(), ( f"The loaded local chunk has different numel({loaded_flat_tensor.numel()}) " @@ -615,9 +611,9 @@ def _sharded_pre_load_state_dict_hook( f"The loaded local chunk has different padding({num_to_pad}) " f"from the local chunk {flat_param._shard_numel_padded}." ) - state_dict[f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}"] = loaded_flat_tensor + state_dict[f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}"] = loaded_flat_tensor if fsdp_state._use_orig_params: - fsdp_state._deregister_orig_params() + _deregister_orig_params(module, fsdp_state) @no_type_check diff --git a/torch/distributed/fsdp/_unshard_param_utils.py b/torch/distributed/fsdp/_unshard_param_utils.py new file mode 100644 index 000000000000000..950841850b620f4 --- /dev/null +++ b/torch/distributed/fsdp/_unshard_param_utils.py @@ -0,0 +1,254 @@ +import contextlib +import warnings +from typing import cast, Generator, List + +import torch +import torch.nn as nn +from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _has_fsdp_params, + _module_handles, + HandleTrainingState, +) +from torch.distributed.fsdp._runtime_utils import ( + _clear_grads_if_needed, + _reshard, + _reshard_grads, + _unshard, + _unshard_grads, +) +from ._utils import p_assert +from .flat_param import FlatParamHandle + +FLAT_PARAM = "_flat_param" + + +@torch.no_grad() +def _writeback_to_local_shard( + handles: List[FlatParamHandle], + writeback_grad: bool, +): + """ + For each handle, writes back the this rank's shard of the unsharded + flattened parameter to the sharded flattened parameter. If + ``writeback_grad=True``, then writes back to the sharded gradient as + well. + + Precondition: Each handle's ``FlatParameter`` 's data points to the + padded unsharded flattened parameter. + """ + for handle in handles: + # For `NO_SHARD`, `_local_shard` is the unsharded flattened + # parameter and `grad` is the unsharded gradient, so there is no + # need to writeback for either + if not handle.uses_sharded_strategy: + continue + assert ( + handle.flat_param.ndim == 1 + ), f"Expects `flat_param` to be flattened but got {handle.flat_param.shape}" + + # Get the unpadded shard instead of the padded shard to persist + # user changes to the padding (though FSDP does not explicitly + # support this) + param_shard, _ = FlatParamHandle._get_unpadded_shard( + handle.flat_param, + handle.rank, + handle.world_size, + ) + handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) # type: ignore[attr-defined] + if writeback_grad: + existing_grad = handle.sharded_grad + if existing_grad is not None: + assert handle.flat_param.grad is not None + grad_shard, _ = FlatParamHandle._get_unpadded_shard( + handle.flat_param.grad, + handle.rank, + handle.world_size, + ) + existing_grad[: grad_shard.numel()].copy_(grad_shard) + + +def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None: + """ + De-registers the flattened parameter from the wrapped module, hiding it + from ``nn.Module`` methods. + + We do not use ``del`` because we want ``FLAT_PARAM`` to always be an + attribute but dynamically change whether it is visible to ``nn.Module`` + methods. + """ + if _has_fsdp_params(state, module): + # TODO: figure out the case for the composable APIs. + cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None) + + +def _register_flat_param(state: _FSDPState, module: nn.Module) -> None: + """ + Registers the flattened parameter to the wrapped module, making it + visible to ``nn.Module`` methods. + + We do not use :meth:`nn.Module.register_parameter` because we want + ``FLAT_PARAM`` to always be an attribute but dynamically change whether + it is visible to ``nn.Module`` methods. + """ + handles = _module_handles(state, module) + if _has_fsdp_params(state, module): + # TODO: figure out the case for the composable APIs. + cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handles[0].flat_param + + +@contextlib.contextmanager +def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator: + """ + Assumes that the flattened parameter is unsharded. When in the context, + de-registers the flattened parameter and unflattens the original + parameters as ``nn.Parameter`` views into the flattened parameter. + After the context, re-registers the flattened parameter and restores + the original parameters as ``Tensor`` views into the flattened + parameter. + """ + handles = _module_handles(state, module) + if not handles: + yield + else: + _deregister_flat_param(state, module) + try: + with handles[0].unflatten_as_params(): + yield + finally: + if not handles[0]._use_orig_params: + _register_flat_param(state, module) + + +@contextlib.contextmanager +def _unshard_params( + module: nn.Module, + state: _FSDPState, + writeback: bool = True, + rank0_only: bool = False, + offload_to_cpu: bool = False, + with_grads: bool = False, +): + if with_grads and (offload_to_cpu or not state._use_orig_params): + raise NotImplementedError( + f"with_grads={with_grads} " + f"use_orig_params={state._use_orig_params} " + f"offload_to_cpu={offload_to_cpu} " + f"is not supported yet" + ) + if writeback and rank0_only: + raise ValueError( + "writeback=True and rank0_only=True is not supported, as model " + "parameter shapes will be different across ranks, and writing " + "to them can lead to inconsistencies across ranks when the " + "context is exited." + ) + if offload_to_cpu and not rank0_only: + warnings.warn( + "offload_to_cpu and rank0_only=False will result in " + "full parameters being redundantly copied to CPU memory for " + "GPUs that reside on the same machine, which may incur the risk of " + "CPU OOM. It is recommended to use ``offload_to_cpu`` with " + "rank0_only=True." + ) + + torch.cuda.synchronize() + # If handles are shared by other module(s), the handle may be already unsharded. + handles = [ + handle + for handle in _module_handles(state, module) + if handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS + ] + if not handles: + yield + return + + for handle in handles: + if handle._training_state != HandleTrainingState.IDLE: + raise ValueError(f"Current handle state is {handle._training_state}") + + for handle in handles: + handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS + + _clear_grads_if_needed(handles) + free_unsharded_flat_params = [handle.needs_unshard() for handle in handles] + # No need to call `wait_stream()` since we unshard in the computation + # stream directly + computation_stream = torch.cuda.current_stream() + _unshard(state, handles, computation_stream, computation_stream) + if with_grads: + _unshard_grads(handles) + + if rank0_only and state.rank != 0: + # Free the unsharded flattened parameter early + _reshard(state, handles, free_unsharded_flat_params) + if with_grads: + _reshard_grads(handles) + try: + yield + finally: + for handle in handles: + handle._training_state = HandleTrainingState.IDLE + else: + # Unflatten the unsharded flattened parameters + with contextlib.ExitStack() as stack: + # Invariant: rank == 0 or !rank0_only + for handle in handles: + if offload_to_cpu and handle.uses_sharded_strategy: + stack.enter_context(handle.to_cpu()) + # TODO (awgu): Since PyTorch enforces that a parameter + # and its gradients need to match metadata (e.g. + # device), we must move gradients to CPU *after* we + # move parameters. + # TODO (awgu): This FPW call assumes 1 `FlatParameter` + if not state._use_orig_params: + stack.enter_context(_unflatten_as_params(state, module)) + try: + yield + finally: + stack.close() + if writeback: + _writeback_to_local_shard(handles, with_grads) + _reshard(state, handles, free_unsharded_flat_params) + if with_grads: + _reshard_grads(handles) + for handle in handles: + handle._training_state = HandleTrainingState.IDLE + + +def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None: + """ + Deregisters the original parameters; registers the ``FlatParameter``. + """ + handles = _module_handles(state, module) + p_assert( + len(handles) <= 1, + "Expects <=1 handle per FSDP instance; needs to be refactored " + "for >1 handle (e.g. non-recursive wrapping)", + ) + if not handles: + return + handle = handles[0] + p_assert( + handle._use_orig_params, + f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} " + f"handle: {handle._use_orig_params}", + ) + handle._deregister_orig_params() + _register_flat_param(state, module) + + +def _register_orig_params(state: _FSDPState, module: nn.Module) -> None: + """ + Deregisters the ``FlatParameter``; registers the original parameters. + """ + handles = _module_handles(state, module) + if not handles: + return + handle = handles[0] + _deregister_flat_param(state, module) + if handle.is_sharded(handle.flat_param): + handle._use_sharded_views() + handle._use_sharded_grad_views() + else: + handle._use_unsharded_views(as_params=True) diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 773686081a4d2a3..510f90de20234ee 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -48,18 +48,14 @@ _init_state_dict_state, ) from torch.distributed.fsdp._runtime_utils import ( - _clear_grads_if_needed, _lazy_init, _post_forward, _post_forward_reshard, _pre_forward, _pre_forward_unshard, _reshard, - _reshard_grads, _root_pre_forward, _should_free_in_backward, - _unshard, - _unshard_grads, _wait_for_computation_stream, ) from torch.distributed.fsdp._wrap_utils import _auto_wrap @@ -92,6 +88,12 @@ _post_state_dict_hook, _pre_load_state_dict_hook, ) +from ._unshard_param_utils import ( + _deregister_orig_params, + _register_flat_param, + _register_orig_params, + _unshard_params, +) from ._utils import p_assert from .flat_param import FlatParameter, FlatParamHandle from .wrap import ParamExecOrderWrapPolicy @@ -409,7 +411,7 @@ def __init__( self._fsdp_wrapped_module = module if not use_orig_params: _check_orig_params_flattened(self, self._ignored_params) - self._register_flat_param() + _register_flat_param(self, self) # Delete to avoid keeping references after the constructor delattr(self, "_ignored_params") @@ -864,153 +866,20 @@ def _summon_full_params( yield return - torch.cuda.synchronize() _lazy_init(self, self) - self._assert_state([TrainingState.IDLE]) - for handle in self._handles: - assert handle._training_state == HandleTrainingState.IDLE - self.training_state = TrainingState.SUMMON_FULL_PARAMS - for handle in self._handles: - handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS - - if self._is_root: - _clear_grads_if_needed(self._fsdp_handles(self)) - free_unsharded_flat_params = [ - handle.needs_unshard() for handle in self._handles - ] - # No need to call `wait_stream()` since we unshard in the computation - # stream directly - computation_stream = torch.cuda.current_stream() - _unshard(self, self._handles, computation_stream, computation_stream) - if with_grads: - _unshard_grads(self._handles) - - if rank0_only and self.rank != 0: - # Free the unsharded flattened parameter early - _reshard(self, self._handles, free_unsharded_flat_params) - if with_grads: - _reshard_grads(self._handles) + with _unshard_params( + module=self, + state=self, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ): try: + self.training_state = TrainingState.SUMMON_FULL_PARAMS yield finally: self.training_state = TrainingState.IDLE - for handle in self._handles: - handle._training_state = HandleTrainingState.IDLE - else: - # Unflatten the unsharded flattened parameters - with contextlib.ExitStack() as stack: - # Invariant: rank == 0 or !rank0_only - for handle in self._handles: - if offload_to_cpu and handle.uses_sharded_strategy: - stack.enter_context(handle.to_cpu()) - # TODO (awgu): Since PyTorch enforces that a parameter - # and its gradients need to match metadata (e.g. - # device), we must move gradients to CPU *after* we - # move parameters. - # TODO (awgu): This FPW call assumes 1 `FlatParameter` - if not self._use_orig_params: - stack.enter_context(self._unflatten_as_params()) - try: - yield - finally: - stack.close() - if writeback: - self._writeback_to_local_shard(self._handles, with_grads) - _reshard(self, self._handles, free_unsharded_flat_params) - if with_grads: - _reshard_grads(self._handles) - self.training_state = TrainingState.IDLE - for handle in self._handles: - handle._training_state = HandleTrainingState.IDLE - - @torch.no_grad() - def _writeback_to_local_shard( - self, - handles: List[FlatParamHandle], - writeback_grad: bool, - ): - """ - For each handle, writes back the this rank's shard of the unsharded - flattened parameter to the sharded flattened parameter. If - ``writeback_grad=True``, then writes back to the sharded gradient as - well. - - Precondition: Each handle's ``FlatParameter`` 's data points to the - padded unsharded flattened parameter. - """ - for handle in handles: - # For `NO_SHARD`, `_local_shard` is the unsharded flattened - # parameter and `grad` is the unsharded gradient, so there is no - # need to writeback for either - if not handle.uses_sharded_strategy: - continue - assert ( - handle.flat_param.ndim == 1 - ), f"Expects `flat_param` to be flattened but got {handle.flat_param.shape}" - - # Get the unpadded shard instead of the padded shard to persist - # user changes to the padding (though FSDP does not explicitly - # support this) - param_shard, _ = FlatParamHandle._get_unpadded_shard( - handle.flat_param, - handle.rank, - handle.world_size, - ) - handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) - if writeback_grad: - existing_grad = handle.sharded_grad - if existing_grad is not None: - grad_shard, _ = FlatParamHandle._get_unpadded_shard( - handle.flat_param.grad, - handle.rank, - handle.world_size, - ) - existing_grad[: grad_shard.numel()].copy_(grad_shard) - - @contextlib.contextmanager - def _unflatten_as_params(self) -> Generator: - """ - Assumes that the flattened parameter is unsharded. When in the context, - de-registers the flattened parameter and unflattens the original - parameters as ``nn.Parameter`` views into the flattened parameter. - After the context, re-registers the flattened parameter and restores - the original parameters as ``Tensor`` views into the flattened - parameter. - """ - if not self._handles: - yield - else: - self._deregister_flat_param() - try: - with self._handles[0].unflatten_as_params(): - yield - finally: - if not self._handles[0]._use_orig_params: - self._register_flat_param() - - def _register_flat_param(self): - """ - Registers the flattened parameter to the wrapped module, making it - visible to ``nn.Module`` methods. - - We do not use :meth:`nn.Module.register_parameter` because we want - ``FLAT_PARAM`` to always be an attribute but dynamically change whether - it is visible to ``nn.Module`` methods. - """ - if self._has_params: - self.module._parameters[FLAT_PARAM] = self._handles[0].flat_param - - def _deregister_flat_param(self): - """ - De-registers the flattened parameter from the wrapped module, hiding it - from ``nn.Module`` methods. - - We do not use ``del`` because we want ``FLAT_PARAM`` to always be an - attribute but dynamically change whether it is visible to ``nn.Module`` - methods. - """ - if self._has_params: - self.module._parameters.pop(FLAT_PARAM, None) @contextlib.contextmanager def _deregister_orig_params_ctx(self): @@ -1026,46 +895,12 @@ def _deregister_orig_params_ctx(self): "`_use_orig_params=True`", ) for fsdp_module in self.fsdp_modules(self): - fsdp_module._deregister_orig_params() + _deregister_orig_params(fsdp_module, fsdp_module) try: yield finally: for fsdp_module in self.fsdp_modules(self): - fsdp_module._register_orig_params() - - def _deregister_orig_params(self): - """ - Deregisters the original parameters; registers the ``FlatParameter``. - """ - p_assert( - len(self._handles) <= 1, - "Expects <=1 handle per FSDP instance; needs to be refactored " - "for >1 handle (e.g. non-recursive wrapping)", - ) - if not self._handles: - return - handle = self._handles[0] - p_assert( - handle._use_orig_params, - f"Inconsistent `_use_orig_params` -- FSDP: {self._use_orig_params} " - f"handle: {handle._use_orig_params}", - ) - handle._deregister_orig_params() - self._register_flat_param() - - def _register_orig_params(self): - """ - Deregisters the ``FlatParameter``; registers the original parameters. - """ - if not self._handles: - return - handle = self._handles[0] - self._deregister_flat_param() - if handle.is_sharded(handle.flat_param): - handle._use_sharded_views() - handle._use_sharded_grad_views() - else: - handle._use_unsharded_views(as_params=True) + _register_orig_params(fsdp_module, fsdp_module) def _apply(self, *args, **kwargs): """ From 9d7d21f5691979f728f42a709e1a47ab3e905342 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 8 Nov 2022 10:22:31 -0800 Subject: [PATCH 31/62] [ONNX] Add stack info to diagnostics (#87258) ~~Investigating strange bug releasing 'graph' right when returning from `_C._jit_pass_onnx`.~~ ~~Can be repro-ed locally via `test_cpp_diagnose`, with changes in this PR.~~ Resolved by https://github.com/pytorch/pytorch/pull/87829. This PR adds methods to record stack backtrace information to diagnostics. * #87830 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87258 Approved by: https://github.com/abock --- test/onnx/internal/test_diagnostics.py | 77 +++++++++++++++---- .../onnx/_internal/diagnostics/_diagnostic.py | 61 ++++++++++++--- .../_internal/diagnostics/infra/__init__.py | 2 + .../_internal/diagnostics/infra/_infra.py | 49 ++++++------ .../onnx/_internal/diagnostics/infra/utils.py | 35 +++++++++ 5 files changed, 169 insertions(+), 55 deletions(-) create mode 100644 torch/onnx/_internal/diagnostics/infra/utils.py diff --git a/test/onnx/internal/test_diagnostics.py b/test/onnx/internal/test_diagnostics.py index fbd888329a50e93..ea9a789e91c1f28 100644 --- a/test/onnx/internal/test_diagnostics.py +++ b/test/onnx/internal/test_diagnostics.py @@ -3,6 +3,7 @@ import contextlib import dataclasses import io +import typing import unittest from typing import AbstractSet, Tuple @@ -110,23 +111,15 @@ class TestOnnxDiagnostics(common_utils.TestCase): def setUp(self): engine = diagnostics.engine engine.clear() + self._sample_rule = diagnostics.rules.missing_custom_symbolic_function super().setUp() - def test_assert_diagnostic_raises_when_diagnostic_not_found(self): - with self.assertRaises(AssertionError): - with assert_diagnostic( - self, - diagnostics.engine, - diagnostics.rules.node_missing_onnx_shape_inference, - diagnostics.levels.WARNING, - ): - pass - - def test_cpp_diagnose_emits_warning(self): + def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp( + self, + ) -> diagnostics.ExportDiagnostic: class CustomAdd(torch.autograd.Function): @staticmethod def forward(ctx, x, y): - ctx.save_for_backward(x, y) return x + y @staticmethod @@ -137,6 +130,30 @@ class M(torch.nn.Module): def forward(self, x): return CustomAdd.apply(x, x) + # trigger warning for missing shape inference. + rule = diagnostics.rules.node_missing_onnx_shape_inference + torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) + + context = diagnostics.engine.contexts[-1] + for diagnostic in context.diagnostics: + if ( + diagnostic.rule == rule + and diagnostic.level == diagnostics.levels.WARNING + ): + return typing.cast(diagnostics.ExportDiagnostic, diagnostic) + raise AssertionError("No diagnostic found.") + + def test_assert_diagnostic_raises_when_diagnostic_not_found(self): + with self.assertRaises(AssertionError): + with assert_diagnostic( + self, + diagnostics.engine, + diagnostics.rules.node_missing_onnx_shape_inference, + diagnostics.levels.WARNING, + ): + pass + + def test_cpp_diagnose_emits_warning(self): with assert_diagnostic( self, diagnostics.engine, @@ -144,7 +161,7 @@ def forward(self, x): diagnostics.levels.WARNING, ): # trigger warning for missing shape inference. - torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) + self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp() def test_py_diagnose_emits_error(self): class M(torch.nn.Module): @@ -168,15 +185,43 @@ def forward(self, x): def test_diagnostics_engine_records_diagnosis_reported_outside_of_export( self, ): - sample_rule = diagnostics.rules.missing_custom_symbolic_function sample_level = diagnostics.levels.ERROR with assert_diagnostic( self, diagnostics.engine, - sample_rule, + self._sample_rule, sample_level, ): - diagnostics.context.diagnose(sample_rule, sample_level) + diagnostics.context.diagnose(self._sample_rule, sample_level) + + def test_diagnostics_records_python_call_stack(self): + diagnostic = diagnostics.ExportDiagnostic( + self._sample_rule, diagnostics.levels.NOTE + ) + stack = diagnostic.python_call_stack + assert stack is not None # for mypy + self.assertGreater(len(stack.frames), 0) + frame = stack.frames[0] + assert frame.location.snippet is not None # for mypy + self.assertIn("self._sample_rule", frame.location.snippet) + assert frame.location.uri is not None # for mypy + self.assertIn("test_diagnostics.py", frame.location.uri) + + def test_diagnostics_records_cpp_call_stack(self): + diagnostic = ( + self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp() + ) + stack = diagnostic.cpp_call_stack + assert stack is not None # for mypy + self.assertGreater(len(stack.frames), 0) + frame_messages = [frame.location.message for frame in stack.frames] + self.assertTrue( + any( + isinstance(message, str) + and "torch::jit::ONNXShapeTypeInference" in message + for message in frame_messages + ) + ) @dataclasses.dataclass diff --git a/torch/onnx/_internal/diagnostics/_diagnostic.py b/torch/onnx/_internal/diagnostics/_diagnostic.py index ae6615e831cb21c..21e44f2b44671e2 100644 --- a/torch/onnx/_internal/diagnostics/_diagnostic.py +++ b/torch/onnx/_internal/diagnostics/_diagnostic.py @@ -5,11 +5,38 @@ import torch from torch.onnx._internal.diagnostics import infra +from torch.onnx._internal.diagnostics.infra import utils as infra_utils +from torch.utils import cpp_backtrace # This is a workaround for mypy not supporting Self from typing_extensions. _ExportDiagnostic = TypeVar("_ExportDiagnostic", bound="ExportDiagnostic") +def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32): + """Returns the current C++ call stack. + + This function utilizes `torch.utils.cpp_backtrace` to get the current C++ call stack. + The returned C++ call stack is a concatenated string of the C++ call stack frames. + Each frame is separated by a newline character, in the same format of + r"frame #[0-9]+: (?P.*)". More info at `c10/util/Backtrace.cpp`. + + """ + frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n") + frame_messages = [] + for frame in frames: + segments = frame.split(":", 1) + if len(segments) == 2: + frame_messages.append(segments[1].strip()) + else: + frame_messages.append("") + return infra.Stack( + frames=[ + infra.StackFrame(location=infra.Location(message=message)) + for message in frame_messages + ] + ) + + class ExportDiagnostic(infra.Diagnostic): """Base class for all export diagnostics. @@ -18,24 +45,34 @@ class ExportDiagnostic(infra.Diagnostic): diagnostic. """ + python_call_stack: Optional[infra.Stack] = None + cpp_call_stack: Optional[infra.Stack] = None + def __init__( self, *args, **kwargs, ) -> None: super().__init__(*args, **kwargs) - - def with_cpp_stack(self: _ExportDiagnostic) -> _ExportDiagnostic: - # TODO: Implement this. - # self.stacks.append(...) - raise NotImplementedError() - return self - - def with_python_stack(self: _ExportDiagnostic) -> _ExportDiagnostic: - # TODO: Implement this. - # self.stacks.append(...) - raise NotImplementedError() - return self + self.record_python_call_stack(frames_to_skip=1) + self.record_cpp_call_stack(frames_to_skip=1) + + def record_python_call_stack(self, frames_to_skip) -> None: + """Records the current Python call stack in the diagnostic.""" + frames_to_skip += 1 # Skip this function. + stack = infra_utils.python_call_stack(frames_to_skip=frames_to_skip) + stack.message = "Python call stack" + self.with_stack(stack) + self.python_call_stack = stack + + def record_cpp_call_stack(self, frames_to_skip) -> None: + """Records the current C++ call stack in the diagnostic.""" + # No need to skip this function because python frame is not recorded + # in cpp call stack. + stack = _cpp_call_stack(frames_to_skip=frames_to_skip) + stack.message = "C++ call stack" + self.with_stack(stack) + self.cpp_call_stack = stack def with_model_source_location( self: _ExportDiagnostic, diff --git a/torch/onnx/_internal/diagnostics/infra/__init__.py b/torch/onnx/_internal/diagnostics/infra/__init__.py index ac9e6e99a9746bc..4f9dd9e5fa0b3b2 100644 --- a/torch/onnx/_internal/diagnostics/infra/__init__.py +++ b/torch/onnx/_internal/diagnostics/infra/__init__.py @@ -8,6 +8,7 @@ Rule, RuleCollection, Stack, + StackFrame, ) from .engine import DiagnosticEngine @@ -22,4 +23,5 @@ "Rule", "RuleCollection", "Stack", + "StackFrame", ] diff --git a/torch/onnx/_internal/diagnostics/infra/_infra.py b/torch/onnx/_internal/diagnostics/infra/_infra.py index 6966ccccbb26426..b8a4c5032f52390 100644 --- a/torch/onnx/_internal/diagnostics/infra/_infra.py +++ b/torch/onnx/_internal/diagnostics/infra/_infra.py @@ -110,11 +110,12 @@ def format_message(self, *args, **kwargs) -> str: @dataclasses.dataclass class Location: - uri: str - message: str + uri: Optional[str] = None line: Optional[int] = None + message: Optional[str] = None start_column: Optional[int] = None end_column: Optional[int] = None + snippet: Optional[str] = None def sarif(self) -> sarif.Location: """Returns the SARIF representation of this location.""" @@ -124,43 +125,37 @@ def sarif(self) -> sarif.Location: region=sarif.Region( start_line=self.line, start_column=self.start_column, - end_line=self.line, end_column=self.end_column, + snippet=sarif.ArtifactContent(text=self.snippet), ), ), - message=sarif.Message(text=self.message), + message=sarif.Message(text=self.message) + if self.message is not None + else None, ) +@dataclasses.dataclass +class StackFrame: + location: Location + + def sarif(self) -> sarif.StackFrame: + """Returns the SARIF representation of this stack frame.""" + return sarif.StackFrame(location=self.location.sarif()) + + @dataclasses.dataclass class Stack: - frame_locations: List[Location] = dataclasses.field(default_factory=list) + frames: List[StackFrame] = dataclasses.field(default_factory=list) + message: Optional[str] = None def sarif(self) -> sarif.Stack: """Returns the SARIF representation of this stack.""" return sarif.Stack( - frames=[ - sarif.StackFrame(location=loc.sarif()) for loc in self.frame_locations - ] - ) - - def add_frame( - self, - uri: str, - message: str, - line: Optional[int] = None, - start_column: Optional[int] = None, - end_column: Optional[int] = None, - ) -> None: - """Adds a frame to the stack.""" - self.frame_locations.append( - Location( - uri=uri, - message=message, - line=line, - start_column=start_column, - end_column=end_column, - ) + frames=[frame.sarif() for frame in self.frames], + message=sarif.Message(text=self.message) + if self.message is not None + else None, ) diff --git a/torch/onnx/_internal/diagnostics/infra/utils.py b/torch/onnx/_internal/diagnostics/infra/utils.py new file mode 100644 index 000000000000000..c32de1c6b8ad90f --- /dev/null +++ b/torch/onnx/_internal/diagnostics/infra/utils.py @@ -0,0 +1,35 @@ +import inspect + +from torch.onnx._internal.diagnostics.infra import _infra + + +def python_frame(frame: inspect.FrameInfo) -> _infra.StackFrame: + """Returns a StackFrame for the given inspect.FrameInfo.""" + snippet = ( + frame.code_context[frame.index] + if frame.code_context is not None and frame.index is not None + else None + ) + + return _infra.StackFrame( + location=_infra.Location( + uri=frame.filename, + line=frame.lineno, + snippet=snippet, + ) + ) + + +def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> _infra.Stack: + """Returns the current Python call stack.""" + if frames_to_skip < 0: + raise ValueError("frames_to_skip must be non-negative") + if frames_to_log < 0: + raise ValueError("frames_to_log must be non-negative") + frames_to_skip += 1 # Skip this function. + stack = _infra.Stack() + stack.frames = [ + python_frame(frame) + for frame in inspect.stack()[frames_to_skip : frames_to_skip + frames_to_log] + ] + return stack From 4e5d7afe84c01ed730f0f43395d7fa0542e81f3a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 11 Nov 2022 19:08:30 +0000 Subject: [PATCH 32/62] Revert "add DisableTorchFunction that matches DisableTorchDispatch (#88219)" This reverts commit c0ecce15b5a54ff0185f9976e6bfb6f3a7de698d. Reverted https://github.com/pytorch/pytorch/pull/88219 on behalf of https://github.com/izaitsevfb due to BC-breaking change, D41211901 --- aten/src/ATen/PythonTorchFunctionTLS.cpp | 11 +-- aten/src/ATen/PythonTorchFunctionTLS.h | 12 +-- test/allowlist_for_publicAPI.json | 1 - test/test_overrides.py | 21 ---- test/test_public_bindings.py | 1 - torch/_C/__init__.pyi.in | 1 - torch/__init__.py | 2 +- torch/csrc/Module.cpp | 4 - torch/csrc/autograd/init.cpp | 9 +- torch/csrc/utils/disable_torch_function.cpp | 100 ++------------------ torch/csrc/utils/disable_torch_function.h | 1 - 11 files changed, 24 insertions(+), 139 deletions(-) diff --git a/aten/src/ATen/PythonTorchFunctionTLS.cpp b/aten/src/ATen/PythonTorchFunctionTLS.cpp index 00f372f370e62fa..c9487c6958cbf76 100644 --- a/aten/src/ATen/PythonTorchFunctionTLS.cpp +++ b/aten/src/ATen/PythonTorchFunctionTLS.cpp @@ -26,12 +26,12 @@ int64_t PythonTorchFunctionTLS::stack_len() { return pythonTorchFunctionState.stack_.size(); } -void PythonTorchFunctionTLS::set_disabled_state(TorchFunctionDisabledState disabled_state) { - pythonTorchFunctionState.disabled_state_ = disabled_state; +void PythonTorchFunctionTLS::set_disabled(bool disabled) { + pythonTorchFunctionState.disabled_ = disabled; } -TorchFunctionDisabledState PythonTorchFunctionTLS::get_disabled_state() { - return pythonTorchFunctionState.disabled_state_; +bool PythonTorchFunctionTLS::is_disabled() { + return pythonTorchFunctionState.disabled_; } void PythonTorchFunctionTLS::set_state(const PythonTorchFunctionTLS& state) { @@ -43,8 +43,7 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() { } bool torch_function_mode_enabled() { - return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED && - PythonTorchFunctionTLS::stack_len() > 0; + return PythonTorchFunctionTLS::stack_len() > 0; } } // namespace impl diff --git a/aten/src/ATen/PythonTorchFunctionTLS.h b/aten/src/ATen/PythonTorchFunctionTLS.h index a1e3a61ea20233b..5940fb6f2dee246 100644 --- a/aten/src/ATen/PythonTorchFunctionTLS.h +++ b/aten/src/ATen/PythonTorchFunctionTLS.h @@ -6,11 +6,9 @@ namespace at { namespace impl { -enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED }; - struct TORCH_API PythonTorchFunctionTLS { - static void set_disabled_state(TorchFunctionDisabledState disabled_state_); - static TorchFunctionDisabledState get_disabled_state(); + static void set_disabled(bool); + static bool is_disabled(); static void push_onto_stack(std::shared_ptr mode); static const std::shared_ptr pop_stack(); @@ -22,11 +20,11 @@ struct TORCH_API PythonTorchFunctionTLS { private: // The mode TLS is split into - // - disabled_state, which says which part of torch function are disabled + // - disabled_, which says whether or not to disable all torch function + // modes // - stack_, which is a vector of modes representing the stack of user // defined modes - TorchFunctionDisabledState disabled_state_ = - TorchFunctionDisabledState::ENABLED; + bool disabled_; std::vector> stack_; }; diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 45ba9ae94676d9f..8a66dc12d4b6f9d 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1128,7 +1128,6 @@ "BFloat16Tensor", "ComplexDoubleStorage", "ComplexFloatStorage", - "DisableTorchFunction", "DisableTorchFunctionSubclass", "Generator", "HalfStorage", diff --git a/test/test_overrides.py b/test/test_overrides.py index 3b3a5ed063c70e4..01c763a548fc8b1 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -1453,27 +1453,6 @@ class B(torch.Tensor): self.assertTrue(called) - def test_disable_subclass_mode(self): - called = False - - class A(TorchFunctionMode): - def __torch_function__(self, func, types, args=(), kwargs=None): - nonlocal called - if kwargs is None: - kwargs = {} - called = True - return func(*args, **kwargs) - - class B(torch.Tensor): - pass - - x = B(torch.randn(5)) - with A(): - with torch._C.DisableTorchFunction(): - self.assertNotIsInstance(torch.sum(x), B) - - self.assertFalse(called) - def test_disable_enable_subclass(self): called = False diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 46c7396b9b07fdc..6897c3102df607a 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -99,7 +99,6 @@ def test_no_new_bindings(self): "device", "DeviceObjType", "DictType", - "DisableTorchFunction", "DisableTorchFunctionSubclass", "DispatchKey", "DispatchKeySet", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index bc4bf03d8161f23..79dd6386c3789e3 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -108,7 +108,6 @@ class layout: ... # Defined in torch/csrc/utils/disable_torch_function.cpp -def DisableTorchFunction(): ... def DisableTorchFunctionSubclass(): ... # Defined in torch/csrc/utils/tensor_layouts.cpp diff --git a/torch/__init__.py b/torch/__init__.py index 6049967b6f18e6f..ec23499dce659d6 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -315,7 +315,7 @@ def get_pyobj(self): if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type] if (obj.__module__ != 'torch'): # TODO: fix their module from C++ side - if name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']: + if name not in ['DisableTorchFunctionSubclass', 'Generator']: obj.__module__ = 'torch' if not TYPE_CHECKING: diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 0a9aa53a0bbc4bb..efe6c18ea0cd4f6 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1597,10 +1597,6 @@ Call this whenever a new thread is created in order to propagate values from "DisableTorchFunctionSubclass", (PyObject*)THPModule_DisableTorchFunctionSubclassType(), /* incref= */ false)); - ASSERT_TRUE(set_module_attr( - "DisableTorchFunction", - (PyObject*)THPModule_DisableTorchFunctionType(), - /* incref= */ false)); torch::set_disabled_torch_function_impl( PyObject_GetAttrString(module, "_disabled_torch_function_impl")); ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr); diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 6271cfd5cb997d8..d26db95f1295cf8 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -60,14 +60,13 @@ struct DisableAutocast { struct EnableTorchFunction { EnableTorchFunction() - : old_(at::impl::PythonTorchFunctionTLS::get_disabled_state()) { - at::impl::PythonTorchFunctionTLS::set_disabled_state( - at::impl::TorchFunctionDisabledState::ENABLED); + : old_(at::impl::PythonTorchFunctionTLS::is_disabled()) { + at::impl::PythonTorchFunctionTLS::set_disabled(false); } ~EnableTorchFunction() { - at::impl::PythonTorchFunctionTLS::set_disabled_state(old_); + at::impl::PythonTorchFunctionTLS::set_disabled(old_); } - at::impl::TorchFunctionDisabledState old_; + bool old_; }; struct EnablePythonDispatcher { diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index 589b069250a36ef..516e6b89d43af59 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -11,8 +11,7 @@ PyObject* disabled_torch_function = nullptr; PyObject* disabled_torch_dispatch = nullptr; bool torch_function_enabled() { - return at::impl::PythonTorchFunctionTLS::get_disabled_state() == - at::impl::TorchFunctionDisabledState::ENABLED; + return !at::impl::PythonTorchFunctionTLS::is_disabled(); } PyObject* disabled_torch_function_impl() { @@ -35,23 +34,20 @@ void set_disabled_torch_dispatch_impl(PyObject* value) { typedef struct { PyObject_HEAD /* Type-specific fields go here. */ - at::impl::TorchFunctionDisabledState old_state; + bool old_state; } DisableTorchFunctionSubclass; PyObject* DisableTorchFunctionSubclass__enter( PyObject* self, PyObject* unused) { - const auto old_state = at::impl::PythonTorchFunctionTLS::get_disabled_state(); - ((DisableTorchFunctionSubclass*)self)->old_state = old_state; - if (old_state == at::impl::TorchFunctionDisabledState::ENABLED) { - at::impl::PythonTorchFunctionTLS::set_disabled_state( - at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED); - } + ((DisableTorchFunctionSubclass*)self)->old_state = + at::impl::PythonTorchFunctionTLS::is_disabled(); + at::impl::PythonTorchFunctionTLS::set_disabled(true); Py_RETURN_NONE; } PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) { - at::impl::PythonTorchFunctionTLS::set_disabled_state( + at::impl::PythonTorchFunctionTLS::set_disabled( ((DisableTorchFunctionSubclass*)self)->old_state); Py_RETURN_NONE; } @@ -119,81 +115,6 @@ PyObject* THPModule_DisableTorchFunctionSubclassType() { return (PyObject*)(&DisableTorchFunctionSubclassType); } -typedef struct { - PyObject_HEAD - /* Type-specific fields go here. */ - at::impl::TorchFunctionDisabledState old_state; -} DisableTorchFunction; - -PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) { - ((DisableTorchFunctionSubclass*)self)->old_state = - at::impl::PythonTorchFunctionTLS::get_disabled_state(); - at::impl::PythonTorchFunctionTLS::set_disabled_state( - at::impl::TorchFunctionDisabledState::ALL_DISABLED); - Py_RETURN_NONE; -} - -PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) { - at::impl::PythonTorchFunctionTLS::set_disabled_state( - ((DisableTorchFunctionSubclass*)self)->old_state); - Py_RETURN_NONE; -} - -static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT - {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr}, - {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr}, - {nullptr, nullptr, 0, nullptr}}; - -PyTypeObject DisableTorchFunctionType = { - PyVarObject_HEAD_INIT( - nullptr, - 0) "torch._C.DisableTorchFunction", /* tp_name */ - sizeof(DisableTorchFunction), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - DisableTorchFunction_methods, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - PyType_GenericAlloc, /* tp_alloc */ - PyType_GenericNew, /* tp_new */ -}; - -PyObject* THPModule_DisableTorchFunctionType() { - if (PyType_Ready(&DisableTorchFunctionType) < 0) { - return nullptr; - } - - return (PyObject*)(&DisableTorchFunctionType); -} - PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { HANDLE_TH_ERRORS PyObject *func = nullptr, *types = nullptr, *args = nullptr, @@ -216,14 +137,11 @@ PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { // These are all C-API calls so no exceptions will be raised // and therefore no need for RAII approach to storing // the old value. - auto old_value = at::impl::PythonTorchFunctionTLS::get_disabled_state(); - if (old_value == at::impl::TorchFunctionDisabledState::ENABLED) { - at::impl::PythonTorchFunctionTLS::set_disabled_state( - at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED); - } + bool old_value = at::impl::PythonTorchFunctionTLS::is_disabled(); + at::impl::PythonTorchFunctionTLS::set_disabled(true); // kwargs can safely be nullptr here. PyObject* result = PyObject_Call(func, py_args.ptr(), kwargs); - at::impl::PythonTorchFunctionTLS::set_disabled_state(old_value); + at::impl::PythonTorchFunctionTLS::set_disabled(old_value); return result; END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h index 8fc5118830eb7d7..881a7adb13ebf99 100644 --- a/torch/csrc/utils/disable_torch_function.h +++ b/torch/csrc/utils/disable_torch_function.h @@ -29,7 +29,6 @@ struct DisableTorchDispatch { } // namespace torch PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused); -PyObject* THPModule_DisableTorchFunctionType(); PyObject* THPModule_DisableTorchFunctionSubclassType(); PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args); PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args); From ba4d5aae06bde7c0ad045e54b7ad86f4542efb86 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 11 Nov 2022 19:13:05 +0000 Subject: [PATCH 33/62] Revert "rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218)" This reverts commit 7f28be10e5e71efda37800384fa897785499bed1. Reverted https://github.com/pytorch/pytorch/pull/88218 on behalf of https://github.com/izaitsevfb due to BC-breaking change, D41211901 --- test/allowlist_for_publicAPI.json | 2 +- test/profiler/test_profiler_tree.py | 2 +- test/test_overrides.py | 4 +-- test/test_public_bindings.py | 2 +- torch/_C/__init__.pyi.in | 2 +- torch/__init__.py | 2 +- torch/_dynamo/variables/builder.py | 2 +- torch/_dynamo/variables/misc.py | 2 +- torch/_dynamo/variables/tensor.py | 2 +- torch/_subclasses/fake_tensor.py | 2 +- torch/_tensor.py | 2 +- torch/csrc/Module.cpp | 4 +-- torch/csrc/autograd/init.cpp | 1 + torch/csrc/utils/disable_torch_function.cpp | 32 +++++++++---------- torch/csrc/utils/disable_torch_function.h | 2 +- torch/distributed/_shard/common_op_utils.py | 4 +-- torch/distributed/_shard/partial_tensor.py | 2 +- torch/distributed/_shard/replicated_tensor.py | 4 +-- .../_shard/sharded_tensor/_ops/tensor_ops.py | 2 +- torch/masked/maskedtensor/core.py | 2 +- 20 files changed, 38 insertions(+), 39 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 8a66dc12d4b6f9d..ba4a2e96df21943 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1128,7 +1128,7 @@ "BFloat16Tensor", "ComplexDoubleStorage", "ComplexFloatStorage", - "DisableTorchFunctionSubclass", + "DisableTorchFunction", "Generator", "HalfStorage", "HalfTensor", diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 210530250f924ca..d4a31c645613154 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -26,7 +26,7 @@ "torch/profiler/profiler.py(...): start": KEEP_ELLIPSES, "torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES, "torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES, - "": PRUNE_ALL, + "": PRUNE_ALL, "cudaStreamIsCapturing": PRUNE_ALL, "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags": PRUNE_ALL, } diff --git a/test/test_overrides.py b/test/test_overrides.py index 01c763a548fc8b1..7082f75a2141f55 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -1448,7 +1448,7 @@ class B(torch.Tensor): x = B(torch.randn(5)) with A(): - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): self.assertNotIsInstance(torch.sum(x), B) self.assertTrue(called) @@ -1460,7 +1460,7 @@ class A(torch.Tensor): pass x = A(torch.randn(5)) - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): g = torch._C._EnableTorchFunction() try: self.assertIsInstance(torch.sum(x), A) diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 6897c3102df607a..4d2df65126983e8 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -99,7 +99,7 @@ def test_no_new_bindings(self): "device", "DeviceObjType", "DictType", - "DisableTorchFunctionSubclass", + "DisableTorchFunction", "DispatchKey", "DispatchKeySet", "dtype", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 79dd6386c3789e3..2d20da2a04f30d9 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -108,7 +108,7 @@ class layout: ... # Defined in torch/csrc/utils/disable_torch_function.cpp -def DisableTorchFunctionSubclass(): ... +def DisableTorchFunction(): ... # Defined in torch/csrc/utils/tensor_layouts.cpp strided : layout = ... diff --git a/torch/__init__.py b/torch/__init__.py index ec23499dce659d6..19be59282cca4e0 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -315,7 +315,7 @@ def get_pyobj(self): if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type] if (obj.__module__ != 'torch'): # TODO: fix their module from C++ side - if name not in ['DisableTorchFunctionSubclass', 'Generator']: + if name not in ['DisableTorchFunction', 'Generator']: obj.__module__ = 'torch' if not TYPE_CHECKING: diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 9d8789746855447..d3c5140fa4a979e 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -506,7 +506,7 @@ def wrap_tensor(self, value: torch.Tensor): ) # Disable __torch_function__ to prevent cloning of `value` to hit # us - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): if is_constant_source(self.get_source()): return self.tx.output.register_attr_or_module( value, diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 6e4325b6c0f431c..da327122a6a7015 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -538,7 +538,7 @@ def call_function( options = VariableTracker.propagate(self, new_args, new_kwargs.values()) # Disable __torch_function__ here to prevent the clone of the # example tensor from going into the override. - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): if isinstance(args[0], TorchVariable): return TensorVariable.create( tx=tx, diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 0974f24ee9694d2..e87b1d87bac9bb5 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -743,7 +743,7 @@ def inline_torch_function_unwrapped( # Disable __torch_function__ here to prevent the clone of the # example tensor from going into the override. - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): return tx.inline_user_function_return(tf_func_var, tf_args, {}) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 79af51efc5b8eec..14f5cd2de0a7ab2 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1093,5 +1093,5 @@ def __torch_function__(self, func, types, args=(), kwargs=None): memo[id(tensor)] = out return out else: - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): return func(*args, **kwargs) diff --git a/torch/_tensor.py b/torch/_tensor.py index 41b6569c06d86bf..793034bb64edef9 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1297,7 +1297,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if not all(issubclass(cls, t) for t in types): return NotImplemented - with _C.DisableTorchFunctionSubclass(): + with _C.DisableTorchFunction(): ret = func(*args, **kwargs) if func in get_default_nowrap_functions(): return ret diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index efe6c18ea0cd4f6..b8693a484ed9da2 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1594,8 +1594,8 @@ Call this whenever a new thread is created in order to propagate values from (PyObject*)THPDefaultCPUGenerator, /* incref= */ false)); ASSERT_TRUE(set_module_attr( - "DisableTorchFunctionSubclass", - (PyObject*)THPModule_DisableTorchFunctionSubclassType(), + "DisableTorchFunction", + (PyObject*)THPModule_DisableTorchFunctionType(), /* incref= */ false)); torch::set_disabled_torch_function_impl( PyObject_GetAttrString(module, "_disabled_torch_function_impl")); diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index d26db95f1295cf8..ee963232d316635 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -343,6 +343,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { _C_m, "_RestorePythonTLSSnapshot") .def(py::init<>()); + // TODO: line up this binding with DisableTorchFunction py::class_(_C_m, "_DisableTorchDispatch") .def(py::init<>()); py::class_(_C_m, "_EnableTorchFunction") diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index 516e6b89d43af59..682120d7e62232f 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -35,20 +35,18 @@ typedef struct { PyObject_HEAD /* Type-specific fields go here. */ bool old_state; -} DisableTorchFunctionSubclass; +} DisableTorchFunction; -PyObject* DisableTorchFunctionSubclass__enter( - PyObject* self, - PyObject* unused) { - ((DisableTorchFunctionSubclass*)self)->old_state = +PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) { + ((DisableTorchFunction*)self)->old_state = at::impl::PythonTorchFunctionTLS::is_disabled(); at::impl::PythonTorchFunctionTLS::set_disabled(true); Py_RETURN_NONE; } -PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) { +PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) { at::impl::PythonTorchFunctionTLS::set_disabled( - ((DisableTorchFunctionSubclass*)self)->old_state); + ((DisableTorchFunction*)self)->old_state); Py_RETURN_NONE; } @@ -60,16 +58,16 @@ PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused) { } } -static PyMethodDef DisableTorchFunctionSubclass_methods[] = { // NOLINT - {"__enter__", DisableTorchFunctionSubclass__enter, METH_NOARGS, nullptr}, - {"__exit__", DisableTorchFunctionSubclass__exit, METH_VARARGS, nullptr}, +static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT + {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr}, + {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; -PyTypeObject DisableTorchFunctionSubclassType = { +PyTypeObject DisableTorchFunctionType = { PyVarObject_HEAD_INIT( nullptr, - 0) "torch._C.DisableTorchFunctionSubclass", /* tp_name */ - sizeof(DisableTorchFunctionSubclass), /* tp_basicsize */ + 0) "torch._C.DisableTorchFunction", /* tp_name */ + sizeof(DisableTorchFunction), /* tp_basicsize */ 0, /* tp_itemsize */ nullptr, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ @@ -94,7 +92,7 @@ PyTypeObject DisableTorchFunctionSubclassType = { 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ - DisableTorchFunctionSubclass_methods, /* tp_methods */ + DisableTorchFunction_methods, /* tp_methods */ nullptr, /* tp_members */ nullptr, /* tp_getset */ nullptr, /* tp_base */ @@ -107,12 +105,12 @@ PyTypeObject DisableTorchFunctionSubclassType = { PyType_GenericNew, /* tp_new */ }; -PyObject* THPModule_DisableTorchFunctionSubclassType() { - if (PyType_Ready(&DisableTorchFunctionSubclassType) < 0) { +PyObject* THPModule_DisableTorchFunctionType() { + if (PyType_Ready(&DisableTorchFunctionType) < 0) { return nullptr; } - return (PyObject*)(&DisableTorchFunctionSubclassType); + return (PyObject*)(&DisableTorchFunctionType); } PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h index 881a7adb13ebf99..3cdc33e90681b47 100644 --- a/torch/csrc/utils/disable_torch_function.h +++ b/torch/csrc/utils/disable_torch_function.h @@ -29,7 +29,7 @@ struct DisableTorchDispatch { } // namespace torch PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused); -PyObject* THPModule_DisableTorchFunctionSubclassType(); +PyObject* THPModule_DisableTorchFunctionType(); PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args); PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args); PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg); diff --git a/torch/distributed/_shard/common_op_utils.py b/torch/distributed/_shard/common_op_utils.py index 42d65923a536522..08aa13282abcd75 100644 --- a/torch/distributed/_shard/common_op_utils.py +++ b/torch/distributed/_shard/common_op_utils.py @@ -53,11 +53,11 @@ def tensor_default_op(types, args=(), kwargs=None, pg=None): Handles ``__torch_function__`` dispatch for the default tensor ops that behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or ``torch.Tensor.dtype``. We simply lower to the real op call with - DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__`` + DisableTorchFunction context like ``torch.Tensor.__torch_function__`` to avoid recursions. """ if kwargs is None: kwargs = {} - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): return op(*args, **kwargs) diff --git a/torch/distributed/_shard/partial_tensor.py b/torch/distributed/_shard/partial_tensor.py index 6a48163082c5ed7..dc8d09bdd7f301e 100644 --- a/torch/distributed/_shard/partial_tensor.py +++ b/torch/distributed/_shard/partial_tensor.py @@ -236,7 +236,7 @@ def find_process_group(e): # Need to disable all dispatch to print args and kwargs appropriately. guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] try: - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): raise RuntimeError( f"torch function '{func.__name__}', with args: {args} and " f"kwargs: {kwargs} not supported for PartialTensor!") diff --git a/torch/distributed/_shard/replicated_tensor.py b/torch/distributed/_shard/replicated_tensor.py index e3db6b0fac66460..1327f89e00aafe2 100644 --- a/torch/distributed/_shard/replicated_tensor.py +++ b/torch/distributed/_shard/replicated_tensor.py @@ -109,7 +109,7 @@ def dispatch_arg(arg): # We cann't do super().__torch_function__() as it implicitly convert the result # back to tensor subclasses, where in our case, we need to control the output type # base on the inter-op rules we defined. - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): rs = func(*args, **kwargs) if func in get_default_nowrap_functions(): return rs @@ -157,7 +157,7 @@ def validate(self) -> bool: return True def __setstate__(self, state): - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): self.data = state self.requires_grad = state.requires_grad from torch.distributed._shard.api import _get_current_process_group diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index 9ed83ee33f61940..e52c29238a62bc3 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -203,7 +203,7 @@ def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None): local_shard.tensor.requires_grad_(requires_grad) # update the wrapper class property - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): self_st.requires_grad_(requires_grad) # update the metadata in the meanwhile self_st._metadata.tensor_properties.requires_grad = requires_grad diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 0459f24587bd766..3274ef2ef956909 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -270,7 +270,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if not all(issubclass(cls, t) for t in types): return NotImplemented - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) if func in get_default_nowrap_functions(): return ret From f74946324e794d2332251d0497dc8ff4f831caa9 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 11 Nov 2022 21:11:12 +0000 Subject: [PATCH 34/62] [fix] allow saving python attr on Tensor and Parameter via torch.save (#81616) Fixes: https://github.com/pytorch/pytorch/issues/72129 TODO: * [x] Fix for Parameter Benchmark (Measurable diff for small tensors) ``` [-------------- Save and Load --------------] | After PR | Before PR 1 threads: ---------------------------------- () | 111.7 | 106.9 (4, 4) | 114.4 | 109.2 (128, 128) | 135.2 | 128.3 (1024, 1024) | 1431.9 | 1431.3 Times are in microseconds (us). ```
Benchmark Script ```python import torch from torch.testing._internal.common_utils import BytesIOContext from torch.utils import benchmark import pickle shapes = ((), (4, 4), (128, 128), (1024, 1024)) sizes = [1, 64, 1024, 10000] results = [] def save_load_fn(t): with BytesIOContext() as f: torch.save(t, f) f.seek(0) torch.load(f) for shape in shapes: t = torch.randn(shape) label = 'Save and Load' sub_label = f'{shape}' results.append(benchmark.Timer( stmt='save_load_fn(t)', globals={'t': t, 'save_load_fn':save_load_fn}, label=label, sub_label=sub_label, description='Before PR', ).blocked_autorange(min_run_time=2)) compare = benchmark.Compare(results) compare.print() with open('before_pr.pkl', 'wb') as f: pickle.dump(results, f) # with open('after_pr.pkl', 'rb') as f: # after_pr = pickle.load(f) # with open('before_pr.pkl', 'rb') as f: # before_pr = pickle.load(f) # compare = benchmark.Compare(after_pr + before_pr) # compare.print() ```
NOTE : **BC-Breaking** : After this PR, all tensors (also regular tensors) will be serialised using `_rebuild_from_type_v2`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81616 Approved by: https://github.com/albanD, https://github.com/kurtamohler --- test/test_serialization.py | 22 +++++++ torch/_tensor.py | 43 ++----------- torch/_utils.py | 59 ++++++++++++++++++ torch/_weights_only_unpickler.py | 4 ++ torch/csrc/jit/serialization/unpickler.cpp | 71 ++++++++++++++++++++++ torch/csrc/jit/serialization/unpickler.h | 4 ++ torch/nn/parameter.py | 1 + 7 files changed, 165 insertions(+), 39 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index 5ccc6f47b4c5d06..dca926be60e706d 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -905,6 +905,28 @@ def test_meta_serialization(self, weights_only): self.assertEqual(state['weight'].size(), big_model.weight.size()) + def test_serialization_python_attr(self): + def _test_save_load_attr(t): + t.foo = 'foo' + t.pi = 3.14 + + with BytesIOContext() as f: + torch.save(t, f) + f.seek(0) + loaded_t = torch.load(f) + + self.assertEqual(t, loaded_t) + self.assertEqual(t.foo, loaded_t.foo) + self.assertEqual(t.pi, loaded_t.pi) + + t = torch.zeros(3, 3) + _test_save_load_attr(t) + # This should start failing once Parameter + # supports saving Python Attribute. + err_msg = "'Parameter' object has no attribute" + with self.assertRaisesRegex(AttributeError, err_msg): + _test_save_load_attr(torch.nn.Parameter(t)) + def test_weights_only_assert(self): class HelloWorld: def __reduce__(self): diff --git a/torch/_tensor.py b/torch/_tensor.py index 793034bb64edef9..39fc56452f5a45f 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -55,9 +55,6 @@ def _rebuild_from_type(func, type, args, dict): def _rebuild_from_type_v2(func, new_type, args, state): - if new_type is Tensor: - return func(*args) - ret = func(*args) if type(ret) is not new_type: ret = ret.as_subclass(new_type) @@ -70,21 +67,7 @@ def _rebuild_from_type_v2(func, new_type, args, state): ): ret.__setstate__(state) else: - if isinstance(state, tuple): - if not len(state) == 2: - raise RuntimeError(f"Invalid serialized state: {state}") - dict_state = state[0] - slots_state = state[1] - else: - dict_state = state - slots_state = None - - for k, v in dict_state.items(): - setattr(ret, k, v) - - if slots_state: - for k, v in slots_state.items(): - setattr(ret, k, v) + ret = torch._utils._set_obj_state(ret, state) return ret @@ -223,31 +206,13 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): - if type(self) is Tensor: + state = torch._utils._get_obj_state(self) + if type(self) is Tensor and not state: + # Fast path for regular tensor without Python state. return self._reduce_ex_internal(proto) if has_torch_function_unary(self): return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto) func, args = self._reduce_ex_internal(proto) - # Get the state of the python subclass - # This loosely mimicks the function on the object class but since Tensor do not inherit - # from it, we cannot call that function directly - # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 - getstate_fn = getattr(self, "__getstate__", None) - if getstate_fn: - state = getstate_fn() - else: - slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] - if slots_to_save: - state = ( - self.__dict__, - { - name: getattr(self, name) - for name in slots_to_save - if hasattr(self, name) - }, - ) - else: - state = self.__dict__ return (_rebuild_from_type_v2, (func, type(self), args, state)) def storage(self): diff --git a/torch/_utils.py b/torch/_utils.py index 3bc8a749b3e6661..9c646a2f85e0c25 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -1,3 +1,4 @@ +import copyreg import sys import traceback import warnings @@ -335,6 +336,64 @@ def _rebuild_parameter(data, requires_grad, backward_hooks): return param +# TODO(kshitij12345): Support serializing nn.Parameter with Python Attributes. +# NOTE: We are just defining it here now for future use. +def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state): + param = torch.nn.Parameter(data, requires_grad) + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + param._backward_hooks = backward_hooks + + # Restore state on Parameter like python attr. + param = _set_obj_state(param, state) + return param + + +def _get_obj_state(obj): + # Get the state of the python subclass + # This loosely mimicks the function on the object class but since Tensor do not inherit + # from it, we cannot call that function directly + # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 + getstate_fn = getattr(obj, "__getstate__", None) + if getstate_fn: + state = getstate_fn() + else: + slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined] + if slots_to_save: + state = ( + obj.__dict__, + { + name: getattr(obj, name) + for name in slots_to_save + if hasattr(obj, name) + }, + ) + else: + state = obj.__dict__ + + return state + + +def _set_obj_state(obj, state): + if isinstance(state, tuple): + if not len(state) == 2: + raise RuntimeError(f"Invalid serialized state: {state}") + dict_state = state[0] + slots_state = state[1] + else: + dict_state = state + slots_state = None + + for k, v in dict_state.items(): + setattr(obj, k, v) + + if slots_state: + for k, v in slots_state.items(): + setattr(obj, k, v) + return obj + + def _import_dotted_name(name): components = name.split(".") obj = __import__(components[0]) diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index ee00db937fc3deb..acc3554768b0b39 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -103,6 +103,10 @@ def _get_allowed_globals(): torch._utils._rebuild_sparse_csr_tensor, ]: rc[f"torch._utils.{f.__name__}"] = f + + # Handles Tensor Subclasses, Tensor's with attributes. + # NOTE: It calls into above rebuild functions for regular Tensor types. + rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2 return rc diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index f7e974919f03d6b..4bbf7a783a23247 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -532,6 +532,21 @@ PickleOpCode Unpickler::readInstruction() { } stack_.emplace_back(std::move(tensor)); } break; + case PickleOpCode::SETITEM: { + // At this OpCode, stack looks like + // | Stack Bottom | + // | ...... | + // | Dict | -> (stack_size - 3) + // | Key | -> (stack_size - 2) + // | Value | -> (stack_size - 1) + auto stack_size = stack_.size(); + auto dict_pos = stack_size - 3; + auto key_pos = stack_size - 2; + auto val_pos = stack_size - 1; + auto dict = stack_.at(dict_pos).toGenericDict(); + dict.insert_or_assign(stack_.at(key_pos), stack_.at(val_pos)); + stack_.erase(stack_.begin() + (key_pos), stack_.end()); + } break; default: { AT_ERROR( "Unknown opcode for unpickling at ", @@ -546,6 +561,23 @@ PickleOpCode Unpickler::readInstruction() { void Unpickler::readGlobal( const std::string& module_name, const std::string& class_name) { + if (this->skip_next_read_global) { + // See [NOTE] skip_next_read_global + this->skip_next_read_global--; + if (this->skip_next_read_global == 1) { + // Pass through to the correct handler + } else if (this->skip_next_read_global == 0) { + // Corresponds to the type of `Tensor` being unpickled + if (module_name != "torch" || class_name != "Tensor") { + TORCH_WARN( + "Trying to load a Subclassed Tensor, it will be converted to at::Tensor in C++"); + } + stack_.emplace_back(int64_t(globals_.size() - 1)); + return; + } else { + TORCH_CHECK(false, "INVALID VALUES") + } + } // TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this // is only here for bc-compatibility reasons if (module_name == "__main__") { @@ -631,6 +663,12 @@ void Unpickler::readGlobal( // Unpickle a tensor bool quantized = class_name == "_rebuild_qtensor"; rebuildTensor(quantized); + } else if ( + module_name == "torch._tensor" && + (class_name == "_rebuild_from_type_v2")) { + // Unpickle a Tensor with Python attributes or + // a Subclassed Tensor. + rebuildTensorFromTypeV2(); } else if ( module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") { rebuildSparseTensor(); @@ -849,6 +887,39 @@ void Unpickler::rebuildTensor(bool quantized) { }); } +void Unpickler::rebuildTensorFromTypeV2() { + // [NOTE] skip_next_read_global + // When rebuilding Tensor with Python Attr or Subclassed Tensor, + // we receive `(func, type(self), args, state)` on stack for + // `rebuildTensorFromTypeV2`. + // Thus next call to readGlobal corresponds to `func` which is + // the function to rebuild the base tensor. + // The call after `func` to readGlobal corresponds to `type` of the + // Tensor where we raise warning if the type is not `torch.Tensor`. + this->skip_next_read_global = 2; + auto curr_globals_idx = globals_.size(); + globals_.emplace_back([this, curr_globals_idx] { + // args is a tuple with following data + // (function to rebuild base tensor, type of tensor, + // arguments to construct base tensor, Python State (as dict)) + auto args = pop(stack_).toTuple(); + size_t tup_idx = 0; + const auto args_elems = args->elements(); + auto base_tensor_args = args_elems.at(tup_idx + 2).toTuple(); + auto py_state = args_elems.at(tup_idx + 3).toGenericDict(); + if (py_state.size() > 0) { + TORCH_WARN( + "Loading Tensor with Python attributes will return at::Tensor with Python attributes being discarded"); + } + // This calls the function to rebuild the + // base tensor. + // Eg. `rebuildTensor`, `rebuildSpareTensor`. + stack_.emplace_back(base_tensor_args); + globals_[curr_globals_idx + 1](); + stack_.emplace_back(pop(stack_)); + }); +} + #ifdef USE_RPC void Unpickler::rebuildRRef() { globals_.emplace_back([this] { diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index 5411d421a0c57be..de00e7eacff21e2 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -120,6 +120,7 @@ class TORCH_API Unpickler { const std::string& module_name, const std::string& class_name); void rebuildTensor(bool quantized); + void rebuildTensorFromTypeV2(); void rebuildSparseTensor(); #ifdef USE_DISTRIBUTED void rebuildRRef(); @@ -176,6 +177,9 @@ class TORCH_API Unpickler { // See [type tag serialization] uint64_t version_; + + // See [NOTE] skip_next_read_global + uint8_t skip_next_read_global = 0; }; void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag); diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index e0f400f2642bfb1..68908001238ecec 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -60,6 +60,7 @@ def __repr__(self): return 'Parameter containing:\n' + super(Parameter, self).__repr__() def __reduce_ex__(self, proto): + # TODO(kshitij12345): Support saving Python Attribute # See Note [Don't serialize hooks] return ( torch._utils._rebuild_parameter, From 575e02df5357ef6216b2d2db2424d10432679df2 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 11 Nov 2022 21:19:26 +0000 Subject: [PATCH 35/62] Fix CUDNN_PATH handling on Windows (#88898) Fixes https://github.com/pytorch/pytorch/issues/88873 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88898 Approved by: https://github.com/kit1980 --- torch/utils/cpp_extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index aa03da23b38da7b..720935296504f05 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1686,7 +1686,7 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib", "x64")}') extra_ldflags.append('cudart.lib') if CUDNN_HOME is not None: - extra_ldflags.append(os.path.join(CUDNN_HOME, "lib", "x64")) + extra_ldflags.append(f'/LIBPATH:{os.path.join(CUDNN_HOME, "lib", "x64")}') elif not IS_HIP_EXTENSION: extra_ldflags.append(f'-L{_join_cuda_home("lib64")}') extra_ldflags.append('-lcudart') From 7aa144ac54808419f7a702ef0c5a4445dba4c587 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 10 Nov 2022 21:19:21 +0000 Subject: [PATCH 36/62] [FSDP][state_dict][5/N] Remove the FSDP module dependency from _state_dict_utils (#88637) **What** This PR completely removes the `FullyShardedDataParallel` dependency from `_state_dict_utils` -- `_state_dict_utils` now depends only on `_FSDPState` and all the utils modules. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88637 Approved by: https://github.com/awgu --- torch/distributed/fsdp/_init_utils.py | 6 +- torch/distributed/fsdp/_state_dict_utils.py | 108 ++++++++++---------- 2 files changed, 58 insertions(+), 56 deletions(-) diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 966e61f7fe1231b..1265ee3578ed40c 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -213,10 +213,8 @@ def _init_state_dict_state(state: _FSDPState) -> _FSDPState: state._state_dict_type = StateDictType.FULL_STATE_DICT state_dict_config: StateDictConfig = FullStateDictConfig() state._state_dict_config = state_dict_config - full_param_ctx: Optional[Generator] = None - # TODO: For composable API, this should be a dict that maps from a module to - # handles. - state._full_param_ctx = full_param_ctx + unshard_params_ctx: Dict[nn.Module, Generator] = {} + state._unshard_params_ctx = unshard_params_ctx return state diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index eee5522340b46ec..54191cb55ece804 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -8,7 +8,6 @@ import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper # Import the entire FSDP file to avoid circular imports -import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file import torch.nn as nn import torch.nn.functional as F @@ -42,6 +41,7 @@ from ._unshard_param_utils import ( _deregister_orig_params, _register_orig_params, + _unshard_params, FLAT_PARAM, ) from .flat_param import FlatParamHandle @@ -58,7 +58,9 @@ def _convert_to_wrapped_module_name(module_name: str) -> str: return module_name -def _param_fqns(module, fsdp_state: _FSDPState) -> Iterator[Tuple[str, str, str]]: +def _param_fqns( + module: nn.Module, fsdp_state: _FSDPState +) -> Iterator[Tuple[str, str, str]]: if not _has_fsdp_params(fsdp_state, module): return for param_name, module_name in _module_handles(fsdp_state, module)[ @@ -69,7 +71,7 @@ def _param_fqns(module, fsdp_state: _FSDPState) -> Iterator[Tuple[str, str, str] yield fqn, param_name, module_name -def _shared_param_fqns(module, fsdp_state) -> Iterator[Tuple[str, str, str]]: +def _shared_param_fqns(module: nn.Module, fsdp_state) -> Iterator[Tuple[str, str, str]]: for param_name, module_name in _module_handles(fsdp_state, module)[ 0 ].shared_parameter_module_names(): @@ -78,7 +80,9 @@ def _shared_param_fqns(module, fsdp_state) -> Iterator[Tuple[str, str, str]]: yield fqn, param_name, module_name -def _enter_full_param_ctx( +@no_type_check +def _enter_unshard_params_ctx( + module: nn.Module, fsdp_state: _FSDPState, recurse: bool = False, writeback: bool = False, @@ -89,32 +93,32 @@ def _enter_full_param_ctx( """ state_dict hooks cannot use the pure context call as the checkpoint flow requires to enter the context in the pre-hook but leave the context in the - post-hook. This API enters the context of ``summon_full_params``. + post-hook. This API enters the context of ``_unshard_params``. """ - assert fsdp_state._full_param_ctx is None, ( - "Entering the ``summon_full_params`` context but fsdp_state._full_param_ctx " + assert module not in fsdp_state._unshard_params_ctx, ( + "Entering the ``_unshard_params`` context but _unshard_params_ctx[module] " "is not None." ) - fsdp_state._full_param_ctx = fsdp_state._summon_full_params( - recurse=recurse, + fsdp_state._unshard_params_ctx[module] = _unshard_params( + module, + fsdp_state, writeback=writeback, rank0_only=rank0_only, offload_to_cpu=offload_to_cpu, with_grads=with_grads, ) - fsdp_state._full_param_ctx.__enter__() + fsdp_state._unshard_params_ctx[module].__enter__() @no_type_check -def _exit_full_param_ctx(fsdp_state: _FSDPState) -> None: - """A helper function to exit ``summon_full_params`` context.""" - assert fsdp_state._full_param_ctx is not None - fsdp_state._full_param_ctx.__exit__(None, None, None) - fsdp_state._full_param_ctx = None +def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None: + """A helper function to exit ``_unshard_params`` context.""" + fsdp_state._unshard_params_ctx[module].__exit__(None, None, None) + fsdp_state._unshard_params_ctx.pop(module) def _common_pre_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -131,16 +135,18 @@ def _common_pre_state_dict_hook( _clear_grads_if_needed([_module_handles(fsdp_state, module)[0]]) -def _common_summon_pre_state_dict_hook( +def _common_unshard_pre_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, offload_to_cpu: bool, rank0_only: bool, ) -> None: """ Performs the pre-state_dict tasks shared by all state_dict types that require - ``summon_full_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. + ``_unshard_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. """ - _enter_full_param_ctx( + _enter_unshard_params_ctx( + module, fsdp_state, recurse=False, writeback=False, @@ -151,8 +157,8 @@ def _common_summon_pre_state_dict_hook( # TODO: change to the decorator style. See ``_full_pre_state_dict_hook``. @no_type_check -def _common_summon_post_state_dict_hook( - module, +def _common_unshard_post_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -160,13 +166,13 @@ def _common_summon_post_state_dict_hook( ) -> Dict[str, Any]: """ The post-state_dict flow that shared by all state_dict types that require - ``summon_full_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this + ``_unshard_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. """ _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) # Return early for trivial cases if not state_dict or not _has_fsdp_params(fsdp_state, module): - _exit_full_param_ctx(fsdp_state) + _exit_unshard_params_ctx(module, fsdp_state) return state_dict # TODO: Once pre_state_dict hook is supported, this pop should be removed. @@ -193,7 +199,7 @@ def _common_summon_post_state_dict_hook( f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" ) state_dict.pop(f"{prefix}{clean_key}", None) - _exit_full_param_ctx(fsdp_state) + _exit_unshard_params_ctx(module, fsdp_state) return state_dict # Loop only the parameters saved in this instance's wrapped module to @@ -214,7 +220,7 @@ def _common_summon_post_state_dict_hook( ) param_hook(state_dict, prefix, fqn) - _exit_full_param_ctx(fsdp_state) + _exit_unshard_params_ctx(module, fsdp_state) cpu_device = torch.device("cpu") buffer_clean_fqns = [] @@ -251,7 +257,7 @@ def _common_summon_post_state_dict_hook( @no_type_check def _full_pre_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -267,7 +273,8 @@ def _full_pre_state_dict_hook( in ``nn.Module``. """ _common_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) - _common_summon_pre_state_dict_hook( + _common_unshard_pre_state_dict_hook( + module, fsdp_state, offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu, rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only, @@ -276,7 +283,7 @@ def _full_pre_state_dict_hook( @no_type_check def _full_post_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -284,7 +291,7 @@ def _full_post_state_dict_hook( """ Hook that runs after model.state_dict() is called before returning result to user. For FSDP, we may have to clone the tensors in state_dict as params go - back to sharded version after _summon_full_params ends, and also remove + back to sharded version after _unshard_params ends, and also remove the ``FSDP_WRAPPED_MODULE`` prefix. """ # TODO: remove the hack. See ``_full_pre_state_dict_hook``. @@ -303,8 +310,7 @@ def param_hook( if clean_key.startswith(clean_prefix): clean_key = clean_key[len(clean_prefix) :] - # Clone non-ignored parameters before exiting the - # `_summon_full_params()` context + # Clone non-ignored parameters before exiting the `_unshard_params()` context. if clean_key not in fsdp_state._ignored_param_names and not getattr( state_dict[fqn], "_has_been_cloned", False ): @@ -320,30 +326,30 @@ def param_hook( f"implementation of {fqn}. Error: {str(e)}" ) - return _common_summon_post_state_dict_hook( + return _common_unshard_post_state_dict_hook( module, fsdp_state, state_dict, prefix, param_hook ) def _full_pre_load_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: _lazy_init(fsdp_state, module) - _enter_full_param_ctx(fsdp_state, recurse=False, writeback=True) + _enter_unshard_params_ctx(module, fsdp_state, recurse=False, writeback=True) _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") def _full_post_load_state_dict_hook( - module, fsdp_state: _FSDPState, *args, **kwargs + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs ) -> None: - _exit_full_param_ctx(fsdp_state) + _exit_unshard_params_ctx(module, fsdp_state) def _local_pre_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -366,7 +372,7 @@ def _local_pre_state_dict_hook( @no_type_check def _local_post_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -408,13 +414,13 @@ def _local_post_state_dict_hook( def _local_post_load_state_dict_hook( - module, fsdp_state: _FSDPState, *args, **kwargs + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs ) -> None: pass def _local_pre_load_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -457,7 +463,7 @@ def _local_pre_load_state_dict_hook( def _sharded_pre_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -477,7 +483,8 @@ def _sharded_pre_state_dict_hook( _common_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) # Setting offload_to_cpu here does not work even if offload_to_cpu is True. # We have to create ShardedTensor first then move it to CPU. - _common_summon_pre_state_dict_hook( + _common_unshard_pre_state_dict_hook( + module, fsdp_state, offload_to_cpu=False, rank0_only=False, @@ -486,7 +493,7 @@ def _sharded_pre_state_dict_hook( @no_type_check def _sharded_post_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -512,14 +519,14 @@ def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str): sharded_tensor = sharded_tensor.cpu() state_dict[fqn] = sharded_tensor - return _common_summon_post_state_dict_hook( + return _common_unshard_post_state_dict_hook( module, fsdp_state, state_dict, prefix, param_hook ) @no_type_check def _sharded_post_load_state_dict_hook( - module, fsdp_state: _FSDPState, *args, **kwargs + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs ) -> None: if fsdp_state._use_orig_params: _register_orig_params(module, fsdp_state) @@ -527,7 +534,7 @@ def _sharded_post_load_state_dict_hook( @no_type_check def _sharded_pre_load_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -636,9 +643,8 @@ def _post_state_dict_hook( StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, } - fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type]( - fsdp_module, fsdp_state, state_dict, prefix + module, fsdp_state, state_dict, prefix ) return processed_state_dict @@ -664,12 +670,11 @@ def _pre_load_state_dict_hook( StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook, } # Code that is common for all state_dict impls - fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) if torch.cuda.is_available(): torch.cuda.synchronize() # Dispatch into state_dict specific implementation of pre-hook. _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type]( - fsdp_module, fsdp_state, state_dict, prefix + module, fsdp_state, state_dict, prefix ) @@ -684,7 +689,6 @@ def _post_load_state_dict_hook(module: nn.Module, *args: Any) -> None: StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook, } # Code that is common for all state_dict impls - fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) # Dispatch into state_dict type specific implementation of post-hook for # loading state_dict. - _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](fsdp_module, fsdp_state) + _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state) From dfb4b73e45896851d734e34a9902fd8b151797fe Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Fri, 11 Nov 2022 21:51:10 +0000 Subject: [PATCH 37/62] Fix unused variable 'options' warning in RNN.cpp (#88753) Fixes ``` /home/rbarnes/pytorch/aten/src/ATen/native/cudnn/RNN.cpp:73:17: warning: unused variable 'options' [-Wunused-variable] TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); ^ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88753 Approved by: https://github.com/soumith --- aten/src/ATen/native/cudnn/RNN.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index c08c5d26b63c73a..426243392b6fc35 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -70,7 +70,7 @@ Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_see c10::optional device, c10::optional pin_memory) { // See [Note: hacky wrapper removal for TensorOptions] - TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); + TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); AT_ERROR("_cudnn_init_dropout_state: ATen not compiled with cuDNN support"); } From ea0ec9d71ca5428bedfcaf74990c109af8cb9a64 Mon Sep 17 00:00:00 2001 From: efiks <5167930+efiks@users.noreply.github.com> Date: Fri, 11 Nov 2022 21:58:23 +0000 Subject: [PATCH 38/62] [tourch] BatchBoxCox - fix numerical issue in vectorized code (#88875) Summary: Usage of fast math in BatchBoxCox kernel provided different math results between dev and optimized versions which cause few internal test to fail. For now disabling the compiler optimized version and relying on ATEN vectors Differential Revision: D41211784 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88875 Approved by: https://github.com/hyuen --- caffe2/perfkernels/batch_box_cox_avx2.cc | 93 ++++++++++++++---------- 1 file changed, 53 insertions(+), 40 deletions(-) diff --git a/caffe2/perfkernels/batch_box_cox_avx2.cc b/caffe2/perfkernels/batch_box_cox_avx2.cc index 8b93293646dba2b..6171b5bfd0326c6 100644 --- a/caffe2/perfkernels/batch_box_cox_avx2.cc +++ b/caffe2/perfkernels/batch_box_cox_avx2.cc @@ -1,3 +1,4 @@ +#include #ifdef CAFFE2_PERF_USE_MKL #include #include @@ -5,30 +6,68 @@ #include "vectorizer.h" -#ifndef VECTORIZED_KERNEL +// Enable compiler vectorized version only if numerical consistency is not +// required between dev and opt versions - disabled for now +#ifndef FAST_VECTORIZED_KERNEL #define CPU_CAPABILITY_AVX2 #include namespace at::vec { +// Implements the vectorized version of std::max() operation, +// which DOESNOT propagates NaN for second argument template Vectorized max(const Vectorized& a, const Vectorized& b); -// Implements the vectorized version of std::max() operation, -// which DOESNOT propagates NaN for second argument template <> Vectorized max(const Vectorized& a, const Vectorized& b) { // std::max(NaN, nonNan) -> NaN return _mm256_max_pd(b, a); } - template <> Vectorized max(const Vectorized& a, const Vectorized& b) { // std::max(NaN, nonNan) -> NaN return _mm256_max_ps(b, a); } +// Implements recieprocal method based on newton-rapson method +// 1. user RCP approximiation +// 2. update with RCP = RCP * (2 - X * RCP) +template +Vectorized fast_recieprocal(const Vectorized& b); +template +scalar_t fast_recieprocal(scalar_t b); + +template<> +Vectorized fast_recieprocal(const Vectorized& b) { + auto minus2 = _mm256_set1_ps(-2.f); + auto rcp = _mm256_rcp_ps(b); + rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); + rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); + return rcp; +} + +template <> +float fast_recieprocal(float b) { + auto minus2 = _mm_set_ss(-2.f); + auto b_reg = _mm_set_ss(b); + auto rcp = _mm_rcp_ss(b_reg); + rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); + rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); + return _mm_cvtss_f32(rcp); +} + +template<> +Vectorized fast_recieprocal(const Vectorized& b) { + return b.reciprocal(); +} + +template <> +double fast_recieprocal(double b) { + return 1./b; +} + } #endif @@ -45,14 +84,6 @@ template void PackV(const int N, const T* a, const int* ia, T* y); template void UnpackV(const int N, const T* a, T* y, const int* iy); -template -void Pow(const int N, const T* a, const T* b, T* y); -template -void Add(const int N, const T* a, const T* b, T* y); -template -void Div(const int N, const T* a, const T* b, T* y); -template -void Ln(const int N, const T* a, T* y); #define DELEGATE_PACKV_FUNCTION(T, OriginalFunc) \ template <> \ @@ -72,29 +103,7 @@ DELEGATE_UNPACKV_FUNCTION(float, vsUnpackV) DELEGATE_UNPACKV_FUNCTION(double, vdUnpackV) #undef DELEGATE_UNPACKV_FUNCTION -#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Funcname, OriginalFunc) \ - template <> \ - void Funcname(const int N, const T* a, const T* b, T* y) { \ - OriginalFunc(N, a, b, y); \ - } -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Pow, vsPow) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Pow, vdPow) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Add, vsAdd) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Add, vdAdd) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Div, vsDiv) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, vdDiv) -#undef DELEGATE_SIMPLE_BINARY_FUNCTION - -#define DELEGATE_SIMPLE_UNARY_FUNCTION(T, Funcname, OriginalFunc) \ - template <> \ - void Funcname(const int N, const T* a, T* y) { \ - OriginalFunc(N, a, y); \ - } -DELEGATE_SIMPLE_UNARY_FUNCTION(float, Ln, vsLn) -DELEGATE_SIMPLE_UNARY_FUNCTION(double, Ln, vdLn) -#undef DELEGATE_SIMPLE_UNARY_FUNCTION - -#ifndef VECTORIZED_KERNEL +#ifndef FAST_VECTORIZED_KERNEL template void box_cox_zero_lambda( size_t D, @@ -140,7 +149,7 @@ void box_cox_nonzero_lambda( auto sum = data + lambda2; auto max = at::vec::max(sum, k_eps_vec); auto lambda1 = Vec::loadu(lambda1_ptr + j); - auto lambda_over_1 = lambda1.reciprocal(); + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1); auto pow = max.pow(lambda1); auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1); res.store(out + j); @@ -148,7 +157,7 @@ void box_cox_nonzero_lambda( for ( ;j < D; ++j) { auto sum = data_ptr[j] + lambda2_ptr[j]; auto max = std::max(sum, k_eps); - auto lambda_over_1 = 1 / lambda1_ptr[j]; + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]); auto pow = std::pow(max, lambda1_ptr[j]); out[j] = pow * lambda_over_1 - lambda_over_1; } @@ -181,12 +190,16 @@ void box_cox_nonzero_lambda( FAST_MATH auto sum = data_ptr[j] + lambda2_ptr[j]; auto max = std::max(sum, k_eps); - auto lambda_over_1 = 1 / lambda1_ptr[j]; - auto pow = std::pow(max, lambda1_ptr[j]); + auto lamda1 = lambda1_ptr[j]; + auto lambda_over_1 = 1 / lamda1; + if constexpr (std::is_same::value) { + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + } + auto pow = std::pow(max, lamda1); out[j] = pow * lambda_over_1 - lambda_over_1; } } - #endif template From fbc1878265374a159639993269d40a6e08503278 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 8 Nov 2022 10:22:32 -0800 Subject: [PATCH 39/62] [ONNX] Pretty print diagnostic logging (#88261) Adds pretty print diagnostic logging. For example ```python import io import torch from torch.onnx._internal import diagnostics class CustomAdd(torch.autograd.Function): @staticmethod def forward(ctx, x, y): return x + y @staticmethod def symbolic(g, x, y): return g.op("custom::CustomAdd", x, y) class M(torch.nn.Module): def forward(self, x): return CustomAdd.apply(x, x) # trigger warning for missing shape inference. # rule = diagnostics.rules.node_missing_onnx_shape_inference torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) ``` By default, observe minimum summary of diagnostics ``` ========= Diagnostic Run torch.onnx.export version 1.14.0a0+git90a69c5 ========= verbose: False, log level: Level.ERROR ======================= 0 NONE 0 NOTE 3 WARNING 0 ERROR ======================== 3 WARNING were not printed due to the log level. ``` Adjusting the `verbose` and `level` argument. ```python diagnostics.engine.pretty_print(verbose=True, level=diagnostics.levels.WARNING) ``` Prints full log. ``` =============================== 1 Diagnostic Run =============================== ========= Diagnostic Run torch.onnx.export version 1.14.0a0+git90a69c5 ========= verbose: True, log level: Level.WARNING ======================= 0 NONE 0 NOTE 3 WARNING 0 ERROR ======================== WARNING: node-missing-onnx-shape-inference ========================================== The shape inference of custom::CustomAdd type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. --------------------------- Stack: Python call stack --------------------------- frame: diagnostic = ExportDiagnostic(rule, level, message, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/diagnostics/_diagnostic.py:151 frame: n, utils._params_dict, GLOBALS.export_onnx_opset_version /home/bowbao/pytorch_dev/torch/onnx/_patch_torch.py:82 frame: <@beartype(torch.onnx._patch_torch._graph_op) at 0x7f62184b6710>:78 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: return function(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_deprecation.py:30 frame: return g.op("custom::CustomAdd", x, y) test_pretty_print.py:14 frame: return symbolic_fn(g, *args) /home/bowbao/pytorch_dev/torch/onnx/utils.py:1716 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: graph = _C._jit_pass_onnx(graph, operator_export_type) /home/bowbao/pytorch_dev/torch/onnx/utils.py:663 frame: <@beartype(torch.onnx.utils._optimize_graph) at 0x7f62180e05f0>:85 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: module=module, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1123 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: dynamic_axes=dynamic_axes, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1539 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: export_modules_as_functions=export_modules_as_functions, /home/bowbao/pytorch_dev/torch/onnx/utils.py:519 frame: <@beartype(torch.onnx.utils.export) at 0x7f62180e0170>:347 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) test_pretty_print.py:22 ---------------------------- Stack: C++ call stack ----------------------------- frame: () frame: ( + 0x88411b (0x7f625b36011b in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Value*, std::pair const&) + 0x7d3 (0x7f625b351743 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Node*) + 0x4f (0x7f625b35198f in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map, std::allocator >, c10::IValue, std::less, std::allocator > >, std::allocator, std::allocator > const, c10::IValue> > > const&, int) + 0xac9 (0x7f625b357179 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0xabd026 (0x7f625b599026 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0x3c0fda (0x7f625ae9cfda in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: () WARNING: node-missing-onnx-shape-inference ========================================== The shape inference of custom::CustomAdd type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. --------------------------- Stack: Python call stack --------------------------- frame: diagnostic = ExportDiagnostic(rule, level, message, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/diagnostics/_diagnostic.py:151 frame: graph, params_dict, GLOBALS.export_onnx_opset_version /home/bowbao/pytorch_dev/torch/onnx/utils.py:688 frame: <@beartype(torch.onnx.utils._optimize_graph) at 0x7f62180e05f0>:85 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: module=module, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1123 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: dynamic_axes=dynamic_axes, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1539 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: export_modules_as_functions=export_modules_as_functions, /home/bowbao/pytorch_dev/torch/onnx/utils.py:519 frame: <@beartype(torch.onnx.utils.export) at 0x7f62180e0170>:347 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) test_pretty_print.py:22 ---------------------------- Stack: C++ call stack ----------------------------- frame: () frame: ( + 0x88411b (0x7f625b36011b in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Value*, std::pair const&) + 0x7d3 (0x7f625b351743 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Node*) + 0x4f (0x7f625b35198f in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map, std::allocator >, c10::IValue, std::less, std::allocator > >, std::allocator, std::allocator > const, c10::IValue> > > const&, int) + 0xac9 (0x7f625b357179 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0x87d6d1 (0x7f625b3596d1 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::ONNXShapeTypeInference(std::shared_ptr&, std::map, std::allocator >, c10::IValue, std::less, std::allocator > >, std::allocator, std::allocator > const, c10::IValue> > > const&, int) + 0x33 (0x7f625b359cf3 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0xabdbae (0x7f625b599bae in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0x3c0fda (0x7f625ae9cfda in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: () WARNING: node-missing-onnx-shape-inference ========================================== The shape inference of custom::CustomAdd type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. --------------------------- Stack: Python call stack --------------------------- frame: diagnostic = ExportDiagnostic(rule, level, message, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/diagnostics/_diagnostic.py:151 frame: graph, params_dict, GLOBALS.export_onnx_opset_version /home/bowbao/pytorch_dev/torch/onnx/utils.py:1179 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: dynamic_axes=dynamic_axes, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1539 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: export_modules_as_functions=export_modules_as_functions, /home/bowbao/pytorch_dev/torch/onnx/utils.py:519 frame: <@beartype(torch.onnx.utils.export) at 0x7f62180e0170>:347 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) test_pretty_print.py:22 ---------------------------- Stack: C++ call stack ----------------------------- frame: () frame: ( + 0x88411b (0x7f625b36011b in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Value*, std::pair const&) + 0x7d3 (0x7f625b351743 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Node*) + 0x4f (0x7f625b35198f in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map, std::allocator >, c10::IValue, std::less, std::allocator > >, std::allocator, std::allocator > const, c10::IValue> > > const&, int) + 0xac9 (0x7f625b357179 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0x87d6d1 (0x7f625b3596d1 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::ONNXShapeTypeInference(std::shared_ptr&, std::map, std::allocator >, c10::IValue, std::less, std::allocator > >, std::allocator, std::allocator > const, c10::IValue> > > const&, int) + 0x33 (0x7f625b359cf3 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0xabdbae (0x7f625b599bae in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0x3c0fda (0x7f625ae9cfda in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: () ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88261 Approved by: https://github.com/abock, https://github.com/justinchuby --- test/onnx/internal/test_diagnostics.py | 2 +- .../onnx/_internal/diagnostics/_diagnostic.py | 18 +-- .../_internal/diagnostics/infra/_infra.py | 110 +++++++++++++++++- .../_internal/diagnostics/infra/engine.py | 15 +++ .../_internal/diagnostics/infra/formatter.py | 18 +++ .../onnx/_internal/diagnostics/infra/utils.py | 2 +- 6 files changed, 140 insertions(+), 25 deletions(-) diff --git a/test/onnx/internal/test_diagnostics.py b/test/onnx/internal/test_diagnostics.py index ea9a789e91c1f28..884b7cb1c3880ee 100644 --- a/test/onnx/internal/test_diagnostics.py +++ b/test/onnx/internal/test_diagnostics.py @@ -19,7 +19,7 @@ def _assert_has_diagnostics( rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]], ): sarif_log = engine.sarif_log() - unseen_pairs = {(rule.id, level.value) for rule, level in rule_level_pairs} + unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs} actual_results = [] for run in sarif_log.runs: if run.results is None: diff --git a/torch/onnx/_internal/diagnostics/_diagnostic.py b/torch/onnx/_internal/diagnostics/_diagnostic.py index 21e44f2b44671e2..efe5c0e34911cb8 100644 --- a/torch/onnx/_internal/diagnostics/_diagnostic.py +++ b/torch/onnx/_internal/diagnostics/_diagnostic.py @@ -74,22 +74,6 @@ def record_cpp_call_stack(self, frames_to_skip) -> None: self.with_stack(stack) self.cpp_call_stack = stack - def with_model_source_location( - self: _ExportDiagnostic, - ) -> _ExportDiagnostic: - # TODO: Implement this. - # self.locations.append(...) - raise NotImplementedError() - return self - - def with_export_source_location( - self: _ExportDiagnostic, - ) -> _ExportDiagnostic: - # TODO: Implement this. - # self.locations.append(...) - raise NotImplementedError() - return self - class ExportDiagnosticEngine(infra.DiagnosticEngine): """PyTorch ONNX Export diagnostic engine. @@ -115,7 +99,6 @@ def __init__(self) -> None: name="torch.onnx", version=torch.__version__, diagnostic_type=ExportDiagnostic, - options=None, ) @property @@ -150,6 +133,7 @@ def create_export_diagnostic_context(): try: yield context finally: + context.pretty_print(context.options.log_verbose, context.options.log_level) context = engine.background_context diff --git a/torch/onnx/_internal/diagnostics/infra/_infra.py b/torch/onnx/_internal/diagnostics/infra/_infra.py index b8a4c5032f52390..3414574cce739bd 100644 --- a/torch/onnx/_internal/diagnostics/infra/_infra.py +++ b/torch/onnx/_internal/diagnostics/infra/_infra.py @@ -17,10 +17,10 @@ class Level(enum.Enum): please use infra.Tag instead. """ - NONE = "none" - NOTE = "note" - WARNING = "warning" - ERROR = "error" + NONE = enum.auto() + NOTE = enum.auto() + WARNING = enum.auto() + ERROR = enum.auto() levels = Level @@ -107,6 +107,9 @@ def format_message(self, *args, **kwargs) -> str: """ return self.message_default_template.format(*args, **kwargs) + def pretty_print(self): + pass + @dataclasses.dataclass class Location: @@ -134,6 +137,25 @@ def sarif(self) -> sarif.Location: else None, ) + def pretty_print(self): + """Prints the location in a human-readable format.""" + location_strs = ["frame:"] + if self.snippet is not None: + location_strs.append(self.snippet) + if self.uri is not None: + line_strs = [self.uri] + line_strs.append(str(self.line)) if self.line is not None else "-1" + line_strs.append( + str(self.start_column) + ) if self.start_column is not None else "-1" + line_strs.append( + str(self.end_column) + ) if self.end_column is not None else "-1" + location_strs.append(":".join(line_strs)) + if self.message is not None: + location_strs.append(f"({self.message})") + print(" ".join(location_strs)) + @dataclasses.dataclass class StackFrame: @@ -143,6 +165,10 @@ def sarif(self) -> sarif.StackFrame: """Returns the SARIF representation of this stack frame.""" return sarif.StackFrame(location=self.location.sarif()) + def pretty_print(self): + """Prints the stack frame in a human-readable format.""" + self.location.pretty_print() + @dataclasses.dataclass class Stack: @@ -158,6 +184,12 @@ def sarif(self) -> sarif.Stack: else None, ) + def pretty_print(self): + """Prints the stack in a human-readable format.""" + formatter.pretty_print_title(f"Stack: {self.message}", fill_char="-") + for frame in self.frames: + frame.pretty_print() + # This is a workaround for mypy not supporting Self from typing_extensions. _Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic") @@ -182,6 +214,9 @@ def sarif(self) -> sarif.Graph: properties=PatchedPropertyBag(name=self.name, description=self.description), ) + def pretty_print(self): + pass + @dataclasses.dataclass class Diagnostic: @@ -201,7 +236,7 @@ def sarif(self) -> sarif.Result: message = f"{message}\n{self.additional_message}" sarif_result = sarif.Result( message=sarif.Message(text=message), - level=self.level.value, + level=self.level.name.lower(), # type: ignore[arg-type] rule_id=self.rule.id, ) sarif_result.locations = [location.sarif() for location in self.locations] @@ -235,6 +270,31 @@ def with_additional_message(self: _Diagnostic, message: str) -> _Diagnostic: self.additional_message = f"{self.additional_message}\n{message}" return self + def pretty_print(self, verbose: bool = False, log_level: Level = Level.ERROR): + """Prints the diagnostics in a human-readable format. + + Args: + verbose: If True, prints all information. E.g. stack frames, graphs, etc. + Otherwise, only prints compact information. E.g., rule name and display message. + level: The minimum level of diagnostics to print. + """ + if self.level.value < log_level.value: + return + formatter.pretty_print_item_title(f"{self.level.name}: {self.rule.name}") + print(self.message) + + if not verbose: + print("\n") + return + + for location in self.locations: + location.pretty_print() + for stack in self.stacks: + stack.pretty_print() + for graph in self.graphs: + graph.pretty_print() + print() + @dataclasses.dataclass class RuleCollection: @@ -284,12 +344,15 @@ class DiagnosticOptions: Options for diagnostic context. """ + log_verbose: bool = dataclasses.field(default=False) + log_level: Level = dataclasses.field(default=Level.ERROR) + @dataclasses.dataclass class DiagnosticContext: name: str version: str - options: Optional[DiagnosticOptions] = None + options: DiagnosticOptions = dataclasses.field(default_factory=DiagnosticOptions) diagnostic_type: Type[Diagnostic] = dataclasses.field(default=Diagnostic) diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list) _invocation: Invocation = dataclasses.field(init=False) @@ -350,3 +413,38 @@ def diagnose( diagnostic = self.diagnostic_type(rule, level, message, **kwargs) self.add_diagnostic(diagnostic) return diagnostic + + def pretty_print( + self, verbose: bool = False, log_level: Level = Level.ERROR + ) -> None: + """Prints the diagnostics in a human-readable format. + + Args: + verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print. + level: The minimum level of diagnostics to print. + """ + formatter.pretty_print_title( + f"Diagnostic Run {self.name} version {self.version}" + ) + print(f"verbose: {verbose}, log level: {log_level}") + diagnostic_stats = {level: 0 for level in Level} + for diagnostic in self.diagnostics: + diagnostic_stats[diagnostic.level] += 1 + formatter.pretty_print_title( + " ".join(f"{diagnostic_stats[level]} {level.name}" for level in Level) + ) + + for diagnostic in self.diagnostics: + diagnostic.pretty_print(verbose, log_level) + + unprinted_diagnostic_stats = [ + (level, count) + for level, count in diagnostic_stats.items() + if count > 0 and level.value < log_level.value + ] + if unprinted_diagnostic_stats: + print( + f"{' '.join(f'{count} {level.name}' for level, count in unprinted_diagnostic_stats)} " + "were not printed due to the log level." + ) + print() diff --git a/torch/onnx/_internal/diagnostics/infra/engine.py b/torch/onnx/_internal/diagnostics/infra/engine.py index 2678268fbaf9ac8..51a6057565bba3f 100644 --- a/torch/onnx/_internal/diagnostics/infra/engine.py +++ b/torch/onnx/_internal/diagnostics/infra/engine.py @@ -85,8 +85,23 @@ def create_diagnostic_context( Returns: A new diagnostic context. """ + if options is None: + options = infra.DiagnosticOptions() context = infra.DiagnosticContext( name, version, options, diagnostic_type=diagnostic_type ) self.contexts.append(context) return context + + def pretty_print( + self, verbose: bool = False, level: infra.Level = infra.Level.ERROR + ) -> None: + """Pretty prints all diagnostics in the diagnostic contexts. + + Args: + verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print. + level: The minimum level of diagnostics to print. + """ + formatter.pretty_print_title(f"{len(self.contexts)} Diagnostic Run") + for context in self.contexts: + context.pretty_print(verbose, level) diff --git a/torch/onnx/_internal/diagnostics/infra/formatter.py b/torch/onnx/_internal/diagnostics/infra/formatter.py index 2f35489f8d454d8..292a2b6a47a5abd 100644 --- a/torch/onnx/_internal/diagnostics/infra/formatter.py +++ b/torch/onnx/_internal/diagnostics/infra/formatter.py @@ -57,3 +57,21 @@ def sarif_to_json(attr_cls_obj: _SarifClass) -> str: dict = dataclasses.asdict(attr_cls_obj) dict = _convert_key(dict, _camel_case_to_snake_case) return json.dumps(dict, indent=4) + + +def pretty_print_title(title: str, width: int = 80, fill_char: str = "=") -> None: + """Pretty prints title in below format: + + ==================== title ==================== + """ + print(f" {title} ".center(width, fill_char)) + + +def pretty_print_item_title(title: str, fill_char: str = "=") -> None: + """Pretty prints title in below format: + + title + ===== + """ + print(title) + print(fill_char * len(title)) diff --git a/torch/onnx/_internal/diagnostics/infra/utils.py b/torch/onnx/_internal/diagnostics/infra/utils.py index c32de1c6b8ad90f..6a85df91046398b 100644 --- a/torch/onnx/_internal/diagnostics/infra/utils.py +++ b/torch/onnx/_internal/diagnostics/infra/utils.py @@ -6,7 +6,7 @@ def python_frame(frame: inspect.FrameInfo) -> _infra.StackFrame: """Returns a StackFrame for the given inspect.FrameInfo.""" snippet = ( - frame.code_context[frame.index] + frame.code_context[frame.index].strip() if frame.code_context is not None and frame.index is not None else None ) From f39cad50b765b6fd2f4927a4d1552fff5928c61e Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 11 Nov 2022 22:07:34 +0000 Subject: [PATCH 40/62] Make InductorCPU usable in internally (#88870) Test Plan: `buck2 test mode/opt //caffe2/test:test_inductor -- --exact 'caffe2/test:test_inductor - test_dtype_mismatch_issue_cuda (caffe2.test.inductor.test_torchinductor.CudaTests)'` Differential Revision: D41206109 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88870 Approved by: https://github.com/izaitsevfb --- torch/_inductor/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 87e2793782be888..8f9f2c4f461dd1c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -92,6 +92,7 @@ class cpp: "g++-10", "clang++", "g++", + "g++.par", ) From be8d88f8d0c6825b1b19354ffbaa4466aae0d3b8 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Thu, 10 Nov 2022 18:33:09 -0500 Subject: [PATCH 41/62] [DataLoader] Removing DataLoader2 related code (#88848) Removing these lines of code as `DataLoader2` has been added to [TorchData](https://github.com/pytorch/data). I'm importing this to confirm it will not impact internal codes. Differential Revision: [D41201578](https://our.internmc.facebook.com/intern/diff/D41201578) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88848 Approved by: https://github.com/ejguan --- docs/source/data.rst | 3 - test/test_dataloader.py | 111 ----------- torch/utils/data/__init__.py | 5 - torch/utils/data/communication/__init__.py | 6 - torch/utils/data/communication/eventloop.py | 70 ------- torch/utils/data/communication/iter.py | 181 ----------------- torch/utils/data/communication/map.py | 159 --------------- torch/utils/data/communication/messages.py | 75 ------- torch/utils/data/communication/protocol.py | 205 -------------------- torch/utils/data/communication/queue.py | 51 ----- torch/utils/data/dataloader_experimental.py | 150 -------------- 11 files changed, 1016 deletions(-) delete mode 100644 torch/utils/data/communication/__init__.py delete mode 100644 torch/utils/data/communication/eventloop.py delete mode 100644 torch/utils/data/communication/iter.py delete mode 100644 torch/utils/data/communication/map.py delete mode 100644 torch/utils/data/communication/messages.py delete mode 100644 torch/utils/data/communication/protocol.py delete mode 100644 torch/utils/data/communication/queue.py delete mode 100644 torch/utils/data/dataloader_experimental.py diff --git a/docs/source/data.rst b/docs/source/data.rst index de2d44920f573c4..b44096d101964c2 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -441,9 +441,6 @@ Example:: .. autoclass:: torch.utils.data.distributed.DistributedSampler -.. This module is experimental and should be private, adding it here for now -.. py:module:: torch.utils.data.communication - .. These modules are documented as part of torch/data listing them here for .. now until we have a clearer fix .. py:module:: torch.utils.data.datapipes diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 270ca89764ed1dc..6a7ff90527d3db2 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -20,19 +20,16 @@ ChainDataset, ConcatDataset, DataLoader, - DataLoader2, Dataset, IterableDataset, IterDataPipe, Subset, TensorDataset, - communication, _utils ) from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL from torch.utils.data.dataset import random_split from torch.utils.data.datapipes.iter import IterableWrapper -from torch.utils.data.datapipes.map import SequenceWrapper from torch._utils import ExceptionWrapper from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest, @@ -2222,114 +2219,6 @@ def test_excessive_thread_creation_warning(self): r"excessive worker creation might get DataLoader running slow or even freeze"): dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) -# Define a global function for testing purposes since local functions cannot be pickled -def identity(x): - return x - -@unittest.skipIf( - TEST_WITH_TSAN, - "Fails with TSAN with the following error: starting new threads after multi-threaded " - "fork is not supported. Dying (set die_after_fork=0 to override)") -class TestDataLoader2(TestCase): - @skipIfNoDill - def test_basics(self): - # TODO(VitalyFedyunin): This test will start breaking if we remove guaranteed order - # of traversing workers - dp = IterableWrapper(list(range(1000))).sharding_filter() - dl = DataLoader(dp, batch_size=3, collate_fn=identity, num_workers=2) - dl2 = DataLoader2(dp, batch_size=3, collate_fn=identity, num_workers=2) - dl2_threading = DataLoader2(dp, batch_size=3, collate_fn=identity, num_workers=2, parallelism_mode='thread') - self.assertEqual(list(dl), list(dl2)) - self.assertEqual(list(dl), list(dl2_threading)) - - class Sorter(IterDataPipe): - def __init__(self, datapipe): - self.datapipe = datapipe - - def __iter__(self): - return iter(sorted(self.datapipe)) - - def test_shuffle(self): - items = list(range(1000)) - dp = IterableWrapper(items).sharding_filter().shuffle() - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=False) - self.assertEqual(items, list(dl)) - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=True) - self.assertNotEqual(items, list(dl)) - self.assertEqual(items, sorted(list(dl))) - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=True) - self.assertNotEqual(items, list(dl)) - self.assertEqual(items, sorted(list(dl))) - - dl = DataLoader2(self.Sorter(dp), batch_size=None, num_workers=2, shuffle=True) - self.assertEqual(list(dl), items) - - dl = DataLoader2(self.Sorter(dp), batch_size=None, num_workers=2, shuffle=True) - self.assertEqual(list(dl), items) - - -@unittest.skipIf( - TEST_WITH_TSAN, - "Fails with TSAN with the following error: starting new threads after multi-threaded " - "fork is not supported. Dying (set die_after_fork=0 to override)") -class TestDataLoader2_EventLoop(TestCase): - @skipIfNoDill - def test_basic_threading(self): - def clean_me(process, req_queue, res_queue): - req_queue.put(communication.messages.TerminateRequest()) - _ = res_queue.get() - process.join() - - it = list(range(100)) - numbers_dp = IterableWrapper(it) - (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline(numbers_dp) - - process.start() - local_datapipe = communication.iter.QueueWrapper( - communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) - - actual = list(local_datapipe) - clean_me(process, req_queue, res_queue) - - self.assertEqual(list(range(100)), actual) - - @skipIfNoDill - def test_basic_mapdatapipe_threading(self): - def clean_me(process, req_queue, res_queue): - req_queue.put(communication.messages.TerminateRequest()) - _ = res_queue.get() - process.join() - - input_len = 100 - it = list(range(input_len)) - numbers_dp = SequenceWrapper(it) - (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline( - numbers_dp) - - process.start() - - # Functional Test: Ensure that you can retrieve every element from the Queue and DataPipe - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - actual = list(local_datapipe) - self.assertEqual([(x, x) for x in range(100)], actual) - - # Functional Test: raise Error when input - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - with self.assertRaisesRegex(IndexError, "out of bound"): - local_datapipe[1000] - - # __len__ Test: Ensure that the correct length is returned - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - self.assertEqual(input_len, len(local_datapipe)) - - clean_me(process, req_queue, res_queue) - class IntegrationTestDataLoaderDataPipe(TestCase): r""" diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py index 6fe6147ddc545ee..bc054a947069fe5 100644 --- a/torch/utils/data/__init__.py +++ b/torch/utils/data/__init__.py @@ -39,8 +39,6 @@ runtime_validation, runtime_validation_disabled, ) -from torch.utils.data.dataloader_experimental import DataLoader2 -from torch.utils.data import communication __all__ = ['BatchSampler', 'ChainDataset', @@ -48,7 +46,6 @@ 'DFIterDataPipe', 'DataChunk', 'DataLoader', - 'DataLoader2', 'Dataset', 'DistributedSampler', 'IterDataPipe', @@ -63,8 +60,6 @@ 'WeightedRandomSampler', '_DatasetKind', 'argument_validation', - 'collate', - 'communication', 'default_collate', 'default_convert', 'functional_datapipe', diff --git a/torch/utils/data/communication/__init__.py b/torch/utils/data/communication/__init__.py deleted file mode 100644 index 1b9cae401189724..000000000000000 --- a/torch/utils/data/communication/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from . import eventloop -from . import iter -from . import map -from . import messages -from . import protocol -from . import queue diff --git a/torch/utils/data/communication/eventloop.py b/torch/utils/data/communication/eventloop.py deleted file mode 100644 index 9bf241d334dfe38..000000000000000 --- a/torch/utils/data/communication/eventloop.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import threading -import pickle - -from torch.utils.data import IterDataPipe, communication, MapDataPipe - -try: - import dill - # XXX: By default, dill writes the Pickler dispatch table to inject its - # own logic there. This globally affects the behavior of the standard library - # pickler for any user who transitively depends on this module! - # Undo this extension to avoid altering the behavior of the pickler globally. - dill.extend(use_dill=False) - HAS_DILL = True -except ImportError: - HAS_DILL = False - -__all__ = [ - "DataPipeToQueuesLoop", - "SpawnProcessForDataPipeline", - "SpawnThreadForDataPipeline", -] - -def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue): - if isinstance(source_datapipe, IterDataPipe): - pipe_type = communication.iter - protocol_type = communication.protocol.IterDataPipeQueueProtocolServer - elif isinstance(source_datapipe, MapDataPipe): - pipe_type = communication.map # type: ignore[misc] - protocol_type = communication.protocol.MapDataPipeQueueProtocolServer # type: ignore[assignment] - else: - raise Exception('Only supports IterDataPipe or MapDataPipe, got', source_datapipe) - - torch.set_num_threads(1) - for _ in pipe_type.DataPipeBehindQueues(source_datapipe, protocol_type(req_queue, res_queue), - blocking_request_get=True): - pass - - -def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe): - req_queue = multiprocessing_ctx.Queue() - res_queue = multiprocessing_ctx.Queue() - process = multiprocessing_ctx.Process( - target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue)) - return process, req_queue, res_queue - - -def SpawnThreadForDataPipeline(datapipe): - r""" - Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with DataPipeToQueuesLoop as target, - and return the process, req_queue, res_queue, thread_local_datapipe. - """ - req_queue = communication.queue.ThreadingQueue() - res_queue = communication.queue.ThreadingQueue() - - try: - new_datapipe = pickle.loads(pickle.dumps(datapipe)) - except Exception as pe: - if HAS_DILL: - try: - new_datapipe = dill.loads(dill.dumps(datapipe)) - except Exception as de: - raise Exception('Unable to dill DataPipe to make thread local copy', de) - - else: - raise Exception('Unable to pickle DataPipe to make thread local copy (consider installing `dill`)', pe) - - process = threading.Thread(target=DataPipeToQueuesLoop, args=( - new_datapipe, req_queue, res_queue), daemon=True) - return process, req_queue, res_queue, new_datapipe diff --git a/torch/utils/data/communication/iter.py b/torch/utils/data/communication/iter.py deleted file mode 100644 index 94f7cd2ec703597..000000000000000 --- a/torch/utils/data/communication/iter.py +++ /dev/null @@ -1,181 +0,0 @@ -import time -import types - -from torch.utils.data import IterDataPipe, communication - -DEFAULT_NON_BLOCKING_SLEEP = 0.001 - -__all__ = [ - "DataPipeBehindQueues", - "EnsureNonBlockingDataPipe", - "InvalidStateResetRequired", - "NonBlocking", - "NotAvailable", - "QueueWrapper", - "default_not_available_hook", -] - - -def default_not_available_hook(): - time.sleep(DEFAULT_NON_BLOCKING_SLEEP) - - -class NotAvailable(Exception): - pass - - -class InvalidStateResetRequired(Exception): - """ - Returned by DataPipe when it is expecting to get reset request, - for example RouterDataPipe expecting all workers to request reset' - """ - pass - - -class NonBlocking(IterDataPipe): - not_available_hook = default_not_available_hook - - def __iter__(self): - self.reset_iterator() - return self - - def __next__(self): - while True: - try: - return self.nonblocking_next() - except StopIteration: - raise StopIteration - except NotAvailable: - if NonBlocking.not_available_hook is not None: - NonBlocking.not_available_hook() - - def nonblocking_next(self): - raise NotImplementedError( - "nonblocking_next is not implemented for %s" % self.__class__) - - def reset_iterator(self): - raise NotImplementedError( - "reset_iterator is not implemented for %s" % self.__class__) - - @staticmethod - def register_not_available_hook(hook_function): - NonBlocking.not_available_hook = hook_function - - -def EnsureNonBlockingDataPipe(validated_datapipe): - if not isinstance(validated_datapipe, IterDataPipe): - raise Exception('Not Iterable DataPipe ' + - str(validated_datapipe.__class__)) - if isinstance(validated_datapipe, NonBlocking): - return validated_datapipe - if not hasattr(validated_datapipe, '_as_iterator'): - validated_datapipe._as_iterator = None # type: ignore[attr-defined] - if not hasattr(validated_datapipe, 'nonblocking_next'): - def nonblocking_next(self): - if self._as_iterator is None: - self._as_iterator = iter(self) - return next(self._as_iterator) - validated_datapipe.nonblocking_next = types.MethodType( # type: ignore[attr-defined] - nonblocking_next, validated_datapipe) - if not hasattr(validated_datapipe, 'reset_iterator'): - def reset_iterator(self): - self._as_iterator = None - validated_datapipe.reset_iterator = types.MethodType( # type: ignore[attr-defined] - reset_iterator, validated_datapipe) - return validated_datapipe - - -def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False): - """ - Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue - If raise_stop is true, raises exception when StopIteration received from the source_datapipe - """ - if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolServer): - raise Exception('Expecting IterDataPipeQueueProtocolServer, got', protocol) - source_datapipe = EnsureNonBlockingDataPipe(source_datapipe) - forever = True - while forever: - try: - # Non-blocking call is Extremely slow here for python.mp, need to figure out a good workaround - request = protocol.get_new_request(block=blocking_request_get) - except communication.protocol.EmptyQueue: - yield True - continue - - if isinstance(request, communication.messages.ResetIteratorRequest): - source_datapipe.reset_iterator() - protocol.response_reset_iterator() - - elif isinstance(request, communication.messages.TerminateRequest): - forever = False - protocol.response_terminate() - - elif isinstance(request, communication.messages.GetNextRequest): - while forever: - try: - value = source_datapipe.nonblocking_next() - except NotAvailable: - yield True - continue - except StopIteration: - protocol.response_stop_iteration() - if full_stop: - forever = False - else: - yield True - break - except InvalidStateResetRequired: - protocol.response_invalid_state() - if full_stop: - forever = False - else: - yield True - break - protocol.response_next(value) - yield True # Returns control - break - else: - raise Exception('Unrecognized type of request received', request) - - -class QueueWrapper(NonBlocking): - """ - Creates iter.DataPipe which reads data from the DataLoader.Queue - """ - - def __init__(self, protocol, response_wait_time=0.00001): - if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolClient): - raise Exception('Got', protocol) - self.protocol = protocol - self.counter = 0 - self._stop_iteration = False - self._response_wait_time = response_wait_time - - def reset_iterator(self): - self._stop_iteration = False - self.counter = 0 - self.protocol.request_reset_iterator() - while True: - try: - self.protocol.get_response_reset_iterator() - break - except communication.protocol.EmptyQueue: - if NonBlocking.not_available_hook is not None: - NonBlocking.not_available_hook() - - def nonblocking_next(self): - if self._stop_iteration: - raise Exception( - '`next` or `nonblocking_next` called after receiving StopIteration') - if self.protocol.can_take_request(): - self.protocol.request_next() - try: - response = self.protocol.get_response_next(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - if isinstance(response, communication.messages.StopIterationResponse): - self._stop_iteration = True - raise StopIteration - if isinstance(response, communication.messages.InvalidStateResponse): - raise NotAvailable - return response.value diff --git a/torch/utils/data/communication/map.py b/torch/utils/data/communication/map.py deleted file mode 100644 index 8af63bf0c73ecfd..000000000000000 --- a/torch/utils/data/communication/map.py +++ /dev/null @@ -1,159 +0,0 @@ -import time -import types - -from torch.utils.data import communication, MapDataPipe - -DEFAULT_NON_BLOCKING_SLEEP = 0.001 - -__all__ = [ - "DataPipeBehindQueues", - "EnsureNonBlockingMapDataPipe", - "NonBlockingMap", - "NotAvailable", - "QueueWrapperForMap", - "default_not_available_hook", -] - - -def default_not_available_hook(): - time.sleep(DEFAULT_NON_BLOCKING_SLEEP) - - -class NotAvailable(Exception): - pass - - -class NonBlockingMap(MapDataPipe): - not_available_hook = default_not_available_hook - - def __getitem__(self, index): - while True: - try: - return self.nonblocking_getitem(index) - except NotAvailable: - if NonBlockingMap.not_available_hook is not None: - NonBlockingMap.not_available_hook() - - def __len__(self): - try: - return self.nonblocking_len() - except NotAvailable: - if NonBlockingMap.not_available_hook is not None: - NonBlockingMap.not_available_hook() - - def nonblocking_len(self): - raise NotImplementedError( - "nonblocking_len is not implemented for %s" % self.__class__) - - def nonblocking_getitem(self, index): - raise NotImplementedError( - "nonblocking_getitem is not implemented for %s" % self.__class__) - - @staticmethod - def register_not_available_hook(hook_function): - NonBlockingMap.not_available_hook = hook_function - - -def EnsureNonBlockingMapDataPipe(validated_datapipe): - if not isinstance(validated_datapipe, MapDataPipe): - raise Exception(f'Not Map DataPipe - got {validated_datapipe.__class__}') - if isinstance(validated_datapipe, NonBlockingMap): - return validated_datapipe - if not hasattr(validated_datapipe, 'nonblocking_len'): - def nonblocking_len(self): - return self.__len__() - validated_datapipe.nonblocking_len = types.MethodType( # type: ignore[attr-defined] - nonblocking_len, validated_datapipe) - if not hasattr(validated_datapipe, 'nonblocking_getitem'): - def nonblocking_getitem(self, index): - return self.__getitem__(index) - validated_datapipe.nonblocking_getitem = types.MethodType( # type: ignore[attr-defined] - nonblocking_getitem, validated_datapipe) - return validated_datapipe - - -def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False): - """ - Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue - If raise_stop is true, raises exception when StopIteration received from the source_datapipe - """ - if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolServer): - raise Exception('Expecting MapDataPipeQueueProtocolServer, got', protocol) - source_datapipe = EnsureNonBlockingMapDataPipe(source_datapipe) - forever = True - while forever: - try: - # Non-blocking call is Extremely slow here for python.mp, need to figure out a good workaround - request = protocol.get_new_request(block=blocking_request_get) - except communication.protocol.EmptyQueue: - yield True - continue - - if isinstance(request, communication.messages.TerminateRequest): - forever = False - protocol.response_terminate() - - elif isinstance(request, communication.messages.LenRequest): - size = source_datapipe.nonblocking_len() - protocol.response_len(size) - - elif isinstance(request, communication.messages.GetItemRequest): - while forever: - try: - value = source_datapipe.nonblocking_getitem(request.key) - except NotAvailable: - yield True - continue - except IndexError as e: - # Alternatively, we can just allow the underlying DataPipe to throw an exception? - protocol.response_index_out_of_bound() - if full_stop: - forever = False - else: - yield True - break - protocol.response_item(request.key, value) - yield True # Returns control - break - else: - raise Exception('Unrecognized type of request received', request) - - -class QueueWrapperForMap(NonBlockingMap): - """ - Creates map.DataPipe which reads data from the DataLoader.Queue - """ - def __init__(self, protocol, response_wait_time=0.00001): - if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolClient): - raise Exception('Got', protocol) - self.protocol = protocol - self.counter = 0 - self._stop_iteration = False - self._response_wait_time = response_wait_time - - def nonblocking_getitem(self, index): - if self._stop_iteration: - raise Exception( - '`getitem` or `nonblocking_getitem` called after receiving StopIteration') - if self.protocol.can_take_request(): - self.protocol.request_item(index) - try: - response = self.protocol.get_response_item(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - if isinstance(response, communication.messages.StopIterationResponse): - self._stop_iteration = True - raise IndexError(f"Index {index} is out of bound.") - return response.key, response.value - - def nonblocking_len(self): - if self._stop_iteration: - raise Exception( - '`len` or `nonblocking_len` called after receiving StopIteration') - if self.protocol.can_take_request(): - self.protocol.request_len() - try: - response = self.protocol.get_response_len(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - return response.len diff --git a/torch/utils/data/communication/messages.py b/torch/utils/data/communication/messages.py deleted file mode 100644 index 449cf23cfc01c1f..000000000000000 --- a/torch/utils/data/communication/messages.py +++ /dev/null @@ -1,75 +0,0 @@ -class DataLoaderQueueMessage(object): - pass - - -class Request(DataLoaderQueueMessage): - pass - - -class Response(DataLoaderQueueMessage): - pass - - -class ResetIteratorRequest(Request): - pass - - -class ResetIteratorResponse(Response): - pass - - -class TerminateRequest(Request): - pass - - -class TerminateResponse(Response): - pass - - -class LenRequest(Request): - pass - - -class LenResponse(Response): - __slots__ = ('len') - - def __init__(self, len): - self.len = len - - -class GetItemRequest(Request): - __slots__ = ('key') - - def __init__(self, key): - self.key = key - - -class GetItemResponse(Response): - __slots__ = ('key', 'value') - - def __init__(self, key, value): - self.key = key - self.value = value - - -class GetNextRequest(Request): - pass - - -class GetNextResponse(Response): - __slots__ = ('value') - - def __init__(self, value): - self.value = value - - -class StopIterationResponse(Response): - pass - - -class InvalidStateResponse(Response): - """ - Returned by DataPipe when it is expecting to get reset request, - for example RouterDataPipe expecting all workers to request reset' - """ - pass diff --git a/torch/utils/data/communication/protocol.py b/torch/utils/data/communication/protocol.py deleted file mode 100644 index 5bf5fe1af0626e7..000000000000000 --- a/torch/utils/data/communication/protocol.py +++ /dev/null @@ -1,205 +0,0 @@ -from torch.utils.data import communication - - -class Protocol(object): - __slots__ = ('request_queue', 'response_queue') - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - - -class ProtocolClient(Protocol): - """ - ProtocolClient takes charge of putting requests into req_queue and returning results from res_queue. - """ - _req_sent = None - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - self._req_sent = None - - def can_take_request(self): - return self._req_sent is None - - def waiting_for_response(self): - return self._req_sent is not None - - def request_sent(self, request=True): - if not self.can_take_request(): - raise Exception('Protocol only supports one request in the Queue') - self._req_sent = request - - def request_served(self, result=None): - if not self.waiting_for_response(): - raise Exception( - 'Expected no peding requests, but something got served', result) - self._req_sent = None - - -class ProtocolServer(Protocol): - """ - ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe. - """ - _req_received = None - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - self._req_received = None - - def have_pending_request(self): - return self._req_received is not None - - def get_new_request(self, block=False): - if self.have_pending_request(): - raise Exception( - 'Trying to get next request, while having one unserved') - try: - response = self.request_queue.get(block=block) - except Exception as e: # TODO: Catch only timeout exceptions - raise EmptyQueue('queue is empty') - self._req_received = response - return response - # TODO: Validate supported requests - - def response_terminate(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - if not isinstance(self._req_received, communication.messages.TerminateRequest): - raise Exception( - "Replaying with terminate status to other type of message") - self.response_queue.put(communication.messages.TerminateResponse()) - self._req_received = None - - -class MapDataPipeQueueProtocolServer(ProtocolServer): - def response_item(self, key, value): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.GetItemResponse(key, value)) - self._req_received = None - - def response_len(self, size): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.LenResponse(size)) - self._req_received = None - - def response_index_out_of_bound(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.StopIterationResponse()) - self._req_received = None - -class MapDataPipeQueueProtocolClient(ProtocolClient): - def request_len(self): - if not self.can_take_request(): - raise Exception('Can not request len while we are still waiting response for previous request') - request = communication.messages.LenRequest() - self.request_queue.put(request) - self.request_sent(request) - - def request_item(self, index): - if not self.can_take_request(): - raise Exception('Can not request item while we are still waiting response for previous request') - request = communication.messages.GetItemRequest(index) - self.request_queue.put(request) - self.request_sent(request) - - def get_response_len(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception('Can not expect any response without submitted request') - try: - response = self.response_queue.get(block=block, timeout=timeout) - except TimeoutError: - raise EmptyQueue('queue is empty') - self.request_served(response) - if not isinstance(response, communication.messages.LenResponse): - raise Exception('Invalid response received') - return response - - def get_response_item(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception('Can not expect any response without submitted request') - try: - response = self.response_queue.get(block=block, timeout=timeout) - except TimeoutError: - raise EmptyQueue('queue is empty') - self.request_served(response) - # if not isinstance(response, communication.messages.GetItemResponse): - # raise Exception('Invalid response received') - return response - - -class EmptyQueue(Exception): - pass - - -class IterDataPipeQueueProtocolServer(ProtocolServer): - def response_reset_iterator(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - if not isinstance(self._req_received, communication.messages.ResetIteratorRequest): - raise Exception( - "Replaying with reset status to other type of message") - self.response_queue.put(communication.messages.ResetIteratorResponse()) - self._req_received = None - - def response_next(self, value): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.GetNextResponse(value)) - self._req_received = None - - def response_stop_iteration(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.StopIterationResponse()) - self._req_received = None - - def response_invalid_state(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.InvalidStateResponse()) - self._req_received = None - - -class IterDataPipeQueueProtocolClient(ProtocolClient): - def request_reset_iterator(self): - if not self.can_take_request(): - raise Exception('Can not reset while we are still waiting response for previous request') - request = communication.messages.ResetIteratorRequest() - self.request_queue.put(request) - self.request_sent(request) - - def request_next(self): - if not self.can_take_request(): - raise Exception('Can not request next item while we are still waiting response for previous request') - request = communication.messages.GetNextRequest() - self.request_queue.put(request) - self.request_sent(request) - - def get_response_reset_iterator(self, block=False): - try: - response = self.response_queue.get(block=block) - except Exception as e: # TODO: Catch only timeout exceptions - raise EmptyQueue('queue is empty') - self.request_served(response) - - if not isinstance(response, communication.messages.ResetIteratorResponse): - raise Exception('Invalid response received') - - def get_response_next(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception( - 'Can not expect any response without submitted request') - try: - response = self.response_queue.get(block=block, timeout=timeout) - except Exception as e: # TODO: Catch only timeout exceptions - raise EmptyQueue('queue is empty') - self.request_served(response) - - # TODO(VitalyFedyunin): Add possible response types validation here - return response diff --git a/torch/utils/data/communication/queue.py b/torch/utils/data/communication/queue.py deleted file mode 100644 index 85c33d4799cd8ab..000000000000000 --- a/torch/utils/data/communication/queue.py +++ /dev/null @@ -1,51 +0,0 @@ -import threading -import time - - -class LocalQueue(): - ops = 0 - stored = 0 - uid = 0 - empty = 0 - - def __init__(self, name='unnamed'): - self.items = [] - self.name = name - self.uid = LocalQueue.uid - LocalQueue.uid += 1 - - def put(self, item, block=True): - LocalQueue.ops += 1 - LocalQueue.stored += 1 - self.items.append(item) - - def get(self, block=True, timeout=0): - # TODO(VitalyFedyunin): Add support of block and timeout arguments - LocalQueue.ops += 1 - if not len(self.items): - LocalQueue.empty += 1 - raise Exception('LocalQueue is empty') - LocalQueue.stored -= 1 - return self.items.pop() - - -class ThreadingQueue(): - def __init__(self, name='unnamed'): - self.lock = threading.Lock() - self.items = [] - self.name = name - - def put(self, item, block=True): - with self.lock: - self.items.append(item) - - def get(self, block=True, timeout=0): - # TODO(VitalyFedyunin): Add support of block and timeout arguments - while True: - with self.lock: - if len(self.items) > 0: - return self.items.pop() - if not block: - raise Exception("Not available") - # TODO(VitalyFedyunin): Figure out what to do if nothing in the queue - time.sleep(0.000001) diff --git a/torch/utils/data/dataloader_experimental.py b/torch/utils/data/dataloader_experimental.py deleted file mode 100644 index 8a8d536b79857c7..000000000000000 --- a/torch/utils/data/dataloader_experimental.py +++ /dev/null @@ -1,150 +0,0 @@ -import time - -from typing import Any, List - -import torch.utils.data.backward_compatibility - -import torch.utils.data.graph_settings -from torch.utils.data import DataLoader, IterDataPipe, communication -from torch.utils.data.datapipes.iter import IterableWrapper - -__all__ = [ - "DataLoader2", -] - - -class _ThreadingDataLoader2: - - def __init__(self, datapipe, num_workers=0, collate_fn=None): - self.threads = [] - self.datapipes = [] - self.collate_fn = collate_fn - for worker_id in range(num_workers): - (thread, req_queue, res_queue, thread_localdatapipe) = communication.eventloop.SpawnThreadForDataPipeline(datapipe) - torch.utils.data.graph_settings.apply_sharding(thread_localdatapipe, num_workers, worker_id) - thread.start() - self.threads.append((thread, req_queue, res_queue)) # These queues are independent - local_datapipe = communication.iter.QueueWrapper( - communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) - self.datapipes.append(local_datapipe) - - def __iter__(self): - not_available = False - forever = True - exclude_datapipes: List[Any] = [] - while len(exclude_datapipes) < len(self.datapipes): - for dp in self.datapipes: - if dp not in exclude_datapipes: - try: - value = dp.nonblocking_next() - yield value - except StopIteration: - exclude_datapipes.append(dp) - except communication.iter.NotAvailable: - not_available = True - if not_available: - time.sleep(0.001) - - def __del__(self): - self._cleanup_all_threads() - - def _cleanup_all_threads(self): - def clean_me(thread, req_queue, res_queue): - req_queue.put(communication.messages.TerminateRequest()) - _ = res_queue.get() - thread.join() - - for thread, req_queue, res_queue in self.threads: - clean_me(thread, req_queue, res_queue) - -class DataLoader2: - def __new__(cls, - dataset, - batch_size=1, - shuffle=None, - sampler=None, - batch_sampler=None, - num_workers=0, - collate_fn=None, - pin_memory=False, - drop_last=False, - timeout=0, - worker_init_fn=None, - *, - prefetch_factor=2, - persistent_workers=False, - batch_outside_worker=False, - parallelism_mode='mp'): - if isinstance(dataset, IterDataPipe): - data_loader: Any = None - if batch_sampler is not None: - raise Exception( - 'batch_sampler is not yet supported by DataPipes') - if sampler is not None: - raise Exception( - 'sampler is not yet supported by DataPipes') - datapipe = dataset - datapipe = torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=shuffle) # type: ignore[assignment] - if batch_outside_worker and pin_memory: - raise Exception( - 'pin_memory is not yet compatible with batch_outside_worker') - if not batch_outside_worker: - if batch_size is not None: - datapipe = datapipe.batch(batch_size, drop_last=drop_last) - if collate_fn is None: - collate_fn = torch.utils.data._utils.collate.default_collate - - # Note: It is safe to pass shuffle=True to the old DataLoader, as shuffle does nothing - # for Iterable, but required to set Pipes correctly. - data_loader = DataLoader(datapipe, - batch_size=None, # Replaced by .batch DataPipe - shuffle=shuffle, - sampler=None, - batch_sampler=None, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=False, # Replaced by .batch DataPipe - timeout=timeout, - worker_init_fn=worker_init_fn, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers) - elif parallelism_mode == 'thread': - if collate_fn is not None and not batch_outside_worker: - datapipe = datapipe.map(collate_fn) - if pin_memory: - raise Exception( - 'pin_memory is not yet supported by DataPipes with Threading') - if worker_init_fn is not None: - raise Exception( - 'worker_init_fn is not yet supported by DataPipes with Threading') - data_loader = _ThreadingDataLoader2(datapipe, - num_workers=num_workers, - collate_fn=collate_fn) - else: - raise Exception('Unsupported parallelism mode', parallelism_mode) - if not batch_outside_worker: - return data_loader - else: - if collate_fn is None: - collate_fn = torch.utils.data._utils.collate.default_collate - datapipe = IterableWrapper(data_loader).batch( - batch_size, drop_last=drop_last).map(collate_fn) - return datapipe - else: - if parallelism_mode == 'thread': - raise Exception( - 'thread parallelism mode is not supported for old DataSets') - return DataLoader(dataset, - batch_size=batch_size, - shuffle=shuffle, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=drop_last, - timeout=timeout, - worker_init_fn=worker_init_fn, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers) From 6fe47b682fe1ba2dd2c7da02ff1bb06f8670e3a7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Nov 2022 22:31:32 +0000 Subject: [PATCH 42/62] [Dynamo] Fix str(Guard.obj_weakref) bug to re-ennable support overriding __getattr__ (#88564) See my inline comments! Pull Request resolved: https://github.com/pytorch/pytorch/pull/88564 Approved by: https://github.com/ezyang, https://github.com/anijain2305 --- test/dynamo/test_misc.py | 2 -- torch/_dynamo/guards.py | 27 ++++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 4df7153b8fb2b48..a8bf86e46411ba2 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -579,8 +579,6 @@ def fn(count): self.assertEqual(cnts.frame_count, 0) self.assertEqual(cnts.op_count, 0) - # KeyError: '__name__' - @patch.object(torch._dynamo.config, "suppress_errors", True) def test_user_getattr1(self): class MyConfig(dict): def __getattr__(self, name): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 9edd6f60560df87..382734412b2badc 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -101,13 +101,38 @@ def sort_key(self): def __lt__(self, other): return self.sort_key() < other.sort_key() + @staticmethod + def weakref_to_str(obj_weakref): + """ + This is a workaround of a Python weakref bug. + + `obj_weakref` is instance returned by `weakref.ref`, + `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g: + + class MyConfig(dict): + def __getattr__(self, x): + return self[x] + + obj = MyConfig(offset=5) + obj_weakref = weakref.ref(obj) + str(obj_weakref) # raise error: KeyError: '__name__' + """ + if isinstance(obj_weakref, weakref.ReferenceType): + obj = obj_weakref() + if obj is not None: + return f"" + else: + return f"" + else: + return str(obj_weakref) + def __str__(self): s = f""" {self.source.name.lower()} {repr(self.name)} {self.create_fn.__name__} {{ 'guard_types': {self.guard_types}, 'code': {self.code_list}, - 'obj_weakref': {self.obj_weakref} + 'obj_weakref': {self.weakref_to_str(self.obj_weakref)} 'guarded_class': {self.guarded_class_weakref} }} """ From a7fa423f48af8af220e9286a6b4c374d533f77e0 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Fri, 11 Nov 2022 14:41:35 +0000 Subject: [PATCH 43/62] copy_: Short-circuit when self and src view the same data (#88884) This comes up if you use inplace operators on a slice, e.g. ```python import torch a = torch.rand(1000000, device="cuda") a[::2] *= 2 ``` The last line looks as if it should be fully inplace, but is actually equivalent to: ```python tmp = a[::2] tmp *= 2 a[::2] = tmp ``` Which results in `mul_` and `copy_` being called. With this PR, the redundant copy becomes a no-op and the above example is 2x faster. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88884 Approved by: https://github.com/ngimel --- aten/src/ATen/native/Copy.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index a44f39c5bb2ebfa..c6b82426d3bf670 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -220,6 +220,18 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) return at::metal::metal_copy_(self, src); } + // Exit early if self and src are views of the same data + const bool is_same_data = ( + self.is_alias_of(src) && + self.storage_offset() == src.storage_offset() && + self.strides().equals(src.strides()) && + self.sizes().equals(src.sizes()) && + self.scalar_type() == src.scalar_type() + ); + if (is_same_data) { + return self; + } + auto iter = TensorIteratorConfig() .add_output(self) From 7c3adddd6c3fe1bda4a9e5bfb9f992a802329551 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Wed, 9 Nov 2022 12:20:16 -0800 Subject: [PATCH 44/62] [functorch] delete some unused files (#88763) Some post-merge cleanup. - packaging/ was for building standalone windows binaries - our flake8 config got superceded by PyTorch's. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88763 Approved by: https://github.com/samdow --- functorch/.flake8 | 20 - functorch/packaging/build_wheel.sh | 19 - functorch/packaging/pkg_helpers.bash | 414 ------------------ .../windows/internal/cuda_install.bat | 264 ----------- .../windows/internal/driver_update.bat | 25 -- .../windows/internal/vc_env_helper.bat | 43 -- .../windows/internal/vc_install_helper.sh | 16 - 7 files changed, 801 deletions(-) delete mode 100644 functorch/.flake8 delete mode 100644 functorch/packaging/build_wheel.sh delete mode 100644 functorch/packaging/pkg_helpers.bash delete mode 100644 functorch/packaging/windows/internal/cuda_install.bat delete mode 100644 functorch/packaging/windows/internal/driver_update.bat delete mode 100644 functorch/packaging/windows/internal/vc_env_helper.bat delete mode 100644 functorch/packaging/windows/internal/vc_install_helper.sh diff --git a/functorch/.flake8 b/functorch/.flake8 deleted file mode 100644 index a6d73773e3b5566..000000000000000 --- a/functorch/.flake8 +++ /dev/null @@ -1,20 +0,0 @@ -[flake8] -select = B,C,E,F,P,T4,W,B9 -max-line-length = 120 -# C408 ignored because we like the dict keyword argument syntax -# E501 is not flexible enough, we're using B950 instead -ignore = - E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, - # shebang has extra meaning in fbcode lints, so I think it's not worth trying - # to line this up with executable bit - EXE001, - # these ignores are from flake8-bugbear; please fix! - B007,B008, - # these ignores are from flake8-comprehensions; please fix! - C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 -exclude = - ./.git, - ./benchmarks, - ./docs, - ./examples, - ./notebooks diff --git a/functorch/packaging/build_wheel.sh b/functorch/packaging/build_wheel.sh deleted file mode 100644 index 074e7dde771417b..000000000000000 --- a/functorch/packaging/build_wheel.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -set -ex - -script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -. "$script_dir/pkg_helpers.bash" - -export BUILD_TYPE=wheel -setup_env 0.2.0 -setup_wheel_python -pip_install numpy pyyaml future ninja -pip_install --upgrade setuptools -setup_pip_pytorch_version -python setup.py clean - -if [[ "$OSTYPE" == "msys" ]]; then - "$script_dir/windows/internal/vc_env_helper.bat" python setup.py bdist_wheel -else - python setup.py bdist_wheel -fi diff --git a/functorch/packaging/pkg_helpers.bash b/functorch/packaging/pkg_helpers.bash deleted file mode 100644 index 329891a07216cbf..000000000000000 --- a/functorch/packaging/pkg_helpers.bash +++ /dev/null @@ -1,414 +0,0 @@ -# A set of useful bash functions for common functionality we need to do in -# many build scripts - - -# Setup CUDA environment variables, based on CU_VERSION -# -# Inputs: -# CU_VERSION (cpu, cu92, cu100) -# NO_CUDA_PACKAGE (bool) -# BUILD_TYPE (conda, wheel) -# -# Outputs: -# VERSION_SUFFIX (e.g., "") -# PYTORCH_VERSION_SUFFIX (e.g., +cpu) -# WHEEL_DIR (e.g., cu100/) -# CUDA_HOME (e.g., /usr/local/cuda-9.2, respected by torch.utils.cpp_extension) -# FORCE_CUDA (respected by torchvision setup.py) -# NVCC_FLAGS (respected by torchvision setup.py) -# -# Precondition: CUDA versions are installed in their conventional locations in -# /usr/local/cuda-* -# -# NOTE: Why VERSION_SUFFIX versus PYTORCH_VERSION_SUFFIX? If you're building -# a package with CUDA on a platform we support CUDA on, VERSION_SUFFIX == -# PYTORCH_VERSION_SUFFIX and everyone is happy. However, if you are building a -# package with only CPU bits (e.g., torchaudio), then VERSION_SUFFIX is always -# empty, but PYTORCH_VERSION_SUFFIX is +cpu (because that's how you get a CPU -# version of a Python package. But that doesn't apply if you're on OS X, -# since the default CU_VERSION on OS X is cpu. -setup_cuda() { - - # First, compute version suffixes. By default, assume no version suffixes - export VERSION_SUFFIX="" - export PYTORCH_VERSION_SUFFIX="" - export WHEEL_DIR="" - # Wheel builds need suffixes (but not if they're on OS X, which never has suffix) - if [[ "$BUILD_TYPE" == "wheel" ]] && [[ "$(uname)" != Darwin ]]; then - export PYTORCH_VERSION_SUFFIX="+$CU_VERSION" - # Match the suffix scheme of pytorch, unless this package does not have - # CUDA builds (in which case, use default) - if [[ -z "$NO_CUDA_PACKAGE" ]]; then - export VERSION_SUFFIX="$PYTORCH_VERSION_SUFFIX" - export WHEEL_DIR="$CU_VERSION/" - fi - fi - - # Now work out the CUDA settings - case "$CU_VERSION" in - cu115) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.5" - else - export CUDA_HOME=/usr/local/cuda-11.5/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu113) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.3" - else - export CUDA_HOME=/usr/local/cuda-11.3/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu112) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.2" - else - export CUDA_HOME=/usr/local/cuda-11.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu111) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.1" - else - export CUDA_HOME=/usr/local/cuda-11.1/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu110) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.0" - else - export CUDA_HOME=/usr/local/cuda-11.0/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0" - ;; - cu102) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.2" - else - export CUDA_HOME=/usr/local/cuda-10.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu101) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.1" - else - export CUDA_HOME=/usr/local/cuda-10.1/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu100) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.0" - else - export CUDA_HOME=/usr/local/cuda-10.0/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu92) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v9.2" - else - export CUDA_HOME=/usr/local/cuda-9.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0" - ;; - cpu) - ;; - rocm*) - export FORCE_CUDA=1 - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - if [[ -n "$CUDA_HOME" ]]; then - # Adds nvcc binary to the search path so that CMake's `find_package(CUDA)` will pick the right one - export PATH="$CUDA_HOME/bin:$PATH" - export FORCE_CUDA=1 - fi -} - -# Populate build version if necessary, and add version suffix -# -# Inputs: -# BUILD_VERSION (e.g., 0.2.0 or empty) -# VERSION_SUFFIX (e.g., +cpu) -# -# Outputs: -# BUILD_VERSION (e.g., 0.2.0.dev20190807+cpu) -# -# Fill BUILD_VERSION if it doesn't exist already with a nightly string -# Usage: setup_build_version 0.2.0 -setup_build_version() { - if [[ -z "$BUILD_VERSION" ]]; then - export BUILD_VERSION="$1.dev$(date "+%Y%m%d")$VERSION_SUFFIX" - else - export BUILD_VERSION="$BUILD_VERSION$VERSION_SUFFIX" - fi - - # Set build version based on tag if on tag - if [[ -n "${CIRCLE_TAG}" ]]; then - # Strip tag - export BUILD_VERSION="$(echo "${CIRCLE_TAG}" | sed -e 's/^v//' -e 's/-.*$//')${VERSION_SUFFIX}" - fi -} - -# Set some useful variables for OS X, if applicable -setup_macos() { - if [[ "$(uname)" == Darwin ]]; then - export MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ - fi -} - - -# Top-level entry point for things every package will need to do -# -# Usage: setup_env 0.2.0 -setup_env() { - setup_cuda - setup_build_version "$1" - setup_macos -} - -# Function to retry functions that sometimes timeout or have flaky failures -retry () { - $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) -} - -# Inputs: -# PYTHON_VERSION (3.7, 3.8, 3.9) -# UNICODE_ABI (bool) -# -# Outputs: -# PATH modified to put correct Python version in PATH -# -# Precondition: If Linux, you are in a soumith/manylinux-cuda* Docker image -setup_wheel_python() { - if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then - eval "$(conda shell.bash hook)" - conda env remove -n "env$PYTHON_VERSION" || true - conda create ${CONDA_CHANNEL_FLAGS} -yn "env$PYTHON_VERSION" python="$PYTHON_VERSION" - conda activate "env$PYTHON_VERSION" - # Install libpng from Anaconda (defaults) - conda install ${CONDA_CHANNEL_FLAGS} libpng "jpeg<=9b" -y - else - # Install native CentOS libJPEG, freetype and GnuTLS - yum install -y libjpeg-turbo-devel freetype gnutls - case "$PYTHON_VERSION" in - 3.7) python_abi=cp37-cp37m ;; - 3.8) python_abi=cp38-cp38 ;; - 3.9) python_abi=cp39-cp39 ;; - 3.10) python_abi=cp310-cp310 ;; - *) - echo "Unrecognized PYTHON_VERSION=$PYTHON_VERSION" - exit 1 - ;; - esac - # Download all the dependencies required to compile image and video_reader - # extensions - - mkdir -p ext_libraries - pushd ext_libraries - popd - export PATH="/opt/python/$python_abi/bin:$(pwd)/ext_libraries/bin:$PATH" - fi -} - -# Install with pip a bit more robustly than the default -pip_install() { - retry pip install --progress-bar off "$@" -} - -# Install torch with pip, respecting PYTORCH_VERSION, and record the installed -# version into PYTORCH_VERSION, if applicable -setup_pip_pytorch_version() { - if [[ -z "$PYTORCH_VERSION" ]]; then - # Install latest prerelease version of torch, per our nightlies, consistent - # with the requested cuda version - pip_install --pre torch -f "https://download.pytorch.org/whl/nightly/${WHEEL_DIR}torch_nightly.html" - if [[ "$CUDA_VERSION" == "cpu" ]]; then - # CUDA and CPU are ABI compatible on the CPU-only parts, so strip - # in this case - export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//' | sed 's/+.\+//')" - else - export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//')" - fi - else - pip_install "torch==$PYTORCH_VERSION$PYTORCH_VERSION_SUFFIX" \ - -f "https://download.pytorch.org/whl/${CU_VERSION}/torch_stable.html" \ - -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/${CU_VERSION}/torch_${UPLOAD_CHANNEL}.html" - fi -} - -# Fill PYTORCH_VERSION with the latest conda nightly version, and -# CONDA_CHANNEL_FLAGS with appropriate flags to retrieve these versions -# -# You MUST have populated PYTORCH_VERSION_SUFFIX before hand. -setup_conda_pytorch_constraint() { - if [[ -z "$PYTORCH_VERSION" ]]; then - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-nightly -c pytorch" - export PYTORCH_VERSION="$(conda search --json 'pytorch[channel=pytorch-nightly]' | \ - python -c "import os, sys, json, re; cuver = os.environ.get('CU_VERSION'); \ - cuver_1 = cuver.replace('cu', 'cuda') if cuver != 'cpu' else cuver; \ - cuver_2 = (cuver[:-1] + '.' + cuver[-1]).replace('cu', 'cuda') if cuver != 'cpu' else cuver; \ - print(re.sub(r'\\+.*$', '', \ - [x['version'] for x in json.load(sys.stdin)['pytorch'] \ - if (x['platform'] == 'darwin' or cuver_1 in x['fn'] or cuver_2 in x['fn']) \ - and 'py' + os.environ['PYTHON_VERSION'] in x['fn']][-1]))")" - if [[ -z "$PYTORCH_VERSION" ]]; then - echo "PyTorch version auto detection failed" - echo "No package found for CU_VERSION=$CU_VERSION and PYTHON_VERSION=$PYTHON_VERSION" - exit 1 - fi - else - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch -c pytorch-${UPLOAD_CHANNEL}" - fi - if [[ "$CU_VERSION" == cpu ]]; then - export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==$PYTORCH_VERSION${PYTORCH_VERSION_SUFFIX}" - export CONDA_PYTORCH_CONSTRAINT="- pytorch==$PYTORCH_VERSION" - else - export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" - export CONDA_PYTORCH_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" - fi - if [[ "$OSTYPE" == msys && "$CU_VERSION" == cu92 ]]; then - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c defaults -c numba/label/dev" - fi -} - -# Translate CUDA_VERSION into CUDA_CUDATOOLKIT_CONSTRAINT -setup_conda_cudatoolkit_constraint() { - export CONDA_BUILD_VARIANT="cuda" - if [[ "$(uname)" == Darwin ]]; then - export CONDA_BUILD_VARIANT="cpu" - else - case "$CU_VERSION" in - cu115) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.5,<11.6 # [not osx]" - ;; - cu113) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.3,<11.4 # [not osx]" - ;; - cu112) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.2,<11.3 # [not osx]" - ;; - cu111) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.1,<11.2 # [not osx]" - ;; - cu110) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.0,<11.1 # [not osx]" - ;; - cu102) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.2,<10.3 # [not osx]" - ;; - cu101) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.1,<10.2 # [not osx]" - ;; - cu100) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.0,<10.1 # [not osx]" - ;; - cu92) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=9.2,<9.3 # [not osx]" - ;; - cpu) - export CONDA_CUDATOOLKIT_CONSTRAINT="" - export CONDA_BUILD_VARIANT="cpu" - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - fi -} - -setup_conda_cudatoolkit_plain_constraint() { - export CONDA_BUILD_VARIANT="cuda" - export CMAKE_USE_CUDA=1 - if [[ "$(uname)" == Darwin ]]; then - export CONDA_BUILD_VARIANT="cpu" - export CMAKE_USE_CUDA=0 - else - case "$CU_VERSION" in - cu115) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.5" - ;; - cu113) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.3" - ;; - cu112) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.2" - ;; - cu111) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.1" - ;; - cu102) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.2" - ;; - cu101) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.1" - ;; - cu100) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.0" - ;; - cu92) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=9.2" - ;; - cpu) - export CONDA_CUDATOOLKIT_CONSTRAINT="" - export CONDA_BUILD_VARIANT="cpu" - export CMAKE_USE_CUDA=0 - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - fi -} - -# Build the proper compiler package before building the final package -setup_visual_studio_constraint() { - if [[ "$OSTYPE" == "msys" ]]; then - export VSTOOLCHAIN_PACKAGE=vs$VC_YEAR - conda build $CONDA_CHANNEL_FLAGS --no-anaconda-upload packaging/$VSTOOLCHAIN_PACKAGE - cp packaging/$VSTOOLCHAIN_PACKAGE/conda_build_config.yaml packaging/torchvision/conda_build_config.yaml - fi -} - -setup_junit_results_folder() { - if [[ "$CI" == "true" ]]; then - export CONDA_PYTORCH_BUILD_RESULTS_DIRECTORY="${SOURCE_ROOT_DIR}/build_results/results.xml" - fi -} - - -download_copy_ffmpeg() { - if [[ "$OSTYPE" == "msys" ]]; then - # conda install -yq ffmpeg=4.2 -c pytorch - # curl -L -q https://anaconda.org/pytorch/ffmpeg/4.3/download/win-64/ffmpeg-4.3-ha925a31_0.tar.bz2 --output ffmpeg-4.3-ha925a31_0.tar.bz2 - # bzip2 --decompress --stdout ffmpeg-4.3-ha925a31_0.tar.bz2 | tar -x --file=- - # cp Library/bin/*.dll ../torchvision - echo "FFmpeg is disabled currently on Windows" - else - if [[ "$(uname)" == Darwin ]]; then - conda install -yq ffmpeg=4.2 -c pytorch - conda install -yq wget - else - # pushd ext_libraries - # wget -q https://anaconda.org/pytorch/ffmpeg/4.2/download/linux-64/ffmpeg-4.2-hf484d3e_0.tar.bz2 - # tar -xjvf ffmpeg-4.2-hf484d3e_0.tar.bz2 - # rm -rf ffmpeg-4.2-hf484d3e_0.tar.bz2 - # ldconfig - # which ffmpeg - # popd - echo "FFmpeg is disabled currently on Linux" - fi - fi -} diff --git a/functorch/packaging/windows/internal/cuda_install.bat b/functorch/packaging/windows/internal/cuda_install.bat deleted file mode 100644 index 41960224ebaedff..000000000000000 --- a/functorch/packaging/windows/internal/cuda_install.bat +++ /dev/null @@ -1,264 +0,0 @@ -@echo on - -if "%CU_VERSION%" == "cpu" ( - echo Skipping for CPU builds - exit /b 0 -) - -set SRC_DIR=%~dp0\.. - -if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build" - -rem in unit test workflow, we get CUDA_VERSION, for example 11.1 -if defined CUDA_VERSION ( - set CUDA_VER=%CUDA_VERSION:.=% -) else ( - set CUDA_VER=%CU_VERSION:cu=% -) - -set /a CUDA_VER=%CU_VERSION:cu=% -set CUDA_VER_MAJOR=%CUDA_VER:~0,-1% -set CUDA_VER_MINOR=%CUDA_VER:~-1,1% -set CUDA_VERSION_STR=%CUDA_VER_MAJOR%.%CUDA_VER_MINOR% - - -if %CUDA_VER% EQU 92 goto cuda92 -if %CUDA_VER% EQU 100 goto cuda100 -if %CUDA_VER% EQU 101 goto cuda101 -if %CUDA_VER% EQU 102 goto cuda102 -if %CUDA_VER% EQU 110 goto cuda110 -if %CUDA_VER% EQU 111 goto cuda111 -if %CUDA_VER% EQU 112 goto cuda112 -if %CUDA_VER% EQU 113 goto cuda113 -if %CUDA_VER% EQU 115 goto cuda115 - - -echo CUDA %CUDA_VERSION_STR% is not supported -exit /b 1 - -:cuda92 -if not exist "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_9.2.148_win10.exe --output "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" - set "ARGS=nvcc_9.2 cuobjdump_9.2 nvprune_9.2 cupti_9.2 cublas_9.2 cublas_dev_9.2 cudart_9.2 cufft_9.2 cufft_dev_9.2 curand_9.2 curand_dev_9.2 cusolver_9.2 cusolver_dev_9.2 cusparse_9.2 cusparse_dev_9.2 nvgraph_9.2 nvgraph_dev_9.2 npp_9.2 npp_dev_9.2 nvrtc_9.2 nvrtc_dev_9.2 nvml_dev_9.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-9.2-windows10-x64-v7.2.1.38.zip --output "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" -) - -goto cuda_common - -:cuda100 - -if not exist "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_10.0.130_411.31_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" - set "ARGS=nvcc_10.0 cuobjdump_10.0 nvprune_10.0 cupti_10.0 cublas_10.0 cublas_dev_10.0 cudart_10.0 cufft_10.0 cufft_dev_10.0 curand_10.0 curand_dev_10.0 cusolver_10.0 cusolver_dev_10.0 cusparse_10.0 cusparse_dev_10.0 nvgraph_10.0 nvgraph_dev_10.0 npp_10.0 npp_dev_10.0 nvrtc_10.0 nvrtc_dev_10.0 nvml_dev_10.0" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-10.0-windows10-x64-v7.4.1.5.zip --output "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" -) - -goto cuda_common - -:cuda101 - -if not exist "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.1.243_426.00_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" - set "ARGS=nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvjpeg_10.1 nvjpeg_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.1-windows10-x64-v7.6.4.38.zip --output "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" -) - -goto cuda_common - -:cuda102 - -if not exist "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.2.89_441.22_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" - set "ARGS=nvcc_10.2 cuobjdump_10.2 nvprune_10.2 cupti_10.2 cublas_10.2 cublas_dev_10.2 cudart_10.2 cufft_10.2 cufft_dev_10.2 curand_10.2 curand_dev_10.2 cusolver_10.2 cusolver_dev_10.2 cusparse_10.2 cusparse_dev_10.2 nvgraph_10.2 nvgraph_dev_10.2 npp_10.2 npp_dev_10.2 nvjpeg_10.2 nvjpeg_dev_10.2 nvrtc_10.2 nvrtc_dev_10.2 nvml_dev_10.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.2-windows10-x64-v7.6.5.32.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" -) - -rem The below only for cu102, if it's used in other version, e.g. cu111, torch.cuda.is_availabe() would be False. -if not exist "%SRC_DIR%\temp_build\gpu_driver_dlls.7z" ( - curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "%SRC_DIR%\temp_build\gpu_driver_dlls.zip" - if errorlevel 1 exit /b 1 -) - -echo Installing GPU driver DLLs -7z x %SRC_DIR%\temp_build\gpu_driver_dlls.zip -aoa -o"C:\Windows\System32" - -goto cuda_common - -:cuda110 - -if not exist "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.0.2_451.48_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" - set "ARGS=nvcc_11.0 cuobjdump_11.0 nvprune_11.0 nvprof_11.0 cupti_11.0 cublas_11.0 cublas_dev_11.0 cudart_11.0 cufft_11.0 cufft_dev_11.0 curand_11.0 curand_dev_11.0 cusolver_11.0 cusolver_dev_11.0 cusparse_11.0 cusparse_dev_11.0 npp_11.0 npp_dev_11.0 nvjpeg_11.0 nvjpeg_dev_11.0 nvrtc_11.0 nvrtc_dev_11.0 nvml_dev_11.0" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.0-windows-x64-v8.0.4.30.zip --output "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" -) - -goto cuda_common - -:cuda111 - -if not exist "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.1.1_456.81_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" - set "ARGS=nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvjpeg_11.1 nvjpeg_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.1-windows-x64-v8.0.5.39.zip --output "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" -) - -goto cuda_common - -:cuda112 - -if not exist "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.2.0_460.89_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" - set "ARGS=nvcc_11.2 cuobjdump_11.2 nvprune_11.2 nvprof_11.2 cupti_11.2 cublas_11.2 cublas_dev_11.2 cudart_11.2 cufft_11.2 cufft_dev_11.2 curand_11.2 curand_dev_11.2 cusolver_11.2 cusolver_dev_11.2 cusparse_11.2 cusparse_dev_11.2 npp_11.2 npp_dev_11.2 nvjpeg_11.2 nvjpeg_dev_11.2 nvrtc_11.2 nvrtc_dev_11.2 nvml_dev_11.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" ( - curl -k -L http://s3.amazonaws.com/ossci-windows/cudnn-11.2-windows-x64-v8.1.0.77.zip --output "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" -) - -goto cuda_common - -:cuda113 - -set CUDA_INSTALL_EXE=cuda_11.3.0_465.89_win10.exe -if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( - curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=thrust_11.3 nvcc_11.3 cuobjdump_11.3 nvprune_11.3 nvprof_11.3 cupti_11.3 cublas_11.3 cublas_dev_11.3 cudart_11.3 cufft_11.3 cufft_dev_11.3 curand_11.3 curand_dev_11.3 cusolver_11.3 cusolver_dev_11.3 cusparse_11.3 cusparse_dev_11.3 npp_11.3 npp_dev_11.3 nvjpeg_11.3 nvjpeg_dev_11.3 nvrtc_11.3 nvrtc_dev_11.3 nvml_dev_11.3" - -) - -set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip -if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( - curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" -) - -goto cuda_common - -:cuda115 - -set CUDA_INSTALL_EXE=cuda_11.5.0_496.13_win10.exe -if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( - curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=thrust_11.5 nvcc_11.5 cuobjdump_11.5 nvprune_11.5 nvprof_11.5 cupti_11.5 cublas_11.5 cublas_dev_11.5 cudart_11.5 cufft_11.5 cufft_dev_11.5 curand_11.5 curand_dev_11.5 cusolver_11.5 cusolver_dev_11.5 cusparse_11.5 cusparse_dev_11.5 npp_11.5 npp_dev_11.5 nvrtc_11.5 nvrtc_dev_11.5 nvml_dev_11.5" -) - -set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip -if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( - curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" -) - -goto cuda_common - -:cuda_common - -if not exist "%SRC_DIR%\temp_build\NvToolsExt.7z" ( - curl -k -L https://www.dropbox.com/s/9mcolalfdj4n979/NvToolsExt.7z?dl=1 --output "%SRC_DIR%\temp_build\NvToolsExt.7z" - if errorlevel 1 exit /b 1 -) - -echo Installing CUDA toolkit... -7z x %CUDA_SETUP_FILE% -o"%SRC_DIR%\temp_build\cuda" -pushd "%SRC_DIR%\temp_build\cuda" -sc config wuauserv start= disabled -sc stop wuauserv -sc query wuauserv - -start /wait setup.exe -s %ARGS% -loglevel:6 -log:"%cd%/cuda_install_logs" -echo %errorlevel% - -popd - -echo Installing VS integration... -rem It's for VS 2019 -if "%CUDA_VER_MAJOR%" == "10" ( - xcopy /Y "%SRC_DIR%\temp_build\cuda\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" -) -if "%CUDA_VER_MAJOR%" == "11" ( - xcopy /Y "%SRC_DIR%\temp_build\cuda\visual_studio_integration\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" -) - -echo Installing NvToolsExt... -7z x %SRC_DIR%\temp_build\NvToolsExt.7z -o"%SRC_DIR%\temp_build\NvToolsExt" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\bin\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\include\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\lib\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" - -echo Setting up environment... -set "PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\libnvvp;%PATH%" -set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" -set "CUDA_PATH_V%CUDA_VER_MAJOR%_%CUDA_VER_MINOR%=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" -set "NVTOOLSEXT_PATH=%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" - -if not exist "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin\nvcc.exe" ( - echo CUDA %CUDA_VERSION_STR% installed failed. - echo --------- RunDll32.exe.log - type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.RunDll32.exe.log" - echo --------- setup.exe.log ------- - type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.setup.exe.log" - exit /b 1 -) - -echo Installing cuDNN... -7z x %CUDNN_SETUP_FILE% -o"%SRC_DIR%\temp_build\cudnn" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\bin\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\lib\x64\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\lib\x64" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\include\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\include" - -echo Cleaning temp files -rd /s /q "%SRC_DIR%\temp_build" || ver > nul diff --git a/functorch/packaging/windows/internal/driver_update.bat b/functorch/packaging/windows/internal/driver_update.bat deleted file mode 100644 index 00b43affc01cc30..000000000000000 --- a/functorch/packaging/windows/internal/driver_update.bat +++ /dev/null @@ -1,25 +0,0 @@ -set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe" -curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -if errorlevel 1 exit /b 1 - -start /wait 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -s -noreboot -if errorlevel 1 exit /b 1 - -del 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe || ver > NUL - -setlocal EnableDelayedExpansion -set NVIDIA_GPU_EXISTS=0 -for /F "delims=" %%i in ('wmic path win32_VideoController get name') do ( - set GPUS=%%i - if not "x!GPUS:NVIDIA=!" == "x!GPUS!" ( - SET NVIDIA_GPU_EXISTS=1 - goto gpu_check_end - ) -) -:gpu_check_end -endlocal & set NVIDIA_GPU_EXISTS=%NVIDIA_GPU_EXISTS% - -if "%NVIDIA_GPU_EXISTS%" == "0" ( - echo "CUDA Driver installation Failed" - exit /b 1 -) diff --git a/functorch/packaging/windows/internal/vc_env_helper.bat b/functorch/packaging/windows/internal/vc_env_helper.bat deleted file mode 100644 index e85a372f93d58c8..000000000000000 --- a/functorch/packaging/windows/internal/vc_env_helper.bat +++ /dev/null @@ -1,43 +0,0 @@ -@echo on - -set VC_VERSION_LOWER=16 -set VC_VERSION_UPPER=17 -if "%VC_YEAR%" == "2017" ( - set VC_VERSION_LOWER=15 - set VC_VERSION_UPPER=16 -) - -for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( - if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( - set "VS15INSTALLDIR=%%i" - set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" - goto vswhere - ) -) - -:vswhere -if "%VSDEVCMD_ARGS%" == "" ( - call "%VS15VCVARSALL%" x64 || exit /b 1 -) else ( - call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 -) - -@echo on - -set DISTUTILS_USE_SDK=1 - -set args=%1 -shift -:start -if [%1] == [] goto done -set args=%args% %1 -shift -goto start - -:done -if "%args%" == "" ( - echo Usage: vc_env_helper.bat [command] [args] - echo e.g. vc_env_helper.bat cl /c test.cpp -) - -%args% || exit /b 1 diff --git a/functorch/packaging/windows/internal/vc_install_helper.sh b/functorch/packaging/windows/internal/vc_install_helper.sh deleted file mode 100644 index cdae18065b9f6e9..000000000000000 --- a/functorch/packaging/windows/internal/vc_install_helper.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -set -ex - -if [[ "$CU_VERSION" == "cu92" ]]; then - export VC_YEAR=2017 - export VSDEVCMD_ARGS="-vcvars_ver=14.13" - powershell packaging/windows/internal/vs2017_install.ps1 -elif [[ "$CU_VERSION" == "cu100" ]]; then - export VC_YEAR=2017 - export VSDEVCMD_ARGS="" - powershell packaging/windows/internal/vs2017_install.ps1 -else - export VC_YEAR=2019 - export VSDEVCMD_ARGS="" -fi From 37c5b42fa6597ebf7dbfb6db4ada2c7803950555 Mon Sep 17 00:00:00 2001 From: Horace He Date: Fri, 11 Nov 2022 19:17:47 +0000 Subject: [PATCH 45/62] Fix matmul decomp to use reshape instead of contiguous().view() (#88832) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88832 Approved by: https://github.com/bertmaher, https://github.com/ngimel --- torch/_decomp/decompositions.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index fe63e0db007a7b1..1a2d332e99fd958 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2261,9 +2261,7 @@ def matmul(tensor1, tensor2): t2_is_matrix = t2.dim() == 2 if t2_is_matrix: output_shape.append(t2.shape[1]) - # HACK: We need reshape with symint support - t1 = t1.contiguous() - t1_folded = t1.view(folded_dim1, sizes_1[-1]) + t1_folded = t1.reshape(folded_dim1, sizes_1[-1]) if t2_is_matrix: # FIXME This path always does an unnecessary copy when transpose == True as the returned # result from BLAS is already C-transposed @@ -2296,15 +2294,11 @@ def matmul(tensor1, tensor2): expand_batch_product = prod(expand_batch_portion) # HACK: We need reshape with symint support - tensor1_expanded = ( - tensor1.expand(tensor1_expand_size) - .contiguous() - .view(expand_batch_product, n, m1) + tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape( + expand_batch_product, n, m1 ) - tensor2_expanded = ( - tensor2.expand(tensor2_expand_size) - .contiguous() - .view(expand_batch_product, m2, p) + tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape( + expand_batch_product, m2, p ) output_shape = expand_batch_portion From 5ff600aa6e40c6b4d426594bbb1f446f005b7fb3 Mon Sep 17 00:00:00 2001 From: William Wen Date: Sat, 12 Nov 2022 00:22:25 +0000 Subject: [PATCH 46/62] Add comprehensive minifier tests (#88022) Adds tests for https://github.com/pytorch/torchdynamo/issues/1241. To run: `pytest test/dynamo/test_minifier.py`. Actually runs minifier launcher script and repro scripts, rather than just checking for existence of the minifier launcher script. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88022 Approved by: https://github.com/mlazos, https://github.com/anijain2305 --- test/dynamo/test_minifier.py | 630 +++++++++++++++++++++++++++++++---- torch/_dynamo/debug_utils.py | 78 ++++- 2 files changed, 632 insertions(+), 76 deletions(-) diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 0cec7d202a9d446..51b79a5e7511ea3 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -1,27 +1,138 @@ # Owner(s): ["module: dynamo"] +import functools import os +import re import shutil +import subprocess +import textwrap import unittest -from unittest.mock import patch import torch import torch._dynamo import torch._dynamo.test_case import torch._dynamo.testing -from torch._dynamo.optimizations.backends import create_backend +import torch._inductor.utils +from torch._dynamo.debug_utils import TEST_REPLACEABLE_COMMENT +_HAS_TRITON = torch._inductor.utils.has_triton() +requires_cuda = functools.partial(unittest.skipIf, not _HAS_TRITON, "requires cuda") -class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() +RELU_COMPILE_ERROR_BACKEND = """\ +from torch._dynamo.optimizations.backends import register_backend - def forward(self, x): - for _ in range(10): - x = torch.sin(x) - x = torch._foobar(x) - for _ in range(10): - x = torch.cos(x) - return x +class DynamoCompileError(Exception): + pass + +@register_backend +def test_relu_compile_error(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + raise DynamoCompileError("relu found") + return gm +""" + +RELU_RUNTIME_ERROR_BACKEND = """\ +import copy +from torch._dynamo.optimizations.backends import register_backend + +@register_backend +def test_relu_runtime_error(gm: torch.fx.GraphModule, example_inputs): + gm = copy.deepcopy(gm) + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch._assert + node.args = (False, "DynamoRuntimeError") + gm.recompile() + return gm +""" + +RELU_ACCURACY_ERROR_BACKEND = """\ +import copy +from torch._dynamo.optimizations.backends import register_backend + +@register_backend +def test_relu_accuracy_error(gm: torch.fx.GraphModule, example_inputs): + gm = copy.deepcopy(gm) + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch.add + node.args = (node.args[0], 1) + gm.recompile() + + return gm +""" + +RELU_CUSTOM_ERROR_BACKEND = """\ +class CustomError(Exception): + pass + +def test_relu_custom_error(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + raise CustomError("relu found") + return gm +""" + +CPP_COMPILE_ERROR = """\ +def cpp_compile_error(x): + return "compile error!" +""" + +CPP_RUNTIME_ERROR = """\ +def cpp_runtime_error(x): + return f"{x}; throw 1" +""" + +CPP_ACCURACY_ERROR = """\ +def cpp_accuracy_error(x): + return f"{x} + 1" +""" + +TRITON_COMPILE_ERROR = """\ +def triton_compile_error(x): + return "compile error!" +""" + +# NOTE: there is currently not an easy way to cause a triton runtime error. +TRITON_RUNTIME_ERROR = """\ +def triton_runtime_error(x): + return f"{x}; assert?" +""" + +TRITON_ACCURACY_ERROR = """\ +def triton_accuracy_error(x): + return f"{x} + 1" +""" + +DEBUG_DIR = "/tmp/_torchdynamo_debug_/" + +# Search for the name of the first function defined in a code string. +def get_fn_name(code): + fn_name_match = re.search(r"def (\w+)\(", code) + if fn_name_match is not None: + return fn_name_match.group(1) + return None + + +# Generates code that patches CppOverrides/TritonOverrides. +def gen_codegen_fn_patch_code(old_fn_name, new_fn_code, device): + new_fn_name = get_fn_name(new_fn_code) + if new_fn_name is not None: + patch_code = f"""\ +import torch._inductor.codegen.{"cpp" if device == "cpu" else "triton"} as codegen +overrides = codegen.{"CppOverrides" if device == "cpu" else "TritonOverrides"} +{new_fn_code} +overrides.{old_fn_name} = staticmethod({new_fn_name}) +""" + return f"""\ +{patch_code} +isolate_fails_code_str = \"\"\"\\ +{patch_code} +torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" +\"\"\" +""" + + return None class MinfierTests(torch._dynamo.test_case.TestCase): @@ -32,9 +143,10 @@ def setUpClass(cls): unittest.mock.patch.object( torch._dynamo.config, "debug_dir_root", - "/tmp/_torchdynamo_debug_/", + DEBUG_DIR, ) ) + os.makedirs(DEBUG_DIR, exist_ok=True) @classmethod def tearDownClass(cls): @@ -47,65 +159,455 @@ def setUp(self): def tearDown(self): super().tearDown() - def test_after_dynamo(self): - @create_backend - def bad_dynamo_backend(subgraph): - import sys - - def f(*args): - # Shifted the forced exception to runtime as this is more common - # in JIT compilers. - for node in subgraph.model.graph.nodes: - if node.op == "call_function" and node.target is torch._foobar: - sys.stdout.write("Dynamo compiled failed\n") - raise NotImplementedError("foobar is not implemented") - return subgraph.model(*args) - - return f - - mod = MockModule() - opt_mod = torch._dynamo.optimize("bad_dynamo_backend")(mod) - repro_file = torch._dynamo.debug_utils.get_minifier_repro_path() - - @patch.object(torch._dynamo.config, "repro_after", "dynamo") - def inner(): - x = torch.randn(4) - try: - opt_mod(x) - except Exception: - pass - - inner() - self.assertTrue(os.path.exists(repro_file)) + # Run `code` in a separate python process. + # Returns the completed process state and the directory containing the + # minifier launcher script, if `code` outputted it. + def _run_test_code(self, code): + proc = subprocess.run( + ["python3", "-c", code], capture_output=True, cwd=DEBUG_DIR + ) + + repro_dir_match = re.search( + r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8") + ) + if repro_dir_match is not None: + # Print repro directory for debugging generated code. + # Make sure to comment out `shutil.rmtree...` above as well. + print("repro dir:", repro_dir_match.group(1)) + return proc, repro_dir_match.group(1) + return proc, None - # If error_at_aot is True, an error will be produced when AOTAutograd - # attempts to generate the backward graph. - # If error_after_aot is False, an error will be produced in inductor. - def _test_around_aot(self, error_at_aot): - mod = MockModule() - opt_mod = torch._dynamo.optimize("inductor")(mod) + # Patch generated files with testing patches + def _inject_code(self, patch_code, filename): + patch_code = f"""\ +{patch_code} +torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" +""" + with open(filename, "r") as f: + code = f.read() + code = code.replace(TEST_REPLACEABLE_COMMENT, patch_code) + with open(filename, "w") as f: + f.write(code) + return code - repro_file = torch._dynamo.debug_utils.get_minifier_repro_path() - repro_after = "dynamo" if error_at_aot else "aot" + # Runs the minifier launcher script in `repro_dir`, patched with `patch_code`. + def _run_minifier_launcher(self, patch_code, repro_dir): + self.assertIsNotNone(repro_dir) + launch_file = os.path.join(repro_dir, "minifier_launcher.py") + self.assertTrue(os.path.exists(launch_file)) + launch_code = self._inject_code(patch_code, launch_file) - @patch.object(torch._dynamo.config, "repro_after", repro_after) - def inner(): - x = torch.randn(4) - x.requires_grad = error_at_aot - try: - opt_mod(x) - except Exception: - pass + launch_proc = subprocess.run( + ["python3", launch_file], + capture_output=True, + cwd=repro_dir, + ) - inner() + return launch_proc, launch_code + # Runs the repro script in `repro_dir`, patched with `patch_code` + def _run_repro(self, patch_code, repro_dir): + self.assertIsNotNone(repro_dir) + repro_file = os.path.join(repro_dir, "repro.py") self.assertTrue(os.path.exists(repro_file)) + repro_code = self._inject_code(patch_code, repro_file) + + repro_proc = subprocess.run( + ["python3", repro_file], capture_output=True, cwd=repro_dir + ) + + return repro_proc, repro_code + + # Template for testing code. + # `run_code` is the code to run for the test case. + # `patch_code` is the code to be patched in every generated file. + def _gen_test_code(self, run_code, repro_after, repro_level, patch_code): + return f"""\ +import torch +import torch._dynamo +{patch_code} +torch._dynamo.config.repro_after = "{repro_after}" +torch._dynamo.config.repro_level = {repro_level} +torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" +{run_code} +""" + + # Runs a full minifier test. + # Minifier tests generally consist of 3 stages: + # 1. Run the problematic code (in a separate process since it could segfault) + # 2. Run the generated minifier launcher script + # 3. Run the generated repro script + def _run_full_test(self, run_code, repro_after, repro_level, patch_code): + test_code = self._gen_test_code(run_code, repro_after, repro_level, patch_code) + test_proc, repro_dir = self._run_test_code(test_code) + self.assertIsNotNone(repro_dir) + launch_proc, launch_code = self._run_minifier_launcher(patch_code, repro_dir) + repro_proc, repro_code = self._run_repro(patch_code, repro_dir) + return ((test_proc, launch_proc, repro_proc), (launch_code, repro_code)) + + # Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA) + def _test_after_dynamo(self, device, repro_level, backend_code, error_name): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{get_fn_name(backend_code)}") + def inner(x): + for _ in range(10): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "dynamo", repro_level, backend_code + ) + + self.assertIn(error_name, test_proc.stderr.decode("utf-8")) + self.assertIn(error_name, repro_proc.stderr.decode("utf-8")) + + def test_after_dynamo_cpu_compile_error(self): + self._test_after_dynamo( + "cpu", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" + ) + + def test_after_dynamo_cpu_runtime_error(self): + self._test_after_dynamo( + "cpu", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" + ) + + def test_after_dynamo_cpu_accuracy_error(self): + self._test_after_dynamo("cpu", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") + + @requires_cuda() + def test_after_dynamo_cuda_compile_error(self): + self._test_after_dynamo( + "cuda", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" + ) + + @requires_cuda() + def test_after_dynamo_cuda_runtime_error(self): + self._test_after_dynamo( + "cuda", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" + ) + + @requires_cuda() + def test_after_dynamo_cuda_accuracy_error(self): + self._test_after_dynamo("cuda", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") + + # Ensure that the testing backends pass when relu is not present. + def _test_after_dynamo_backend_passes(self, device, repro_level, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{get_fn_name(backend_code)}") + def inner(x): + for _ in range(10): + x = torch.sin(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + + test_code = self._gen_test_code(run_code, "dynamo", repro_level, backend_code) + proc, repro_dir = self._run_test_code(test_code) + self.assertEqual(proc.returncode, 0) + self.assertIsNone(repro_dir) + + def test_after_dynamo_cpu_compile_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 2, RELU_COMPILE_ERROR_BACKEND) + + def test_after_dynamo_cpu_runtime_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 2, RELU_RUNTIME_ERROR_BACKEND) + + def test_after_dynamo_cpu_accuracy_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 4, RELU_ACCURACY_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_compile_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 2, RELU_COMPILE_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_runtime_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 2, RELU_RUNTIME_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_accuracy_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 4, RELU_ACCURACY_ERROR_BACKEND) + + # Ensure that generated code with a custom backends generates a runnable minifier + # launcher script that results in a RuntimeError + def test_after_dynamo_custom_backend(self): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize({get_fn_name(RELU_CUSTOM_ERROR_BACKEND)}) + def inner(x): + for _ in range(10): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20)) + """ + ) + + test_code = self._gen_test_code( + run_code, "dynamo", 2, RELU_CUSTOM_ERROR_BACKEND + ) + _, repro_dir = self._run_test_code(test_code) + launch_proc, launch_code = self._run_minifier_launcher("", repro_dir) + self.assertIn("RuntimeError", launch_proc.stderr.decode("utf-8")) + + # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd + @requires_cuda() + def test_cpu_cuda_module_after_dynamo(self): + backend_name = get_fn_name(RELU_COMPILE_ERROR_BACKEND) + + run_code = textwrap.dedent( + f"""\ + class CpuCudaModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.m_x = torch.nn.Linear(20, 20).cuda() + self.m_y = torch.nn.Linear(20, 20) + self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda()) + self.p_y = torch.nn.Parameter(torch.randn(20, 20)) + self.register_buffer("b_x", torch.ones(20, 20).cuda()) + self.register_buffer("b_y", torch.ones(20, 20)) + + def forward(self, x, y): + return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y + + mod = CpuCudaModule() + + @torch._dynamo.optimize("{backend_name}") + def inner(x1, y1): + x2 = torch.randn(20, 20).cuda() + y2 = torch.randn(20, 20) + x3, y3 = mod(x1 + x2, y1 + y2) + return torch.relu(x3.cpu() + y3) + + inner(torch.randn(20, 20).cuda(), torch.randn(20, 20)) + """ + ) + + (test_proc, _, repro_proc), (launch_code, _) = self._run_full_test( + run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND + ) + + tb1 = test_proc.stderr.decode("utf-8") + tb2 = repro_proc.stderr.decode("utf-8") + + # Check if generated minifier code covers all cpu/cuda cases + self.assertIsNotNone(re.search(r"args.*cuda", launch_code)) + self.assertIsNotNone(re.search(r"args.*cpu", launch_code)) + # search for Linear(...).cuda() + self.assertIsNotNone(re.search(r"Linear.*cuda", launch_code)) + # search for Linear(...) + self.assertIsNotNone( + re.search(r"Linear(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + self.assertIsNotNone(re.search(r"register_buffer.*cuda", launch_code)) + self.assertIsNotNone( + re.search(r"register_buffer(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + self.assertIsNotNone(re.search(r"Parameter.*cuda", launch_code)) + self.assertIsNotNone( + re.search(r"Parameter(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + # search for + # = torch.randn(...) + # ... = .cuda() + self.assertIsNotNone( + re.search(r"(\w+) = torch.randn.*\1\.cuda", launch_code, re.DOTALL) + ) + # search for + # = torch.randn(...) + # no followup call to .cuda() + self.assertIsNotNone( + re.search( + r"(\w+) = torch.randn(?!.*\1\.cuda\(\).*$)", launch_code, re.DOTALL + ) + ) + + self.assertIn(backend_name, tb1) + self.assertIn(backend_name, tb2) + + # Test if we can actually get a minified graph + def test_if_graph_minified(self): + backend_name = get_fn_name(RELU_COMPILE_ERROR_BACKEND) + + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{backend_name}") + def inner(x): + for _ in range(20): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(20): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20)) + """ + ) + + (test_proc, _, repro_proc), (launch_code, repro_code) = self._run_full_test( + run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND + ) + + tb1 = test_proc.stderr.decode("utf-8") + tb2 = repro_proc.stderr.decode("utf-8") + + self.assertIn(backend_name, tb1) + self.assertIn(backend_name, tb2) + + # compare the length of the forward functions + match = re.search(r"def forward.*return", launch_code, re.DOTALL) + self.assertIsNotNone(match) + self.assertGreater(match.group(0).count("\n"), 40) + + match = re.search(r"def forward.*return", repro_code, re.DOTALL) + self.assertIsNotNone(match) + self.assertLess(match.group(0).count("\n"), 5) + + # Test that compile and accuracy errors after aot can be repro'd (both CPU and CUDA) + def _test_after_aot(self, device, backend_code, repro_level): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "aot", repro_level, patch_code + ) + return ( + (test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")), + (test_proc.returncode, repro_proc.returncode), + ) + + def test_after_aot_cpu_compile_error(self): + (tb1, tb2), _ = self._test_after_aot("cpu", CPP_COMPILE_ERROR, 2) + self.assertIn("CppCompileError", tb1) + self.assertIn("CppCompileError", tb2) + + def test_after_aot_cpu_accuracy_error(self): + (tb1, tb2), _ = self._test_after_aot("cpu", CPP_ACCURACY_ERROR, 4) + self.assertIn("AccuracyError", tb1) + self.assertIn("AccuracyError", tb2) + + @requires_cuda() + def test_after_aot_cuda_compile_error(self): + (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_COMPILE_ERROR, 2) + self.assertIn("SyntaxError", tb1) + self.assertIn("SyntaxError", tb2) + + @requires_cuda() + def test_after_aot_cuda_accuracy_error(self): + (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_ACCURACY_ERROR, 4) + self.assertIn("AccuracyError", tb1) + self.assertIn("AccuracyError", tb2) + + # Test that runtime errors after aot can be repro'd (CPU only for now) + def _test_after_aot_runtime_error(self, device, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "aot", 3, patch_code + ) + + self.assertNotIn("CompilerError", test_proc.stderr.decode("utf-8")) + + self.assertEqual(test_proc.returncode, repro_proc.returncode) + self.assertNotEqual(test_proc.returncode, 0) + + def test_after_aot_cpu_runtime_error(self): + self._test_after_aot_runtime_error("cpu", CPP_RUNTIME_ERROR) + + # NOTE: there is currently not an easy way to cause a triton runtime error. + @unittest.skip + @requires_cuda() + def test_after_aot_cuda_runtime_error(self): + self._test_after_aot_runtime_error("cuda", TRITON_RUNTIME_ERROR) + + # Ensure that inductor codegen patches pass when relu is not present. + def _test_after_aot_backend_passes(self, device, repro_level, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + + test_code = self._gen_test_code(run_code, "aot", repro_level, patch_code) + proc, repro_dir = self._run_test_code(test_code) + self.assertEqual(proc.returncode, 0) + self.assertIsNone(repro_dir) + + def test_after_aot_cpu_compile_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 2, CPP_COMPILE_ERROR) + + def test_after_aot_cpu_runtime_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 2, CPP_RUNTIME_ERROR) + + def test_after_aot_cpu_accuracy_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 4, CPP_ACCURACY_ERROR) + + @requires_cuda() + def test_after_aot_cuda_compile_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 2, TRITON_COMPILE_ERROR) - def test_at_aot(self): - self._test_around_aot(True) + # NOTE: there is currently not an easy way to cause a triton runtime error. + @unittest.skip + @requires_cuda() + def test_after_aot_cuda_runtime_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 2, TRITON_RUNTIME_ERROR) - def test_after_aot(self): - self._test_around_aot(False) + @requires_cuda() + def test_after_aot_cuda_accuracy_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 4, TRITON_ACCURACY_ERROR) if __name__ == "__main__": diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index f09991f9bf3489c..98a269fe8c9eb57 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -84,6 +84,11 @@ def __init__(self): for module_name, module in gm.named_children(): module_str = f"{module.__repr__()}" + # module should be a core torch.nn.Module, so all parameters + # should be on the same device. + example_param = next(module.parameters(), None) + if example_param is not None and example_param.is_cuda: + module_str = f"{module_str}.cuda()" model_str += f"{tab*2}self.{module_name} = {module_str}\n" for buffer_name, buffer in gm._buffers.items(): @@ -95,12 +100,16 @@ def __init__(self): tensor_str = ( f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" ) + if buffer.is_cuda: + tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" for param_name, param in gm._parameters.items(): if param is None: continue tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))" + if param.is_cuda: + tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" # TODO - Keep this code for now. But, I don't think we will need this. @@ -145,6 +154,9 @@ def _cuda_system_info_comment(): return model_str +TEST_REPLACEABLE_COMMENT = "# REPLACEABLE COMMENT FOR TESTING PURPOSES" + + def generate_compiler_repro_string(gm, args): model_str = textwrap.dedent( f""" @@ -155,6 +167,8 @@ def generate_compiler_repro_string(gm, args): from math import inf from torch.fx.experimental.proxy_tensor import make_fx + {TEST_REPLACEABLE_COMMENT} + """ ) model_str += f"# torch version: {torch.version.__version__}\n" @@ -170,7 +184,7 @@ def generate_compiler_repro_string(gm, args): model_str += ( "args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]\n" ) - model_str += 'mod = make_fx(Repro().to(device="cuda"))(*args)\n' + model_str += "mod = make_fx(Repro())(*args)\n" return model_str @@ -197,7 +211,8 @@ def dump_compiler_graph_state(gm, args, compiler_name): log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") with open(file_name, "w") as fd: save_graph_repro(fd, gm, args, compiler_name) - repro_path = os.path.join(config.base_dir, "repro.py") + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") try: shutil.copyfile(file_name, repro_path) log.warning(f"Copying repro file for convenience to {repro_path}") @@ -216,7 +231,10 @@ def save_graph_repro(fd, gm, args, compiler_name): textwrap.dedent( f""" compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) - assert same_two_models(mod, compiled, args, only_fwd=True), "Accuracy failed" + class AccuracyError(Exception): + pass + if not same_two_models(mod, compiled, args, only_fwd=True): + raise AccuracyError("Bad accuracy detected") """ ) ) @@ -231,7 +249,7 @@ def save_graph_repro(fd, gm, args, compiler_name): ) -def isolate_fails(fx_g, args, compiler_name: str, env=None): +def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None): if env is None: env = {} subdir = os.path.join(os.getcwd(), "isolate") @@ -239,7 +257,10 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None): os.makedirs(subdir, exist_ok=True) file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") with open(file_name, "w") as fd: - fd.write(generate_compiler_repro_string(fx_g, args)) + repro_code = generate_compiler_repro_string(fx_g, args) + if patch_code is not None: + repro_code = repro_code.replace(TEST_REPLACEABLE_COMMENT, patch_code) + fd.write(repro_code) fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2] fd.write( textwrap.dedent( @@ -263,6 +284,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None): stdout, stderr = TemporaryFile(), TemporaryFile() p = subprocess.Popen( ["python", file_name], + cwd=subdir, stdout=stdout, stderr=stderr, env=new_env, @@ -329,6 +351,8 @@ def dump_to_minify(gm, args, compiler_name: str): contents = textwrap.dedent( f""" +isolate_fails_code_str = None + {generate_compiler_repro_string(gm, args)} from functools import partial @@ -343,7 +367,7 @@ def dump_to_minify(gm, args, compiler_name: str): minifier( mod, args, - module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}"), + module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}", patch_code=isolate_fails_code_str), dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"), ) """ @@ -351,6 +375,10 @@ def dump_to_minify(gm, args, compiler_name: str): return helper_for_dump_minify(contents) +class AccuracyError(Exception): + pass + + def wrap_compiler_debug(compiler_fn, compiler_name: str): """ Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both @@ -410,7 +438,7 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, f"{compiler_name}_accuracy", ) - raise ValueError("Bad accuracy detected") + raise AccuracyError("Bad accuracy detected") else: # Call the compiled function with real inputs return inner_compiled_fn(real_inputs) @@ -435,7 +463,8 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, compiler_name, ) - raise e + log.error("CompilerError") + raise if config.repro_after == "aot": compiled_fn = deferred_for_real_inputs @@ -544,9 +573,14 @@ def generate_dynamo_fx_repro_string( f""" mod.eval() opt_mod.eval() + +class AccuracyError(Exception): + pass + with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): assert same_two_models(mod, mod, args), "Eager itself failed" - assert same_two_models(mod, opt_mod, args), "Dynamo failed" + if not same_two_models(mod, opt_mod, args): + raise AccuracyError("Dynamo failed") """ ) @@ -561,12 +595,14 @@ def generate_dynamo_fx_repro_string( from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd from {config.dynamo_import}.debug_utils import same_two_models +{TEST_REPLACEABLE_COMMENT} + args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro().cuda() +mod = Repro() opt_mod = {config.dynamo_import}.optimize("{compiler_name}")(mod) {run_code} @@ -705,6 +741,21 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): if config.repro_level == 4: minifier_backend = "dynamo_accuracy_minifier_backend" + custom_compiler_error = ( + textwrap.dedent( + """\ + raise RuntimeError( + 'Compiler name is None - this likely means that a custom compiler ' + 'was called by torchdynamo. Please remove this error, import your ' + 'custom compiler function, and replace the compiler_name="None" ' + 'line below to compiler_name=' + ) + """ + ) + if compiler_name is None + else "" + ) + contents = textwrap.dedent( f""" import os @@ -718,14 +769,17 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): from {config.dynamo_import}.optimizations.backends import BACKENDS from {config.dynamo_import}.testing import rand_strided +{TEST_REPLACEABLE_COMMENT} + args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro().cuda() +mod = Repro() # Setup debug minifier compiler compiler_fn = BACKENDS["{minifier_backend}"] +{custom_compiler_error} dynamo_minifier_backend = functools.partial( compiler_fn, compiler_name="{compiler_name}", @@ -769,7 +823,7 @@ def debug_wrapper(gm, example_inputs, **kwargs): example_inputs, compiler_name, ) - exc = ValueError("Bad accuracy detected.") + exc = AccuracyError("Bad accuracy detected.") exc.minifier_path = os.path.join( minifier_dir(), "minifier_launcher.py" ) From a3f3ec8fac98151f31373ba3bcfe2d601584a840 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 11 Nov 2022 21:22:49 +0000 Subject: [PATCH 47/62] [FSDP+dynamo]: forward treats parameter-views as params (#88781) Dynamo+AotAutograd needs a way to wrap all tensors (whether inputs or params/buffers) in FakeTensor wrappers, and FSDP's mangling of parameters hides them from this wrapping. This PR unblocks running hf_bert and hf_T5 with FSDP under dynamo, whether using recursive wrapping around transformer layers or only applying FSDP around the whole model. Perf/memory validation and possibly optimization is the next step. `python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager` `python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager --fsdp_wrap` `python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager` `python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager --fsdp_wrap` The problem: Dynamo (Actually aot_autograd) trips up with FSDP becuase it must wrap all input tensors in FakeTensor wrappers, and it only knows to wrap graph inputs or named_(parameters, buffers). FSDP's pre_forward hook sets views (which are not nn.param) into the flatparam as attrs on the module with the same name as the original param, but they will not show up in named_parameters. - in use_orig_params mode, FSDP still de-registers params during pre-forward hook, then re-registers them post-forward - during forward (between the hooks), the params are setattr'd on the module as regular view tensors, not nn.Parameters - note: use_orig_params is the recommended way to use FSDP, and use_orig_params=False is being deprecated. So i only consider use_orig_params=True for this enablement The solution: - adding them to named_buffers is not possible because it interferes with how FSDP's `_apply` works - since they are not actual nn.parameters, register_parameter will complain about registering them - simply seting `module._parameters[name] = view` seems to be a viable workaround, despite being hacky, and FSDP code does modify _parameters directly already. Note: Manual checkpointing still isn't working with FSDP+dynamo, so that will have to be addressed in a follow up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88781 Approved by: https://github.com/ezyang, https://github.com/awgu --- benchmarks/dynamo/dist_util.py | 20 +-- benchmarks/dynamo/distributed.py | 5 +- test/distributed/test_dynamo_distributed.py | 131 ++++++++++++++++---- torch/distributed/fsdp/flat_param.py | 4 + 4 files changed, 124 insertions(+), 36 deletions(-) diff --git a/benchmarks/dynamo/dist_util.py b/benchmarks/dynamo/dist_util.py index 9e2f086ca8b70e4..d30b5a63cfe5f96 100644 --- a/benchmarks/dynamo/dist_util.py +++ b/benchmarks/dynamo/dist_util.py @@ -20,6 +20,9 @@ except ImportError: from torchbench import setup_torchbench_cwd +from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead +from transformers.models.t5.modeling_t5 import T5Block + def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" @@ -122,26 +125,25 @@ def check_fn(submodule): ) -# from transformers.models.t5.modeling_t5 import T5Block - MODEL_FSDP_WRAP = { - ToyModel: (MyModule,) - # TODO T5: (T5Block,) + "toy_model": (MyModule,), + "hf_Bert": (BertLayer, BertLMPredictionHead), + "hf_T5": (T5Block,), } -def apply_fsdp(model, use_checkpointing=False, use_wrap_policy=True): - blocks = MODEL_FSDP_WRAP[model.__class__] - +def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True): wrap_policy = None + blocks = MODEL_FSDP_WRAP[ + "toy_model" if model.__class__ is ToyModel else args.torchbench_model + ] if use_wrap_policy: # transformer policy is really a generic policy that wraps modules of specified classes wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=blocks ) - model = FSDP(model, auto_wrap_policy=wrap_policy) + model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True) if use_checkpointing: fsdp_checkpointing_base(model, blocks) - return model diff --git a/benchmarks/dynamo/distributed.py b/benchmarks/dynamo/distributed.py index c2db15563348aef..32e3b544d87ddcc 100644 --- a/benchmarks/dynamo/distributed.py +++ b/benchmarks/dynamo/distributed.py @@ -50,6 +50,7 @@ def move_tensor(maybe_tensor): if args.fsdp: model = apply_fsdp( + args, model, use_checkpointing=args.fsdp_checkpoint, use_wrap_policy=args.fsdp_wrap, @@ -160,7 +161,9 @@ def experiment(fn, key, world_size, results): ) args = parser.parse_args() - model_name = "ToyModel" if args.toy_model else args.torchbench_model + model_name = args.torchbench_model + if args.toy_model: + model_name = "ToyModel" model, inputs = get_model(args) fn = partial(run_model, args, model, inputs) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 3dd3c5de77253b6..b6bc16edb941a04 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1,4 +1,6 @@ # Owner(s): ["module: dynamo"] +import copy +import functools import logging import os import random @@ -16,7 +18,9 @@ from torch._dynamo.utils import same from torch._dynamo.testing import collect_results from torch._inductor.utils import has_triton +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing._internal.common_distributed import ( MultiProcessTestCase, import_transformers_or_skip, @@ -175,6 +179,7 @@ def test_ddp_baseline_aot_eager_multiprocess(self): @skip_if_lt_x_gpu(2) @import_transformers_or_skip() + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp(self): @@ -199,6 +204,106 @@ def test_hf_bert_ddp(self): opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) self.assertTrue(same(correct_results, opt_results)) + + @skip_if_lt_x_gpu(1) + # TODO(whc) delete aot_eager test, if inductor test lands stably + def test_fsdp_aot_eager(self): + with _per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # Test with recursive wrapping, nested FSDP around each Linear + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP( + m, + auto_wrap_policy=functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, ) + ), + use_orig_params=True + ) + fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @skip_if_lt_x_gpu(1) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_fsdp_inductor(self): + with _per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # Test with recursive wrapping, nested FSDP around each Linear + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP( + m, + auto_wrap_policy=functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, ) + ), + use_orig_params=True + ) + fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @import_transformers_or_skip() + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert + @patch.object(torch._inductor.config.triton, "cudagraphs", False) + @patch.object(torch._inductor.config, "fallback_random", True) + def test_hf_bert_fsdp(self): + from transformers.models.bert.modeling_bert import BertLayer + + def apply_fsdp(model, wrap_policy): + model = FSDP( + copy.deepcopy(model), + auto_wrap_policy=wrap_policy, + use_orig_params=True + ) + return model + + with _per_rank_init(self.rank, self.world_size): + for (wrap_policy, test_instance) in ( + ( + None, + "FSDP without recursive wrapping" + ), + ( + functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer, ) + ), + "FSDP with recursive wrapping BertLayer instances" + ) + ): + print(f"Running hf_bert test for {test_instance}") + model, inputs = get_hf_bert(self.rank) + reset_rng_state() + eager_model = apply_fsdp(model, wrap_policy) + correct_outputs = eager_model(**inputs) + correct_loss = correct_outputs.loss + correct_loss.backward() + + reset_rng_state() + opt_model = apply_fsdp(model, wrap_policy) + + opt_model = torch._dynamo.optimize("inductor")(opt_model) + opt_outputs = opt_model(**inputs) + opt_loss = opt_outputs.loss + opt_loss.backward() + + inputs_flat = [inputs[k] for k in inputs] + correct_results = collect_results(eager_model, correct_outputs.logits, correct_loss, inputs_flat) + opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) + self.assertTrue(same(correct_results, opt_results)) + + @requires_nccl() class TestDistributed(torch._dynamo.test_case.TestCase): """ @@ -257,32 +362,6 @@ def test_ddp_baseline_inductor(self): outputs = ddp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) - # TODO(whc) move these tests to 'distributed' shard to get nccl, or see if it's available already in pytorch CI? - @unittest.skip( - "can't run with gloo (no support for _allgather_base) and nccl not available in CI" - ) - @patch.object(config, "optimize_ddp", False) - def test_fsdp_baseline_aot_eager(self): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - m, inputs, correct_outputs = self.get_model() - fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) - fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) - outputs = fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - @unittest.skip("hangs/crashes with inductor currently") - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @patch.object(config, "optimize_ddp", False) - def test_fsdp_baseline_inductor(self): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - m, inputs, correct_outputs = self.get_model() - fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) - fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) - outputs = fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - @patch.object(config, "optimize_ddp", True) def test_graph_split(self): """ diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py index 0978f0875a28f99..b790590c7943f02 100644 --- a/torch/distributed/fsdp/flat_param.py +++ b/torch/distributed/fsdp/flat_param.py @@ -1306,6 +1306,8 @@ def _use_unsharded_views(self, as_params: bool) -> None: assert tensor is not None # mypy param_var = tensor setattr(module, param_name, param_var) + if self._use_orig_params and self._training_state == HandleTrainingState.FORWARD: + module._parameters[param_name] = param_var # type: ignore[assignment] for i, ( param_name, module, @@ -1336,6 +1338,8 @@ def _use_unsharded_views(self, as_params: bool) -> None: module.register_parameter(param_name, prim_param) else: setattr(module, param_name, prim_param) + if self._use_orig_params and self._training_state == HandleTrainingState.FORWARD: + module._parameters[param_name] = prim_param # type: ignore[assignment] def _use_unsharded_grad_views(self) -> None: """ From 2cd05a2818bacbc2e252052b6b71085e4de16b0d Mon Sep 17 00:00:00 2001 From: Jiaxu Zhu Date: Sat, 12 Nov 2022 01:20:52 +0000 Subject: [PATCH 48/62] Support torch.qint32 in Convert (#88871) Enable the `torch.qint32` when creating `quantize_per_tensor` function call in `convert_fx` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88871 Approved by: https://github.com/jerryzh168 --- torch/ao/quantization/fx/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 61bb2cdc1b034ca..a5a989ec21480f6 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -183,7 +183,7 @@ def get_quantize_node_info( if hasattr(activation_post_process, "compute_dtype"): compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined] quantize_op : Optional[Union[Callable, str]] = None - if dtype in [torch.quint8, torch.qint8] and \ + if dtype in [torch.quint8, torch.qint8, torch.qint32] and \ not hasattr(activation_post_process, 'compute_dtype'): node_type = "call_function" scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined] From 2b166532f7ac280232daf6c44620e96e258867cf Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 11 Nov 2022 09:00:55 -0500 Subject: [PATCH 49/62] Remove incorrect assert about hermetic state. (#88885) I'm not sure why I thought this assert was valid in the first place, and there's no comment about it. The assert is tantamount to saying, "no tensor objects should become dead via SafePyObject when hermetic mode is on." But suppose we run a Python GC while we're inside hermetic mode. This could result in us disposing non-hermetic tensors, which would hit decref. So the assert seems invalid. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/88885 Approved by: https://github.com/anjali411, https://github.com/malfet --- torch/csrc/autograd/python_variable.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 920d0e7344b589a..002b904d40721e6 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -305,10 +305,6 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor) // THPVariable_clear). // 2. We are decref-ing some other Python object. We don't do // PyObject resurrection on non-Tensors, so we just carry on as usual - if (is_tensor) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - !c10::impl::HermeticPyObjectTLS::get_state()); - } if (is_tensor && Py_REFCNT(pyobj) > 1) { // It's still alive! This can happen if a weak ref resurrected // the PyObject without flipping ownership. At this point it is From 66736ff425d7163df0eed48e9944c8539e92b577 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 11 Nov 2022 09:33:41 -0500 Subject: [PATCH 50/62] Fix bug in OptionalTensorList (#88887) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/88887 Approved by: https://github.com/anjali411 --- aten/src/ATen/core/PythonFallbackKernel.cpp | 5 ++++- test/test_python_dispatch.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp index e16874a83f9661a..2d8834afe59ef77 100644 --- a/aten/src/ATen/core/PythonFallbackKernel.cpp +++ b/aten/src/ATen/core/PythonFallbackKernel.cpp @@ -74,10 +74,13 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { (*interpreter)->dispatch(op, stack); return; } - } else if (ivalue.isTensorList() || (ivalue.isOptionalTensorList() && !ivalue.isNone())) { + } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) { // NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef // is not a thing) for (const auto& nv : ivalue.toListRef()) { + if (nv.isNone()) { + continue; + } auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter(); if (interpreter) { (*interpreter)->dispatch(op, stack); diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 380f85f568f72c1..33465217bbbc0ab 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -390,6 +390,24 @@ def test_produce_real_type(self) -> None: $4 = torch._ops.aten.select.int($3, 1, 1) $5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''') + def test_optional_tensor_list(self) -> None: + def weird(xs): + print("woof") + return torch.empty(()) + + my_lib = Library("my_lib", "DEF") + my_lib.define("weird(Tensor?[] self) -> Tensor") + my_lib.impl("weird", weird, "CPU") + with capture_logs() as logs: + x = LoggingTensor(torch.ones(2, 2)) + log_input("x", x) + torch.ops.my_lib.weird.default([None, x]) + + self.assertExpectedInline('\n'.join(logs), '''\ +$0 = input('x') +$1 = torch._ops.my_lib.weird.default([None, LoggingTensor(tensor([[1., 1.], + [1., 1.]]))])''') + def test_list_ret(self) -> None: # test all sequence types are permissible returns for list_type in (list, tuple): From 1e2327baf7a2d9c63bef08e5f996ef983e199429 Mon Sep 17 00:00:00 2001 From: mikey dagitses Date: Sat, 12 Nov 2022 02:23:48 +0000 Subject: [PATCH 51/62] fix fx tests (#88886) Summary: Some source files are missing and TPX couldn't handle the default test names. Test Plan: Rely on CI. Differential Revision: D41218564 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88886 Approved by: https://github.com/zou3519 --- test/fx/test_common_passes.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/fx/test_common_passes.py b/test/fx/test_common_passes.py index 9c59abce4da6183..407e707db8797a9 100644 --- a/test/fx/test_common_passes.py +++ b/test/fx/test_common_passes.py @@ -73,10 +73,15 @@ def MutationMetadata(x): if torch.cuda.is_available(): Devices.append("cuda") + +def name_fn(common_pass, f, device): + """Names parameterized test cases.""" + return f'{type(common_pass()).__name__}_{f.__name__}_{device}' + @instantiate_parametrized_tests class TestCommonPass(TestCase): - @parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices)) + @parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn) def test_correctness(self, common_pass, f, device): inp = torch.randn(10, device=device) @@ -94,7 +99,7 @@ def test_correctness(self, common_pass, f, device): self.assertEqual(result, expected) - @parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices)) + @parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices), name_fn) def test_correctness_factory(self, common_pass, f, device): inp = torch.randn(10, device=device) traced_m = make_fx(f)(inp, device) From 4108367123c1b44289b5c731c3bb7022976b816d Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 11 Nov 2022 20:41:36 +0000 Subject: [PATCH 52/62] Exclude poolformer_m36 from the inductor model test (#88908) Summary: The root cause is still to be investigated. Issue tracked at https://github.com/pytorch/torchdynamo/issues/1856 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88908 Approved by: https://github.com/malfet --- benchmarks/dynamo/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 758f4396b5b1bcb..198877e0313d830 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -156,6 +156,7 @@ "hrnet_w18", # accuracy "lcnet_0500", # accuracy "levit_128", # levit_128 + "poolformer_m36", "rexnet_100", # accuracy "swin_base_patch4_window7_224", "twins_pcpvt_base", # time out From ae4074669ecbf2a6d8bf99db745d29dce98d0c10 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 10 Nov 2022 21:19:22 +0000 Subject: [PATCH 53/62] [FSDP][state_dict][6/N] Remove most FSDP module dependency from _optim_utils (#88638) **What** This PR removes most `FullyShardedDataParallel` dependencies from `optim_utils`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88638 Approved by: https://github.com/awgu --- torch/distributed/fsdp/_optim_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 530a8480d55220a..70fb4156d53780a 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -22,9 +22,11 @@ import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file import torch.nn as nn from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.fsdp._common_utils import _get_param_to_fqns from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor -from torch.distributed.fsdp._runtime_utils import _clear_grads_if_needed +from torch.distributed.fsdp._runtime_utils import _clear_grads_if_needed, _lazy_init from torch.distributed.fsdp._shard_utils import _gather_state_dict +from torch.distributed.fsdp.api import ShardingStrategy from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle @@ -185,7 +187,7 @@ def _communicate_optim_state( # we take the target rank's value if ( fsdp_module.world_size == 1 - or fsdp_module.sharding_strategy == fsdp_file.ShardingStrategy.NO_SHARD + or fsdp_module.sharding_strategy == ShardingStrategy.NO_SHARD ): tensor_state[state_name] = value continue @@ -293,7 +295,7 @@ def _flatten_optim_state_dict( '"param_groups" to be a valid optimizer state dict' ) flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model) - param_to_fqns = fsdp_file._get_param_to_fqns(model) + param_to_fqns = _get_param_to_fqns(model) # Construct the "state" part flat_osd_state: Dict[_OptimStateKey, Any] = {} @@ -897,7 +899,7 @@ def _rekey_sharded_optim_state_dict( if using_optim_input else _get_param_to_param_id(optim) ) - param_to_fqns = fsdp_file._get_param_to_fqns(model) + param_to_fqns = _get_param_to_fqns(model) # All parameter keys in `param_to_flat_param_id` should be in # `param_to_fqns` -- strict inequality follows when not all parameters are # passed to the optimizer @@ -951,7 +953,7 @@ def _get_flat_param_to_fsdp_module(model: torch.nn.Module): flat_param_to_fsdp_module = {} for module in model.modules(): if isinstance(module, fsdp_file.FullyShardedDataParallel): - fsdp_file._lazy_init(module, module) + _lazy_init(module, module) for param in module.params: # may have none flat_param_to_fsdp_module[param] = module return flat_param_to_fsdp_module @@ -1165,9 +1167,7 @@ def _optim_state_dict( # Construct the local mapping between unflattened parameter names # (`_OptimStateKey`s) and parameter IDs and broadcast rank 0's mapping - param_to_fqns: Dict[torch.nn.Parameter, List[str]] = fsdp_file._get_param_to_fqns( - model - ) + param_to_fqns: Dict[torch.nn.Parameter, List[str]] = _get_param_to_fqns(model) flat_param_id_to_param: List[torch.nn.Parameter] = ( _get_param_id_to_param_from_optim_input(model, optim_input) if using_optim_input From b2b0a0d3baf6258fbf728572719937810fd890ce Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 12 Nov 2022 03:21:06 +0000 Subject: [PATCH 54/62] [vision hash update] update the pinned vision hash (#88920) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/master/.github/workflows/_update-commit-hash.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88920 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 48685938a146b49..b9eda365de0c56e 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -d72e90640ec8514e0369b5419d7f3b74a387b1d7 +deba056203d009fec6b58afb9fa211f6ee3328c8 From d01bf1d1f11ab1fb9ae21a007138e2c4ecc31b63 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Sat, 12 Nov 2022 01:05:46 +0000 Subject: [PATCH 55/62] [FSDP] Introduce `ModuleWrapPolicy` for simplicity (#88450) **BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`. **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective. I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450 Approved by: https://github.com/zhaojuanmao --- .../_composable/test_fully_shard.py | 27 +- .../fsdp/test_fsdp_clip_grad_norm.py | 10 +- test/distributed/fsdp/test_fsdp_misc.py | 22 +- test/distributed/fsdp/test_fsdp_state_dict.py | 12 +- .../fsdp/test_fsdp_use_orig_params.py | 9 +- test/distributed/fsdp/test_utils.py | 7 +- test/distributed/fsdp/test_wrap.py | 16 + torch/distributed/_composable/fully_shard.py | 8 +- torch/distributed/fsdp/__init__.py | 1 - torch/distributed/fsdp/_init_utils.py | 5 +- torch/distributed/fsdp/_wrap_utils.py | 17 +- torch/distributed/fsdp/flat_param.py | 3 +- .../fsdp/fully_sharded_data_parallel.py | 155 ++-------- torch/distributed/fsdp/wrap.py | 288 ++++++++---------- torch/testing/_internal/common_fsdp.py | 20 +- 15 files changed, 244 insertions(+), 356 deletions(-) diff --git a/test/distributed/_composable/test_fully_shard.py b/test/distributed/_composable/test_fully_shard.py index 27e0fb855fba7cc..ba08deeafcdfb78 100644 --- a/test/distributed/_composable/test_fully_shard.py +++ b/test/distributed/_composable/test_fully_shard.py @@ -1,7 +1,6 @@ # Owner(s): ["oncall: distributed"] import copy -import functools import sys from typing import Any, Tuple @@ -12,7 +11,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp._common_utils import _is_fsdp_flattened from torch.distributed.fsdp._runtime_utils import _root_pre_forward -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( @@ -62,10 +61,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return z @staticmethod - def auto_wrap_policy(): - return functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls={SubModel} - ) + def policy(): + return ModuleWrapPolicy({SubModel}) def get_input(self, device=torch.device) -> Tuple[Any, ...]: return (torch.randn((8, 5), device=device),) @@ -85,13 +82,13 @@ def test_auto_wrap_policy(self): local_model = Model(device=torch.device("cuda")) fsdp_wrapped_model = FSDP( copy.deepcopy(local_model), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), use_orig_params=True, ) composable_module = copy.deepcopy(local_model) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), ) # Check that the composable module has the same names as the local @@ -138,7 +135,7 @@ def test_device_id(self): assert param.device == cpu_device fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), device_id=self.rank, ) for param in composable_module.parameters(): @@ -157,12 +154,12 @@ def test_sync_module_states(self): param.zero_() fsdp_wrapped_model = FSDP( copy.deepcopy(local_model), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), use_orig_params=True, ) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), sync_module_states=True, ) for (composable_param, fsdp_wrapped_param) in zip( @@ -197,13 +194,13 @@ def _param_init_fn(module: nn.Module): composable_module = Model(device="meta") fsdp_wrapped_model = FSDP( Model(device="meta"), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), param_init_fn=_param_init_fn, use_orig_params=True, ) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), param_init_fn=_param_init_fn, ) for (composable_param, fsdp_wrapped_param) in zip( @@ -227,13 +224,13 @@ def test_training(self): local_model = Model(device=device) fsdp_wrapped_model = FSDP( copy.deepcopy(local_model), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), use_orig_params=True, ) composable_module = copy.deepcopy(local_model) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), ) del local_model # not needed anymore LR = 1e-2 diff --git a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py index ddba50a9e4561f0..e587065c5c77f49 100644 --- a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py +++ b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import functools import itertools import sys from typing import Union @@ -12,7 +11,7 @@ CPUOffload, FullyShardedDataParallel as FSDP, ) -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu @@ -102,12 +101,11 @@ def _test_ddp_parity( ) ddp_model = DDP(local_model, device_ids=[self.rank]) fsdp_kwargs = { - "auto_wrap_policy": functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ + "auto_wrap_policy": ModuleWrapPolicy( + { TransformerEncoderLayer, TransformerDecoderLayer, - }, + } ), "cpu_offload": CPUOffload(offload_params=offload_params), "use_orig_params": use_orig_params, diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index 79ed6da6240fabf..8c972f8515634fe 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -15,7 +15,11 @@ FullyShardedDataParallel as FSDP, ShardingStrategy, ) -from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ( + always_wrap_policy, + ModuleWrapPolicy, + transformer_auto_wrap_policy, +) from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( @@ -211,10 +215,20 @@ def forward(self, x, y): def test_device_id_auto_wrap(self): """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all nested FSDP instances.""" - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, + self.run_subtests( + {"use_callable": [False, True]}, + self._test_device_id_auto_wrap, ) + + def _test_device_id_auto_wrap(self, use_callable: bool): + module_classes = {TransformerEncoderLayer, TransformerDecoderLayer} + if use_callable: + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=module_classes, + ) + else: + auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, "device_id": torch.cuda.current_device(), diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index ba51ae66ed1b21e..6fafc8e8fdf4a41 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -26,7 +26,7 @@ ) from torch.distributed.fsdp._shard_utils import _gather_state_dict from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM -from torch.distributed.fsdp.wrap import enable_wrap, transformer_auto_wrap_policy, wrap +from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel import DistributedDataParallel from torch.optim import SGD @@ -350,9 +350,8 @@ def test_state_dict_with_manual_ac_wrapper( @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) def test_state_dict_with_shared_parameters(self, state_dict_type): - auto_wrap_policy = partial( - transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} ) model_creator = partial( TransformerWithSharedParams.init, @@ -377,9 +376,8 @@ def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool): """Tests saving a model checkpoint only on rank 0 and loading it only on rank 0 with ``sync_module_states=True`` to emulate the workflow to avoid redundant CPU memory usage.""" - auto_wrap_policy = partial( - transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} ) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py index 24829ff408d9baa..0f5ffa564c2d472 100644 --- a/test/distributed/fsdp/test_fsdp_use_orig_params.py +++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py @@ -15,7 +15,7 @@ ShardingStrategy, ) from torch.distributed.fsdp._common_utils import clean_tensor_name -from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu @@ -117,12 +117,11 @@ def _get_fsdp_transformer_and_optim( # combination with the parameter group construction, ensures different # hyperparameter settings within one `FlatParameter` fsdp_kwargs = { - "auto_wrap_policy": functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ + "auto_wrap_policy": ModuleWrapPolicy( + { TransformerEncoderLayer, TransformerDecoderLayer, - }, + } ), "use_orig_params": True, "sharding_strategy": sharding_strategy, diff --git a/test/distributed/fsdp/test_utils.py b/test/distributed/fsdp/test_utils.py index e797325ccbc99f6..37c52547e8472ff 100644 --- a/test/distributed/fsdp/test_utils.py +++ b/test/distributed/fsdp/test_utils.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import functools import random import sys import unittest @@ -14,7 +13,7 @@ from torch import distributed as dist from torch.distributed.fsdp._utils import _apply_to_tensors from torch.distributed.fsdp._wrap_utils import _get_submodule_to_states -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.utils import _replace_by_prefix from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -173,9 +172,7 @@ def test_module_wrap_policy(self): # Compute the mapping from submodule to states according to a logical # module wrap policy module_classes = (nn.Sequential,) - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls=set(module_classes) - ) + auto_wrap_policy = ModuleWrapPolicy(set(module_classes)) submodule_to_states = _get_submodule_to_states( model, auto_wrap_policy, set(), set() ) diff --git a/test/distributed/fsdp/test_wrap.py b/test/distributed/fsdp/test_wrap.py index cd0d11ba9b4b171..e157f041ae1bd81 100644 --- a/test/distributed/fsdp/test_wrap.py +++ b/test/distributed/fsdp/test_wrap.py @@ -5,6 +5,7 @@ import tempfile import unittest from enum import auto, Enum +from typing import Callable, Union import torch import torch.nn as nn @@ -15,10 +16,12 @@ FullyShardedDataParallel as FSDP, ) from torch.distributed.fsdp.wrap import ( + _FSDPPolicy, _or_policy, _wrap_batchnorm_individually, always_wrap_policy, enable_wrap, + ModuleWrapPolicy, size_based_auto_wrap_policy, transformer_auto_wrap_policy, wrap, @@ -373,6 +376,19 @@ def test_transformer_auto_wrap_policy(self): transformer_auto_wrap_policy, transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, ) + self._test_transformer_wrapping(auto_wrap_policy) + + @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") + def test_module_wrap_policy(self): + """Tests the ``ModuleWrapPolicy``.""" + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} + ) + self._test_transformer_wrapping(auto_wrap_policy) + + def _test_transformer_wrapping( + self, auto_wrap_policy: Union[Callable, _FSDPPolicy] + ): fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, diff --git a/torch/distributed/_composable/fully_shard.py b/torch/distributed/_composable/fully_shard.py index 2d9e9329795bd63..174b2ca89a788ba 100644 --- a/torch/distributed/_composable/fully_shard.py +++ b/torch/distributed/_composable/fully_shard.py @@ -24,6 +24,7 @@ MixedPrecision, ShardingStrategy, ) +from torch.distributed.fsdp.wrap import _FSDPPolicy @contract @@ -32,7 +33,7 @@ def fully_shard( process_group: Optional[dist.ProcessGroup] = None, mixed_precision: Optional[MixedPrecision] = None, cpu_offload: Optional[CPUOffload] = None, - auto_wrap_policy: Optional[Callable] = None, + policy: Optional[_FSDPPolicy] = None, ignored_modules: Optional[Iterable[torch.nn.Module]] = None, device_id: Optional[Union[int, torch.device]] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None, @@ -41,6 +42,9 @@ def fully_shard( """ Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``. """ + # Enforce the new auto wrap policy + if policy is not None and not isinstance(policy, _FSDPPolicy): + raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}") state = fully_shard.state(module) state = _init_ignored_module_states(state, module, ignored_modules) state = _init_process_group_state(state, process_group) @@ -64,7 +68,7 @@ def fully_shard( state = _init_param_handles_from_module( state, module, - auto_wrap_policy, + policy, device_id, param_init_fn, sync_module_states, diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py index 324a3442dea95a3..b1bffdb25a0ebb4 100644 --- a/torch/distributed/fsdp/__init__.py +++ b/torch/distributed/fsdp/__init__.py @@ -11,4 +11,3 @@ ShardingStrategy, StateDictType, ) -from .wrap import ParamExecOrderWrapPolicy diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 1265ee3578ed40c..7e128251fcc49df 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -47,6 +47,7 @@ HandleConfig, HandleShardingStrategy, ) +from torch.distributed.fsdp.wrap import _FSDPPolicy from torch.distributed.utils import _sync_params_and_buffers from torch.utils.hooks import RemovableHandle @@ -262,7 +263,7 @@ def _init_param_handle_from_module( def _init_param_handles_from_module( state: _FSDPState, root_module: nn.Module, - auto_wrap_policy: Callable, + policy: _FSDPPolicy, device_id: Optional[Union[int, torch.device]], param_init_fn: Optional[Callable[[nn.Module], None]], sync_module_states: bool, @@ -273,7 +274,7 @@ def _init_param_handles_from_module( """ submodule_to_states = _get_submodule_to_states( root_module, - auto_wrap_policy, + policy, state._ignored_modules, state._ignored_params, ) diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index 34d1c9c1ac24309..cdda065df19936b 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -1,7 +1,7 @@ import collections import functools import warnings -from typing import Any, Callable, Deque, Dict, List, NamedTuple, Set, Tuple +from typing import Any, Deque, Dict, List, NamedTuple, Set, Tuple import torch import torch.nn as nn @@ -10,6 +10,7 @@ _override_batchnorm_mixed_precision, ) from torch.distributed.fsdp.wrap import ( + _FSDPPolicy, _or_policy, _recursive_wrap, _wrap_batchnorm_individually, @@ -45,6 +46,9 @@ def _auto_wrap( ``fsdp_kwargs`` contains all FSDP arguments except ``module``. """ auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"] + # Support new way to pass an auto wrap policy + if isinstance(auto_wrap_policy, _FSDPPolicy): + auto_wrap_policy = auto_wrap_policy.policy root_module = auto_wrap_kwargs["module"] assert auto_wrap_policy is not None # For auto wrapping, submodules should not already be wrapped with FSDP @@ -68,13 +72,13 @@ def _auto_wrap( "instances with mixed precision disabled since some batch norm " "kernels do not support low precision." ) - auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy + auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy _recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs) def _get_submodule_to_states( root_module: nn.Module, - auto_wrap_policy: Callable, + auto_wrap_policy: _FSDPPolicy, ignored_modules: Set[nn.Module], ignored_params: Set[nn.Parameter], ) -> Dict[nn.Module, SubmoduleState]: @@ -99,7 +103,7 @@ def _get_submodule_to_states( wrapper_cls = functools.partial(_record_module_wrapper_cls, wrapped_modules) _recursive_wrap( root_module, - auto_wrap_policy=auto_wrap_policy, + auto_wrap_policy=auto_wrap_policy.policy, wrapper_cls=wrapper_cls, ignored_modules=ignored_modules, ignored_params=ignored_params, @@ -158,8 +162,9 @@ def _record_module_wrapper_cls( **kwargs, ) -> nn.Module: """ - This defines a wrapper class to be passed to ``_recursive_wrap()`` that - records the wrapped module to the input ``wrapped_modules``. + This defines a pseudo-wrapper class to be passed to ``_recursive_wrap()`` + that records the wrapped module to the input ``wrapped_modules`` without + actually wrapping with a class. """ wrapped_modules.append(module) return module diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py index b790590c7943f02..b5892bca683a21e 100644 --- a/torch/distributed/fsdp/flat_param.py +++ b/torch/distributed/fsdp/flat_param.py @@ -838,7 +838,8 @@ def needs_unshard(self) -> bool: return False unsharded_flat_param = self._get_padded_unsharded_flat_param() already_unsharded = ( - unsharded_flat_param._typed_storage()._size() == unsharded_flat_param.numel() + unsharded_flat_param._typed_storage()._size() + == unsharded_flat_param.numel() ) return not already_unsharded diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 510f90de20234ee..69c8dd92ed8dcb6 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -96,14 +96,6 @@ ) from ._utils import p_assert from .flat_param import FlatParameter, FlatParamHandle -from .wrap import ParamExecOrderWrapPolicy - - -_TORCH_FX_AVAIL = True -if not hasattr(torch, "fx"): - _TORCH_FX_AVAIL = False -if _TORCH_FX_AVAIL: - from ._symbolic_trace import _init_execution_info, _patch_tracer, TracingConfig __all__ = [ @@ -207,37 +199,36 @@ class FullyShardedDataParallel(nn.Module): This configures CPU offloading. If this is set to ``None``, then no CPU offloading happens. See :class:`CPUOffload` for details. (Default: ``None``) - auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): - A callable specifying a policy to recursively wrap layers with FSDP. - Note that this policy currently will only apply to child modules of - the passed in module. The remainder modules are always wrapped in - the returned FSDP root instance. - ``size_based_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is - an example of ``auto_wrap_policy`` callable, this policy wraps layers - with the number of parameters larger than 100M. ``transformer_auto_wrap_policy`` - written in ``torch.distributed.fsdp.wrap`` is an example of ``auto_wrap_policy`` - callable for transformer-like model architectures. Users can supply the customized - ``auto_wrap_policy`` callable that should accept following arguments: - ``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``, and return - a ``bool`` specifying whether the passed in ``module``` should be wrapped - (if ``recurse=False``) or whether we should recurse down the subgraph of ``module`` - children (if ``recurse=True``). Extra customized arguments could be added to - the customized ``auto_wrap_policy`` callable as well. It is a good practice to - print out the sharded model and check whether the sharded model is what - the application wants and then adjust accordingly. + auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], _FSDPPolicy]]): + This is either ``None``, an ``_FSDPPolicy``, or a callable of + a fixed signature. If it is ``None``, then ``module`` is wrapped + with only a top-level FSDP instance without any nested wrapping. If + it is an ``_FSDPPolicy``, then the wrapping follows the given + policy. ``ModuleWrapPolicy`` in ``torch.distributed.fsdp.wrap.py`` + is an example. If it is a callable, then it should take in three + arguments ``module: nn.Module``, ``recurse: bool``, and + ``nonwrapped_numel: int`` and should return a ``bool`` specifying + whether the passed-in ``module`` should be wrapped if + ``recurse=False`` or if the traversal should continue down the + subtree if ``recurse=True``. Additional custom arguments may be + added to the callable. The ``size_based_auto_wrap_policy`` in + ``torch.distributed.fsdp.wrap.py`` gives an example callable that + wraps a module if the parameters in its subtree exceed 100M numel. + A good practice is to print the model after wrapping and adjust as + needed. Example:: >>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, - >>> unwrapped_params: int, - >>> # These are customizable for this policy function. + >>> nonwrapped_numel: int, + >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: - >>> return unwrapped_params >= min_num_params - >>> # Configure a custom min_num_params - >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=1e5) + >>> return nonwrapped_numel >= min_num_params + >>> # Configure a custom `min_num_params` + >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) backward_prefetch (Optional[BackwardPrefetch]): This configures explicit backward prefetching of all-gathers. See @@ -337,25 +328,6 @@ def __init__( limit_all_gathers: bool = False, use_orig_params: bool = False, ): - if isinstance(auto_wrap_policy, ParamExecOrderWrapPolicy): - self._init_param_exec_order_wrap_policy( - module=module, - process_group=process_group, - sharding_strategy=sharding_strategy, - cpu_offload=cpu_offload, - auto_wrap_policy=auto_wrap_policy, - backward_prefetch=backward_prefetch, - mixed_precision=mixed_precision, - ignored_modules=ignored_modules, - param_init_fn=param_init_fn, - device_id=device_id, - sync_module_states=sync_module_states, - forward_prefetch=forward_prefetch, - limit_all_gathers=limit_all_gathers, - use_orig_params=use_orig_params, - ) - return - torch._C._log_api_usage_once("torch.distributed.fsdp") super().__init__() @@ -1815,89 +1787,6 @@ def register_comm_hook(self, state: object, hook: callable): submodule._communication_hook_state = state submodule._communication_hook = hook - def _init_param_exec_order_wrap_policy(self, *args, **kwargs) -> None: - auto_wrap_policy = kwargs["auto_wrap_policy"] - module = kwargs["module"] - assert hasattr(auto_wrap_policy, "tracing_config") - if not _TORCH_FX_AVAIL: - assert ( - auto_wrap_policy.tracing_config is None - ), "tracing_config should be None when torch.fx is not enabled" - elif isinstance(auto_wrap_policy.tracing_config, TracingConfig): - tracer = auto_wrap_policy.tracing_config.tracer - execution_info = _init_execution_info(module) - - for m in module.modules(): - assert not isinstance( - m, FullyShardedDataParallel - ), "The input module of _patch_tracer should not contain FSDP modules" - - with _patch_tracer( - tracer=tracer, - root_module=module, - execution_info=execution_info, - ): - try: - tracer.trace(module, auto_wrap_policy.tracing_config.concrete_args) - except BaseException as e: - raise RuntimeError( - "tracer.trace failed inside _init_param_exec_order_wrap_policy" - f" with the error: {e}." - ) - else: - assert ( - auto_wrap_policy.tracing_config is None - ), "tracing_config should either be an instance of TracingConfig or be None" - # The initial FSDP wrapping is done with auto_wrap_policy.init_policy - kwargs["auto_wrap_policy"] = auto_wrap_policy.init_policy - self.__init__(*args, **kwargs) - self._param_exec_order_policy: bool = True - # self._param_exec_order_prep_stage is set to True before we get the execution order - self._param_exec_order_prep_stage: bool = True - # A list that stores the flatten parameters and its name based on the parameter execution order - self._fsdp_params_exec_order: List[FlatParameter] = [] - if _TORCH_FX_AVAIL and isinstance( - auto_wrap_policy.tracing_config, TracingConfig - ): - # Initialize a dict that maps each module to its parent FSDP wrap - module_to_fsdp: Dict[nn.Module, FullyShardedDataParallel] = dict() - for wrap in self.fsdp_modules(self): - module_to_fsdp[wrap.module] = wrap - # Set self._fsdp_params_exec_order based on execution_info.module_forward_order. - # TODO (linjianma): self._fsdp_params_exec_order will be set based on - # the parameter execution order rather than module_forward_order, - # once the non-recursive wrapping policy is fully implemented. - for m in execution_info.module_forward_order: - if m in module_to_fsdp: - for flat_param in module_to_fsdp[m].params: - self._fsdp_params_exec_order.append(flat_param) - self._param_exec_order_prep_stage = False - - for m in self.modules(): - if m is not self and isinstance(m, FullyShardedDataParallel): - # Assignment by reference, so each children FSDP wrap has access to - # the _fsdp_params_exec_order of the root module - m._fsdp_params_exec_order = self._fsdp_params_exec_order - m._param_exec_order_policy = self._param_exec_order_policy - m._param_exec_order_prep_stage = self._param_exec_order_prep_stage - - def _use_param_exec_order_policy(self) -> bool: - return ( - hasattr(self, "_param_exec_order_policy") and self._param_exec_order_policy - ) - - def _is_param_exec_order_prep_stage(self) -> bool: - is_prep_stage = ( - hasattr(self, "_param_exec_order_prep_stage") - and self._param_exec_order_prep_stage - ) - if not is_prep_stage: - for p in self.parameters(): - assert not hasattr( - p, "_params_exec_order_hook_handle" - ), "When not in execution order prep stage, all _params_exec_order_hook_handle should be removed." - return is_prep_stage - def _get_grad_norm( params: List[nn.Parameter], diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index c529bcde8c859b5..e20c07f18d13242 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. import contextlib -from dataclasses import dataclass +import functools +from abc import ABC, abstractmethod from typing import Any, Callable, cast, Dict, Generator, Optional, Set, Tuple, Type import torch.nn as nn @@ -17,22 +18,84 @@ "size_based_auto_wrap_policy", "enable_wrap", "wrap", - "ParamExecOrderWrapPolicy", + "ModuleWrapPolicy", ] def always_wrap_policy(*args, **kwargs) -> bool: """ - A simple wrapper policy that always returns ``True``, - i.e. when passed as the `auto_wrap_policy` into FSDP, - this will result in all submodules being wrapped as - distinct FSDP instances. + A simple recursive wrap policy that always returns ``True``. This means + that every submodule is wrapped by the wrapper class in + :func:`_recursive_wrap`. """ return True +class _FSDPPolicy(ABC): + """ + This defines an abstract base class that represents an FSDP policy for + constructing ``FlatParameter`` s. + """ + + # The motivation for this abstract base class is to hide the interface + # expected by `_recursive_wrap()` from users (i.e. the `recurse` argument). + def __init__(self): + ... + + @property + @abstractmethod + def policy(self) -> Callable: + ... + + +def _module_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + module_classes: Set[Type[nn.Module]], +) -> bool: + """ + This auto wrap policy wraps every module that is an instance of any type in + ``module_classes`` as its own FSDP instance. The root module given by + ``module`` is always wrapped as an FSDP instance regardless. Since the + wrapping proceeds bottom up, each FSDP instance manages the parameters in + its subtree excluding any already managed by a child FSDP instance. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + module_classes (Set[Type[nn.Module]]): Set of module classes that are + wrapped as FSDP instances. + + Returns: + ``True`` if ``recurse=True``, and whether ``module`` should be wrapped + if ``recurse=False``. + """ + if recurse: + return True # always recurse + return isinstance(module, tuple(module_classes)) + + +class ModuleWrapPolicy(_FSDPPolicy): + """This is a wrapper around :func:`_module_wrap_policy`.""" + + def __init__(self, module_classes: Set[Type[nn.Module]]): + self._policy: Callable = functools.partial( + _module_wrap_policy, + module_classes=module_classes, + ) + + @property + def policy(self): + return self._policy + + def lambda_auto_wrap_policy( - module: nn.Module, recurse: bool, unwrapped_params: int, lambda_fn: Callable + module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable ) -> bool: """ A convenient auto wrap policy to wrap submodules based on an arbitrary user @@ -44,70 +107,34 @@ def lambda_auto_wrap_policy( The first three parameters are required by :func:`_recursive_wrap`. Args: - module (nn.Module): - The module to be considered in this decision. - recurse (bool): - Indicate if this is called to make a decision on whether we - should recurse down a subgraph of the module structure. - If False, it means this function is called to make a decision - on whether we should wrap the said module. - unwrapped_params (int): - The number of parameters yet to be wrapped in this module. - - lambda_fn (Callable[nn.Module] -> bool): - If this returns ``True``, this module will be wrapped by - wrapper_cls individually. + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then + this module will be wrapped. """ if recurse: - # always recurse - return True - else: - # if not recursing, decide whether we should wrap for the leaf node or reminder - return lambda_fn(module) + return True # always recurse + return lambda_fn(module) def transformer_auto_wrap_policy( module: nn.Module, recurse: bool, - unwrapped_params: int, + nonwrapped_numel: int, transformer_layer_cls: Set[Type[nn.Module]], ) -> bool: """ - A convenient auto wrap policy for transformer models. If the submodule - is an instance of transformer_layer_cls, the submodule will be wrapped - as a FSDP unit. Otherwise, all the other remainder submodules are wrapped - by the outermost FSDP unit. Right now, FSDP requires submodules that share - weights to be wrapped in the same FSDP unit, this auto wrap policy can - conviniently wrap the shared embeddings into the same FSDP unit for transformer - models. In the near future, FSDP will support submodules that share weights - to be wrapped in the separated FSDP units. - - Return if a module should be wrapped during FSDP auto wrapping. - - The first three parameters are required by :func:`_recursive_wrap`. - - - Args: - module (nn.Module): - The module to be considered in this decision. - recurse (bool): - Indicate if this is called to make a decision on whether we - should recurse down a subgraph of the module structure. - If False, it means this function is called to make a decision - on whether we should wrap the said module. - unwrapped_params (int): - The number of parameters yet to be wrapped in this module. - - transformer_layer_cls (int): - Submodules with one of the `transformer_layer_cls` names - will be wrapped as separated FSDP units + See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the + same as ``module_classes``. Note that shared parameters must be wrapped in + the same FSDP instance, so this auto wrap policy can help wrap shared + embeddings into the same FSDP instance for transformer models. """ - if recurse: - # always recurse - return True - else: - # if not recursing, decide whether we should wrap for the leaf node or reminder - return isinstance(module, tuple(transformer_layer_cls)) + return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls) def _wrap_batchnorm_individually( @@ -117,7 +144,7 @@ def _wrap_batchnorm_individually( **kwargs, ) -> bool: """ - A policy that wraps ``BatchNorm`` instances in their own FSDP unit. + A policy that wraps ``BatchNorm`` instances in their own FSDP instance. """ if recurse: # always recurse @@ -131,52 +158,46 @@ def _wrap_batchnorm_individually( def _or_policy( module: nn.Module, recurse: bool, - unwrapped_params: int, + nonwrapped_numel: int, policies, ) -> bool: """ A policy that wraps ``module`` if any policy in the passed in iterable of ``policies`` returns ``True``. """ - return any(policy(module, recurse, unwrapped_params) for policy in policies) + return any(policy(module, recurse, nonwrapped_numel) for policy in policies) def size_based_auto_wrap_policy( module: nn.Module, recurse: bool, - unwrapped_params: int, - # These are customizable for this policy function. + nonwrapped_numel: int, + # Additional custom arguments min_num_params: int = int(1e8), force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, ) -> bool: - """A size based auto_wrap_policy function for FSDP API. - - Return if a module should be wrapped during FSDP auto wrapping. - - The first three parameters are used by :func:`_recursive_wrap`. If - you write a custom version of this policy function, your version - needs to at least accept the first three parameters and free - to do whatever you want in the function. + """ + A size-based auto wrap policy. Args: - module (nn.Module): - The module to be considered in this decision. - recurse (bool): - Indicate if this is called to make a decision on whether we - should recurse down a subgraph of the module structure. - If False, it means this function is called to make a decision - on whether we should wrap the said module. - unwrapped_params (int): - The number of parameters yet to be wrapped in this module. - - min_num_params (int): - Customizable policy input. It controls the size threshold - on how big should a module be to be considered wrapped. - force_leaf_modules (Set[Type[nn.Module]]): set of module types to - keep as leaves, i.e., their children will never be wrapped. - exclude_wrap_modules (Set[Type[nn.Module]]): - Customizable set of module types to be excluded in wrapping. + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + min_num_params (int): Customizable policy input that controls the size + threshold over which a module is ready to be wrapped. This is in + units of numel. + force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep + as leaves, i.e. their children will never be wrapped. + exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be + excluded in wrapping. + + Returns: + Whether ``module`` should be wrapped. """ force_leaf_modules = ( size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined] @@ -189,7 +210,10 @@ def size_based_auto_wrap_policy( else exclude_wrap_modules ) - is_large = unwrapped_params >= min_num_params + # Keep the argument `min_num_params` for BC for now, but it represents the + # minimum non-wrapped *numel* before triggering a wrapping + min_nonwrapped_numel = min_num_params + is_large = nonwrapped_numel >= min_nonwrapped_numel if recurse: # We should recurse if the module is big enough but not in force_leaf_modules list. return is_large and not isinstance(module, tuple(force_leaf_modules)) @@ -276,56 +300,6 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: return module -@dataclass -class ParamExecOrderWrapPolicy: - """ - This is the class used for the wrapping policy that wraps parameters and performs - the communication scheduling based on the parameter execution order in the forward pass - (also called non-recursive wrapping policy). - - The policy contains multiple wraps. Each wrap contains original parameters that will be executed together, - and the wrap transfers these parameters into one ``FlattenParameter``. In both forward and the backward passes, - the sharded parameters in each wrap will be gathered just before these parameters are used in the passes. - These parameters will then be reshaded once they have been used. - - TODO (linjianma): For now, the parameters contained in each wrap of ``ParamExecOrderWrapPolicy`` - are the parameters in each wrap of the ``init_policy`` (a recursive wrapping policy). - Later we will wrap parameters based on bucket size. - - Args: - init_policy (Callable): - The initial recursive wrapping policy used to guide the wrapping of - this policy. If tracing_config is none, in the first forward and - backward iteration, ``init_policy`` is used to record parameter - execution order. Otherwise, init_policy is only used in FSDP - constructor for module level wrapping. - - The default ``always_wrap_policy`` might not be the best choice for every model. For example, for - transformer based models, setting ``transformer_auto_wrap_policy`` as the ``init_policy`` will guarantee - wrapping each transformer layer into one FSDP unit, and can be easily combined with checkpointing - within each transformer layer. - - tracing_config (Optional[TracingConfig]): - The configuration used to perform symbolic tracing at FSDP - constructor to get the module and parameter execution order. The - type of ``tracing_config`` needs to be either ``None`` or - ``TracingConfig``. If set as ``None``, then symbolic tracing is not - enabled, and one forward as well as backward iteration are needed to - get the parameter execution order. - - ..warning :: Note that not all modules can be successfully traced when - ``tracing_config`` is not None and symbolic tracing is enabled. The two - cases below may be unable to trace: 1. when there is a data-dependent - branch, 2. when the forward pass contains operators that don't support - ``torch.fx.Proxy`` as the input type (e.g. ``arange``, ``zeros``, ``ones``, - ``full``, ``full_like``, ``eye``, ``empty``, ``tensor``). For those cases, - users can set ``tracing_config = None`` to disable symbolic tracing. - """ - - init_policy: Callable = always_wrap_policy - tracing_config: Any = None - - def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: assert wrapper_cls is not None if hasattr(module, "_wrap_overrides"): @@ -349,13 +323,13 @@ def _recursive_wrap( **kwargs: Any, ) -> Tuple[nn.Module, int]: """ - Automatically wrap child modules of *module* that meet the given - criteria with :func:`auto_wrap`. Does not rely on _ConfigAutoWrap. + Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns + ``True`` with ``wrapper_cls``. + Args: - module (nn.Module): - module to recursively wrap - auto_wrap_policy (Callable): - A callable specifying a policy to recursively wrap layers with FSDP. + module (nn.Module): Module to recursively wrap. + auto_wrap_policy (Callable): A callable representing a policy that + determines which modules to recursively wrap with ``wrapper_cls``. ignored_modules (Set[torch.nn.Module]): Modules to ignore when wrapping. ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when @@ -363,7 +337,7 @@ def _recursive_wrap( in ``ignored_modules``. Returns: (nn.Module, int): - Wrapped module and the number parameters wrapped recursively. + ``module`` after wrapping and the numel recursively wrapped. """ assert auto_wrap_policy is not None, "Must specify auto_wrap_policy." assert wrapper_cls is not None, "Must specify wrapper_cls" @@ -378,11 +352,13 @@ def _recursive_wrap( pass # We count all params, assuming none of them are already wrapped. - num_params = sum(p.numel() for p in module.parameters() if p not in ignored_params) + nonwrapped_numel = sum( + p.numel() for p in module.parameters() if p not in ignored_params + ) assert auto_wrap_policy is not None - if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params): - total_wrapped_params = 0 + if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): + total_wrapped_numel = 0 # Iterate through the children, recursively wrap if necessary for name, child in module.named_children(): if child in ignored_modules: @@ -397,17 +373,17 @@ def _recursive_wrap( ) setattr(module, name, wrapped_child) # Keep track of how many parameters have been wrapped - total_wrapped_params += num_wrapped_params + total_wrapped_numel += num_wrapped_params # decide if we need to wrap the current module, # since the left over parameters exceed the number of params to wrap - remainder = num_params - total_wrapped_params + remainder = nonwrapped_numel - total_wrapped_numel if not only_wrap_children and auto_wrap_policy( - module=module, recurse=False, unwrapped_params=remainder + module=module, recurse=False, nonwrapped_numel=remainder ): # Leaf node or final wrapping of the remainder both happen here. - return _wrap(module, wrapper_cls, **kwargs), num_params + return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel else: - return module, total_wrapped_params + return module, total_wrapped_numel return module, 0 diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 0dca22f48092bdc..b4650adff569b8f 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import functools import itertools import sys from abc import ABC, abstractmethod @@ -21,11 +20,7 @@ ShardingStrategy, ) from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from torch.distributed.fsdp.wrap import ( - always_wrap_policy, - transformer_auto_wrap_policy, - wrap, -) +from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS @@ -285,8 +280,8 @@ def init( fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap any modules with FSDP. If ``RECURSIVE``, then wraps with top-level FSDP. By default, the top-level FSDP uses the - ``transformer_auto_wrap_policy()`` for encoder and decoder - layers, but a different auto wrap policy may be specified via + ``ModuleWrapPolicy`` for encoder and decoder layers, but a + different auto wrap policy may be specified via ``fsdp_kwargs``. cuda_init_mode (CUDAInitMode): Determines model movement to CUDA. fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments @@ -302,14 +297,13 @@ def init( group, cuda_init_mode, add_bn, deterministic ) elif fsdp_init_mode == FSDPInitMode.RECURSIVE: - # Default to the `transformer_auto_wrap_policy()` + # Default to the `ModuleWrapPolicy` if "auto_wrap_policy" not in fsdp_kwargs: - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ + auto_wrap_policy = ModuleWrapPolicy( + { TransformerEncoderLayer, TransformerDecoderLayer, - }, + } ) else: auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy") From c83348597b195f2da1cca0e8318c878b104bce5d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 12 Nov 2022 04:45:17 +0000 Subject: [PATCH 56/62] [dynamo][api] Better support of torch.nn.Module (#88629) This is an API change, so please review carefully. With this PR, torchdynamo returns an `OptimizedModule` class object, a subclass of `torch.nn.Module`, when asked to optimize a `nn.Module` object. Most of the methods are redirected to the original `nn.Module`, which is installed as `_mod` in the `OptimizedModule`. This is helpful for many cases ``` mod = MockModule() opt_mod = torch._dynamo.optimize()(mod) print(opt_mod) # Works opt_mod = opt_mod.to(device="cuda") print(opt_mod) # Works opt_mod(input) # Triggers recompile if necessary, earlier we were shedding the TorchDynamo wrapper opt_mod.parameters() # Refers to the original module ``` Topics unclear to me * I have overridden many methods to raise NotImplementedError. A careful review of those will be good. * hooks * For the optimized forward, should we call torchdynamo optimization on `__call__` or `forward` * What else to test Pull Request resolved: https://github.com/pytorch/pytorch/pull/88629 Approved by: https://github.com/Chillee, https://github.com/jansel, https://github.com/msaroufim --- test/dynamo/test_modules.py | 127 +++++++++++++++++++++++++++++++++++ torch/_dynamo/__init__.py | 2 + torch/_dynamo/debug_utils.py | 8 +++ torch/_dynamo/eval_frame.py | 74 ++++++++++++++------ torch/_dynamo/testing.py | 13 ++++ 5 files changed, 204 insertions(+), 20 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 2fb83b3add6cfbc..930035f99a30c33 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -904,6 +904,133 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) +class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + self.linear = torch.nn.Linear(10, 10) + self.register_buffer("buf0", torch.randn(10, 10)) + + def forward(self, x): + return self.relu(self.linear(x) + self.buf0) + + +class OptimizedModuleTest(torch._dynamo.test_case.TestCase): + def test_nn_module(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) + + x = torch.randn(10, 10) + self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + def test_to(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + x = torch.randn(10, 10) + self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + # Ensure that there is no recompilation + opt_mod(x) + self.assertEqual(cnt.frame_count, 1) + + opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) + self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) + x = torch.randn(10, 10).to(dtype=torch.float64) + opt_mod(x) + # Ensure that there is a recompilation + self.assertEqual(cnt.frame_count, 2) + + def test_attr(self): + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.register_buffer("buf0", torch.randn(10, 10)) + + def forward(self, x): + return self.r(torch.sin(x)) + self.buf0 + + mod = MockModule() + opt_mod = torch._dynamo.optimize("eager")(mod) + + # Check parameteres and buffers + for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): + self.assertTrue(id(p1) == id(p2)) + + def test_recursion(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + + for _ in range(5): + opt_mod = torch._dynamo.optimize(cnt)(opt_mod) + opt_mod(torch.randn(10, 10)) + self.assertEqual(cnt.frame_count, 1) + + def test_composition(self): + class InnerModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(torch.sin(x)) + + opt_inner_mod = InnerModule() + + class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = opt_inner_mod + + def forward(self, x): + return self.mod(torch.cos(x)) + + outer_mod = OuterModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) + + x = torch.randn(4) + self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) + self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + def test_composition_with_opt_mod(self): + class InnerModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(torch.sin(x)) + + inner_mod = InnerModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) + + class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = opt_inner_mod + + def forward(self, x): + return self.mod(torch.cos(x)) + + outer_mod = OuterModule() + opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) + + x = torch.randn(4) + self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) + self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) + # There will be a graph break for the inner mod being OptimizedModule + self.assertEqual(cnt.frame_count, 2) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 80f927aeef2fad9..5eee609b0852a21 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -7,6 +7,7 @@ export, optimize, optimize_assert, + OptimizedModule, reset_code, run, skip, @@ -25,6 +26,7 @@ "reset", "list_backends", "skip", + "OptimizedModule", ] diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 98a269fe8c9eb57..29d830167b109b3 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -515,8 +515,16 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): """ Check two models have same accuracy. """ + from .eval_frame import OptimizedModule + from .testing import named_parameters_for_optimized_module from .utils import same + if isinstance(gm, OptimizedModule): + gm.named_parameters = named_parameters_for_optimized_module(gm) + + if isinstance(opt_gm, OptimizedModule): + opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) + ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) try: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 8d9e3b7b6aa1469..20e8c7de085e0b7 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -5,6 +5,7 @@ import logging import os import sys +import textwrap import threading import traceback import types @@ -44,6 +45,27 @@ most_recent_backend = None +class OptimizedModule(torch.nn.Module): + """ + Wraps the original nn.Module object and later patches its + forward method to optimized self.forward method. + """ + + def __init__(self, mod): + super().__init__() + # Installs the params/buffer + self._orig_mod = mod + + def __getattr__(self, name): + if name == "_orig_mod": + return self._modules["_orig_mod"] + return getattr(self._orig_mod, name) + + def forward(self, *args, **kwargs): + # This will be monkey patched later + raise RuntimeError("Should not be here") + + def remove_from_cache(f): """ Make sure f.__code__ is not cached to force a recompile @@ -118,31 +140,15 @@ def __call__(self, fn): # Optimize the forward method of torch.nn.Module object if isinstance(fn, torch.nn.Module): mod = fn - optimized_forward = self(mod.forward) - - class TorchDynamoNNModuleWrapper: - """ - A wrapper that redirects the forward call to the optimized - forward, while for rest it redirects the calls to the original - module. - """ - - def __getattr__(self, name): - return getattr(mod, name) - - def forward(self, *args, **kwargs): - return optimized_forward(*args, **kwargs) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - new_mod = TorchDynamoNNModuleWrapper() + new_mod = OptimizedModule(mod) + new_mod.forward = self(mod.forward) # Save the function pointer to find the original callable while nesting # of decorators. - new_mod._torchdynamo_orig_callable = mod + new_mod._torchdynamo_orig_callable = mod.forward return new_mod assert callable(fn) + callback = self.callback on_enter = self.on_enter backend_ctx_ctor = self.extra_ctx_ctor @@ -184,6 +190,34 @@ def _fn(*args, **kwargs): # If the function is called using torch._dynamo.optimize decorator, we # should prevent any type of skipping. if callback not in (None, False): + if not hasattr(fn, "__code__"): + raise RuntimeError( + textwrap.dedent( + """ + + torch._dynamo.optimize is called on a non function object. + If this is a callable class, please optimize the individual methods that you are interested in optimizing. + + >> class CallableClass: + >> def __init__(self): + >> super().__init__() + >> self.relu = torch.nn.ReLU() + >> + >> def __call__(self, x): + >> return self.relu(torch.sin(x)) + >> + >> def print_hello(self): + >> print("Hello world") + >> + >> mod = CallableClass() + + If you want to optimize the __call__ function + + >> mod.__call__ = torch._dynamo.optimize(mod.__call__) + + """ + ) + ) always_optimize_code_objects[fn.__code__] = True return _fn diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index d6082ce48acf833..b37299ffd5791fe 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -32,6 +32,17 @@ def clone_me(x): return x.detach().clone().requires_grad_(x.requires_grad) +def named_parameters_for_optimized_module(mod): + assert isinstance(mod, eval_frame.OptimizedModule) + return mod._orig_mod.named_parameters + + +def remove_optimized_module_prefix(name): + prefix = "_orig_mod." + assert name.startswith(prefix) + return name[len(prefix) :] + + def collect_results(model, prediction, loss, example_inputs): results = [] results.append(prediction) @@ -44,6 +55,8 @@ def collect_results(model, prediction, loss, example_inputs): grads = dict() params = dict() for name, param in model.named_parameters(): + if isinstance(model, eval_frame.OptimizedModule): + name = remove_optimized_module_prefix(name) param_copy = param grad = param.grad # Treat None and zero grad as same From 34641c4384328ad9a3d2dc928de5b60a239427ee Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 12 Nov 2022 05:16:41 +0000 Subject: [PATCH 57/62] Revert "Add comprehensive minifier tests (#88022)" This reverts commit 5ff600aa6e40c6b4d426594bbb1f446f005b7fb3. Reverted https://github.com/pytorch/pytorch/pull/88022 on behalf of https://github.com/wconstab due to Seems to be causing CI failures relating to minifier test and some /tmp/ path not existing --- test/dynamo/test_minifier.py | 630 ++++------------------------------- torch/_dynamo/debug_utils.py | 78 +---- 2 files changed, 76 insertions(+), 632 deletions(-) diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 51b79a5e7511ea3..0cec7d202a9d446 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -1,138 +1,27 @@ # Owner(s): ["module: dynamo"] -import functools import os -import re import shutil -import subprocess -import textwrap import unittest +from unittest.mock import patch import torch import torch._dynamo import torch._dynamo.test_case import torch._dynamo.testing -import torch._inductor.utils -from torch._dynamo.debug_utils import TEST_REPLACEABLE_COMMENT +from torch._dynamo.optimizations.backends import create_backend -_HAS_TRITON = torch._inductor.utils.has_triton() -requires_cuda = functools.partial(unittest.skipIf, not _HAS_TRITON, "requires cuda") -RELU_COMPILE_ERROR_BACKEND = """\ -from torch._dynamo.optimizations.backends import register_backend +class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() -class DynamoCompileError(Exception): - pass - -@register_backend -def test_relu_compile_error(gm: torch.fx.GraphModule, example_inputs): - for node in gm.graph.nodes: - if node.target == torch.relu: - raise DynamoCompileError("relu found") - return gm -""" - -RELU_RUNTIME_ERROR_BACKEND = """\ -import copy -from torch._dynamo.optimizations.backends import register_backend - -@register_backend -def test_relu_runtime_error(gm: torch.fx.GraphModule, example_inputs): - gm = copy.deepcopy(gm) - for node in gm.graph.nodes: - if node.target == torch.relu: - node.target = torch._assert - node.args = (False, "DynamoRuntimeError") - gm.recompile() - return gm -""" - -RELU_ACCURACY_ERROR_BACKEND = """\ -import copy -from torch._dynamo.optimizations.backends import register_backend - -@register_backend -def test_relu_accuracy_error(gm: torch.fx.GraphModule, example_inputs): - gm = copy.deepcopy(gm) - for node in gm.graph.nodes: - if node.target == torch.relu: - node.target = torch.add - node.args = (node.args[0], 1) - gm.recompile() - - return gm -""" - -RELU_CUSTOM_ERROR_BACKEND = """\ -class CustomError(Exception): - pass - -def test_relu_custom_error(gm: torch.fx.GraphModule, example_inputs): - for node in gm.graph.nodes: - if node.target == torch.relu: - raise CustomError("relu found") - return gm -""" - -CPP_COMPILE_ERROR = """\ -def cpp_compile_error(x): - return "compile error!" -""" - -CPP_RUNTIME_ERROR = """\ -def cpp_runtime_error(x): - return f"{x}; throw 1" -""" - -CPP_ACCURACY_ERROR = """\ -def cpp_accuracy_error(x): - return f"{x} + 1" -""" - -TRITON_COMPILE_ERROR = """\ -def triton_compile_error(x): - return "compile error!" -""" - -# NOTE: there is currently not an easy way to cause a triton runtime error. -TRITON_RUNTIME_ERROR = """\ -def triton_runtime_error(x): - return f"{x}; assert?" -""" - -TRITON_ACCURACY_ERROR = """\ -def triton_accuracy_error(x): - return f"{x} + 1" -""" - -DEBUG_DIR = "/tmp/_torchdynamo_debug_/" - -# Search for the name of the first function defined in a code string. -def get_fn_name(code): - fn_name_match = re.search(r"def (\w+)\(", code) - if fn_name_match is not None: - return fn_name_match.group(1) - return None - - -# Generates code that patches CppOverrides/TritonOverrides. -def gen_codegen_fn_patch_code(old_fn_name, new_fn_code, device): - new_fn_name = get_fn_name(new_fn_code) - if new_fn_name is not None: - patch_code = f"""\ -import torch._inductor.codegen.{"cpp" if device == "cpu" else "triton"} as codegen -overrides = codegen.{"CppOverrides" if device == "cpu" else "TritonOverrides"} -{new_fn_code} -overrides.{old_fn_name} = staticmethod({new_fn_name}) -""" - return f"""\ -{patch_code} -isolate_fails_code_str = \"\"\"\\ -{patch_code} -torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" -\"\"\" -""" - - return None + def forward(self, x): + for _ in range(10): + x = torch.sin(x) + x = torch._foobar(x) + for _ in range(10): + x = torch.cos(x) + return x class MinfierTests(torch._dynamo.test_case.TestCase): @@ -143,10 +32,9 @@ def setUpClass(cls): unittest.mock.patch.object( torch._dynamo.config, "debug_dir_root", - DEBUG_DIR, + "/tmp/_torchdynamo_debug_/", ) ) - os.makedirs(DEBUG_DIR, exist_ok=True) @classmethod def tearDownClass(cls): @@ -159,455 +47,65 @@ def setUp(self): def tearDown(self): super().tearDown() - # Run `code` in a separate python process. - # Returns the completed process state and the directory containing the - # minifier launcher script, if `code` outputted it. - def _run_test_code(self, code): - proc = subprocess.run( - ["python3", "-c", code], capture_output=True, cwd=DEBUG_DIR - ) - - repro_dir_match = re.search( - r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8") - ) - if repro_dir_match is not None: - # Print repro directory for debugging generated code. - # Make sure to comment out `shutil.rmtree...` above as well. - print("repro dir:", repro_dir_match.group(1)) - return proc, repro_dir_match.group(1) - return proc, None - - # Patch generated files with testing patches - def _inject_code(self, patch_code, filename): - patch_code = f"""\ -{patch_code} -torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" -""" - with open(filename, "r") as f: - code = f.read() - code = code.replace(TEST_REPLACEABLE_COMMENT, patch_code) - with open(filename, "w") as f: - f.write(code) - return code - - # Runs the minifier launcher script in `repro_dir`, patched with `patch_code`. - def _run_minifier_launcher(self, patch_code, repro_dir): - self.assertIsNotNone(repro_dir) - launch_file = os.path.join(repro_dir, "minifier_launcher.py") - self.assertTrue(os.path.exists(launch_file)) - launch_code = self._inject_code(patch_code, launch_file) - - launch_proc = subprocess.run( - ["python3", launch_file], - capture_output=True, - cwd=repro_dir, - ) - - return launch_proc, launch_code - - # Runs the repro script in `repro_dir`, patched with `patch_code` - def _run_repro(self, patch_code, repro_dir): - self.assertIsNotNone(repro_dir) - repro_file = os.path.join(repro_dir, "repro.py") + def test_after_dynamo(self): + @create_backend + def bad_dynamo_backend(subgraph): + import sys + + def f(*args): + # Shifted the forced exception to runtime as this is more common + # in JIT compilers. + for node in subgraph.model.graph.nodes: + if node.op == "call_function" and node.target is torch._foobar: + sys.stdout.write("Dynamo compiled failed\n") + raise NotImplementedError("foobar is not implemented") + return subgraph.model(*args) + + return f + + mod = MockModule() + opt_mod = torch._dynamo.optimize("bad_dynamo_backend")(mod) + repro_file = torch._dynamo.debug_utils.get_minifier_repro_path() + + @patch.object(torch._dynamo.config, "repro_after", "dynamo") + def inner(): + x = torch.randn(4) + try: + opt_mod(x) + except Exception: + pass + + inner() self.assertTrue(os.path.exists(repro_file)) - repro_code = self._inject_code(patch_code, repro_file) - - repro_proc = subprocess.run( - ["python3", repro_file], capture_output=True, cwd=repro_dir - ) - - return repro_proc, repro_code - - # Template for testing code. - # `run_code` is the code to run for the test case. - # `patch_code` is the code to be patched in every generated file. - def _gen_test_code(self, run_code, repro_after, repro_level, patch_code): - return f"""\ -import torch -import torch._dynamo -{patch_code} -torch._dynamo.config.repro_after = "{repro_after}" -torch._dynamo.config.repro_level = {repro_level} -torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" -{run_code} -""" - - # Runs a full minifier test. - # Minifier tests generally consist of 3 stages: - # 1. Run the problematic code (in a separate process since it could segfault) - # 2. Run the generated minifier launcher script - # 3. Run the generated repro script - def _run_full_test(self, run_code, repro_after, repro_level, patch_code): - test_code = self._gen_test_code(run_code, repro_after, repro_level, patch_code) - test_proc, repro_dir = self._run_test_code(test_code) - self.assertIsNotNone(repro_dir) - launch_proc, launch_code = self._run_minifier_launcher(patch_code, repro_dir) - repro_proc, repro_code = self._run_repro(patch_code, repro_dir) - return ((test_proc, launch_proc, repro_proc), (launch_code, repro_code)) - - # Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA) - def _test_after_dynamo(self, device, repro_level, backend_code, error_name): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("{get_fn_name(backend_code)}") - def inner(x): - for _ in range(10): - x = torch.sin(x) - x = torch.relu(x) - for _ in range(10): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20).to("{device}")) - """ - ) - - (test_proc, _, repro_proc), _ = self._run_full_test( - run_code, "dynamo", repro_level, backend_code - ) - - self.assertIn(error_name, test_proc.stderr.decode("utf-8")) - self.assertIn(error_name, repro_proc.stderr.decode("utf-8")) - - def test_after_dynamo_cpu_compile_error(self): - self._test_after_dynamo( - "cpu", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" - ) - - def test_after_dynamo_cpu_runtime_error(self): - self._test_after_dynamo( - "cpu", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" - ) - - def test_after_dynamo_cpu_accuracy_error(self): - self._test_after_dynamo("cpu", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") - - @requires_cuda() - def test_after_dynamo_cuda_compile_error(self): - self._test_after_dynamo( - "cuda", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" - ) - - @requires_cuda() - def test_after_dynamo_cuda_runtime_error(self): - self._test_after_dynamo( - "cuda", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" - ) - - @requires_cuda() - def test_after_dynamo_cuda_accuracy_error(self): - self._test_after_dynamo("cuda", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") - - # Ensure that the testing backends pass when relu is not present. - def _test_after_dynamo_backend_passes(self, device, repro_level, backend_code): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("{get_fn_name(backend_code)}") - def inner(x): - for _ in range(10): - x = torch.sin(x) - for _ in range(10): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20).to("{device}")) - """ - ) - - test_code = self._gen_test_code(run_code, "dynamo", repro_level, backend_code) - proc, repro_dir = self._run_test_code(test_code) - self.assertEqual(proc.returncode, 0) - self.assertIsNone(repro_dir) - - def test_after_dynamo_cpu_compile_backend_passes(self): - self._test_after_dynamo_backend_passes("cpu", 2, RELU_COMPILE_ERROR_BACKEND) - - def test_after_dynamo_cpu_runtime_backend_passes(self): - self._test_after_dynamo_backend_passes("cpu", 2, RELU_RUNTIME_ERROR_BACKEND) - - def test_after_dynamo_cpu_accuracy_backend_passes(self): - self._test_after_dynamo_backend_passes("cpu", 4, RELU_ACCURACY_ERROR_BACKEND) - @requires_cuda() - def test_after_dynamo_cuda_compile_backend_passes(self): - self._test_after_dynamo_backend_passes("cuda", 2, RELU_COMPILE_ERROR_BACKEND) + # If error_at_aot is True, an error will be produced when AOTAutograd + # attempts to generate the backward graph. + # If error_after_aot is False, an error will be produced in inductor. + def _test_around_aot(self, error_at_aot): + mod = MockModule() + opt_mod = torch._dynamo.optimize("inductor")(mod) - @requires_cuda() - def test_after_dynamo_cuda_runtime_backend_passes(self): - self._test_after_dynamo_backend_passes("cuda", 2, RELU_RUNTIME_ERROR_BACKEND) + repro_file = torch._dynamo.debug_utils.get_minifier_repro_path() + repro_after = "dynamo" if error_at_aot else "aot" - @requires_cuda() - def test_after_dynamo_cuda_accuracy_backend_passes(self): - self._test_after_dynamo_backend_passes("cuda", 4, RELU_ACCURACY_ERROR_BACKEND) + @patch.object(torch._dynamo.config, "repro_after", repro_after) + def inner(): + x = torch.randn(4) + x.requires_grad = error_at_aot + try: + opt_mod(x) + except Exception: + pass - # Ensure that generated code with a custom backends generates a runnable minifier - # launcher script that results in a RuntimeError - def test_after_dynamo_custom_backend(self): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize({get_fn_name(RELU_CUSTOM_ERROR_BACKEND)}) - def inner(x): - for _ in range(10): - x = torch.sin(x) - x = torch.relu(x) - for _ in range(10): - x = torch.cos(x) - return x + inner() - inner(torch.randn(20, 20)) - """ - ) - - test_code = self._gen_test_code( - run_code, "dynamo", 2, RELU_CUSTOM_ERROR_BACKEND - ) - _, repro_dir = self._run_test_code(test_code) - launch_proc, launch_code = self._run_minifier_launcher("", repro_dir) - self.assertIn("RuntimeError", launch_proc.stderr.decode("utf-8")) - - # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd - @requires_cuda() - def test_cpu_cuda_module_after_dynamo(self): - backend_name = get_fn_name(RELU_COMPILE_ERROR_BACKEND) - - run_code = textwrap.dedent( - f"""\ - class CpuCudaModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.m_x = torch.nn.Linear(20, 20).cuda() - self.m_y = torch.nn.Linear(20, 20) - self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda()) - self.p_y = torch.nn.Parameter(torch.randn(20, 20)) - self.register_buffer("b_x", torch.ones(20, 20).cuda()) - self.register_buffer("b_y", torch.ones(20, 20)) - - def forward(self, x, y): - return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y - - mod = CpuCudaModule() - - @torch._dynamo.optimize("{backend_name}") - def inner(x1, y1): - x2 = torch.randn(20, 20).cuda() - y2 = torch.randn(20, 20) - x3, y3 = mod(x1 + x2, y1 + y2) - return torch.relu(x3.cpu() + y3) - - inner(torch.randn(20, 20).cuda(), torch.randn(20, 20)) - """ - ) - - (test_proc, _, repro_proc), (launch_code, _) = self._run_full_test( - run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND - ) - - tb1 = test_proc.stderr.decode("utf-8") - tb2 = repro_proc.stderr.decode("utf-8") - - # Check if generated minifier code covers all cpu/cuda cases - self.assertIsNotNone(re.search(r"args.*cuda", launch_code)) - self.assertIsNotNone(re.search(r"args.*cpu", launch_code)) - # search for Linear(...).cuda() - self.assertIsNotNone(re.search(r"Linear.*cuda", launch_code)) - # search for Linear(...) - self.assertIsNotNone( - re.search(r"Linear(?!.*cuda.*$)", launch_code, re.MULTILINE) - ) - self.assertIsNotNone(re.search(r"register_buffer.*cuda", launch_code)) - self.assertIsNotNone( - re.search(r"register_buffer(?!.*cuda.*$)", launch_code, re.MULTILINE) - ) - self.assertIsNotNone(re.search(r"Parameter.*cuda", launch_code)) - self.assertIsNotNone( - re.search(r"Parameter(?!.*cuda.*$)", launch_code, re.MULTILINE) - ) - # search for - # = torch.randn(...) - # ... = .cuda() - self.assertIsNotNone( - re.search(r"(\w+) = torch.randn.*\1\.cuda", launch_code, re.DOTALL) - ) - # search for - # = torch.randn(...) - # no followup call to .cuda() - self.assertIsNotNone( - re.search( - r"(\w+) = torch.randn(?!.*\1\.cuda\(\).*$)", launch_code, re.DOTALL - ) - ) - - self.assertIn(backend_name, tb1) - self.assertIn(backend_name, tb2) - - # Test if we can actually get a minified graph - def test_if_graph_minified(self): - backend_name = get_fn_name(RELU_COMPILE_ERROR_BACKEND) - - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("{backend_name}") - def inner(x): - for _ in range(20): - x = torch.sin(x) - x = torch.relu(x) - for _ in range(20): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20)) - """ - ) - - (test_proc, _, repro_proc), (launch_code, repro_code) = self._run_full_test( - run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND - ) - - tb1 = test_proc.stderr.decode("utf-8") - tb2 = repro_proc.stderr.decode("utf-8") - - self.assertIn(backend_name, tb1) - self.assertIn(backend_name, tb2) - - # compare the length of the forward functions - match = re.search(r"def forward.*return", launch_code, re.DOTALL) - self.assertIsNotNone(match) - self.assertGreater(match.group(0).count("\n"), 40) - - match = re.search(r"def forward.*return", repro_code, re.DOTALL) - self.assertIsNotNone(match) - self.assertLess(match.group(0).count("\n"), 5) - - # Test that compile and accuracy errors after aot can be repro'd (both CPU and CUDA) - def _test_after_aot(self, device, backend_code, repro_level): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("inductor") - def inner(x): - for _ in range(3): - x = torch.sin(x) - x = torch.relu(x) - for _ in range(3): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20).to("{device}")) - """ - ) - patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) - self.assertIsNotNone(patch_code) - (test_proc, _, repro_proc), _ = self._run_full_test( - run_code, "aot", repro_level, patch_code - ) - return ( - (test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")), - (test_proc.returncode, repro_proc.returncode), - ) - - def test_after_aot_cpu_compile_error(self): - (tb1, tb2), _ = self._test_after_aot("cpu", CPP_COMPILE_ERROR, 2) - self.assertIn("CppCompileError", tb1) - self.assertIn("CppCompileError", tb2) - - def test_after_aot_cpu_accuracy_error(self): - (tb1, tb2), _ = self._test_after_aot("cpu", CPP_ACCURACY_ERROR, 4) - self.assertIn("AccuracyError", tb1) - self.assertIn("AccuracyError", tb2) - - @requires_cuda() - def test_after_aot_cuda_compile_error(self): - (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_COMPILE_ERROR, 2) - self.assertIn("SyntaxError", tb1) - self.assertIn("SyntaxError", tb2) - - @requires_cuda() - def test_after_aot_cuda_accuracy_error(self): - (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_ACCURACY_ERROR, 4) - self.assertIn("AccuracyError", tb1) - self.assertIn("AccuracyError", tb2) - - # Test that runtime errors after aot can be repro'd (CPU only for now) - def _test_after_aot_runtime_error(self, device, backend_code): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("inductor") - def inner(x): - for _ in range(3): - x = torch.sin(x) - x = torch.relu(x) - for _ in range(3): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20).to("{device}")) - """ - ) - patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) - self.assertIsNotNone(patch_code) - - (test_proc, _, repro_proc), _ = self._run_full_test( - run_code, "aot", 3, patch_code - ) - - self.assertNotIn("CompilerError", test_proc.stderr.decode("utf-8")) - - self.assertEqual(test_proc.returncode, repro_proc.returncode) - self.assertNotEqual(test_proc.returncode, 0) - - def test_after_aot_cpu_runtime_error(self): - self._test_after_aot_runtime_error("cpu", CPP_RUNTIME_ERROR) - - # NOTE: there is currently not an easy way to cause a triton runtime error. - @unittest.skip - @requires_cuda() - def test_after_aot_cuda_runtime_error(self): - self._test_after_aot_runtime_error("cuda", TRITON_RUNTIME_ERROR) - - # Ensure that inductor codegen patches pass when relu is not present. - def _test_after_aot_backend_passes(self, device, repro_level, backend_code): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("inductor") - def inner(x): - for _ in range(3): - x = torch.sin(x) - for _ in range(3): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20).to("{device}")) - """ - ) - patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) - self.assertIsNotNone(patch_code) - - test_code = self._gen_test_code(run_code, "aot", repro_level, patch_code) - proc, repro_dir = self._run_test_code(test_code) - self.assertEqual(proc.returncode, 0) - self.assertIsNone(repro_dir) - - def test_after_aot_cpu_compile_backend_passes(self): - self._test_after_aot_backend_passes("cpu", 2, CPP_COMPILE_ERROR) - - def test_after_aot_cpu_runtime_backend_passes(self): - self._test_after_aot_backend_passes("cpu", 2, CPP_RUNTIME_ERROR) - - def test_after_aot_cpu_accuracy_backend_passes(self): - self._test_after_aot_backend_passes("cpu", 4, CPP_ACCURACY_ERROR) - - @requires_cuda() - def test_after_aot_cuda_compile_backend_passes(self): - self._test_after_aot_backend_passes("cuda", 2, TRITON_COMPILE_ERROR) + self.assertTrue(os.path.exists(repro_file)) - # NOTE: there is currently not an easy way to cause a triton runtime error. - @unittest.skip - @requires_cuda() - def test_after_aot_cuda_runtime_backend_passes(self): - self._test_after_aot_backend_passes("cuda", 2, TRITON_RUNTIME_ERROR) + def test_at_aot(self): + self._test_around_aot(True) - @requires_cuda() - def test_after_aot_cuda_accuracy_backend_passes(self): - self._test_after_aot_backend_passes("cuda", 4, TRITON_ACCURACY_ERROR) + def test_after_aot(self): + self._test_around_aot(False) if __name__ == "__main__": diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 29d830167b109b3..089ef172d625d9c 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -84,11 +84,6 @@ def __init__(self): for module_name, module in gm.named_children(): module_str = f"{module.__repr__()}" - # module should be a core torch.nn.Module, so all parameters - # should be on the same device. - example_param = next(module.parameters(), None) - if example_param is not None and example_param.is_cuda: - module_str = f"{module_str}.cuda()" model_str += f"{tab*2}self.{module_name} = {module_str}\n" for buffer_name, buffer in gm._buffers.items(): @@ -100,16 +95,12 @@ def __init__(self): tensor_str = ( f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" ) - if buffer.is_cuda: - tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" for param_name, param in gm._parameters.items(): if param is None: continue tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))" - if param.is_cuda: - tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" # TODO - Keep this code for now. But, I don't think we will need this. @@ -154,9 +145,6 @@ def _cuda_system_info_comment(): return model_str -TEST_REPLACEABLE_COMMENT = "# REPLACEABLE COMMENT FOR TESTING PURPOSES" - - def generate_compiler_repro_string(gm, args): model_str = textwrap.dedent( f""" @@ -167,8 +155,6 @@ def generate_compiler_repro_string(gm, args): from math import inf from torch.fx.experimental.proxy_tensor import make_fx - {TEST_REPLACEABLE_COMMENT} - """ ) model_str += f"# torch version: {torch.version.__version__}\n" @@ -184,7 +170,7 @@ def generate_compiler_repro_string(gm, args): model_str += ( "args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]\n" ) - model_str += "mod = make_fx(Repro())(*args)\n" + model_str += 'mod = make_fx(Repro().to(device="cuda"))(*args)\n' return model_str @@ -211,8 +197,7 @@ def dump_compiler_graph_state(gm, args, compiler_name): log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") with open(file_name, "w") as fd: save_graph_repro(fd, gm, args, compiler_name) - curdir = os.getcwd() - repro_path = os.path.join(curdir, "repro.py") + repro_path = os.path.join(config.base_dir, "repro.py") try: shutil.copyfile(file_name, repro_path) log.warning(f"Copying repro file for convenience to {repro_path}") @@ -231,10 +216,7 @@ def save_graph_repro(fd, gm, args, compiler_name): textwrap.dedent( f""" compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) - class AccuracyError(Exception): - pass - if not same_two_models(mod, compiled, args, only_fwd=True): - raise AccuracyError("Bad accuracy detected") + assert same_two_models(mod, compiled, args, only_fwd=True), "Accuracy failed" """ ) ) @@ -249,7 +231,7 @@ class AccuracyError(Exception): ) -def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None): +def isolate_fails(fx_g, args, compiler_name: str, env=None): if env is None: env = {} subdir = os.path.join(os.getcwd(), "isolate") @@ -257,10 +239,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None): os.makedirs(subdir, exist_ok=True) file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") with open(file_name, "w") as fd: - repro_code = generate_compiler_repro_string(fx_g, args) - if patch_code is not None: - repro_code = repro_code.replace(TEST_REPLACEABLE_COMMENT, patch_code) - fd.write(repro_code) + fd.write(generate_compiler_repro_string(fx_g, args)) fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2] fd.write( textwrap.dedent( @@ -284,7 +263,6 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None): stdout, stderr = TemporaryFile(), TemporaryFile() p = subprocess.Popen( ["python", file_name], - cwd=subdir, stdout=stdout, stderr=stderr, env=new_env, @@ -351,8 +329,6 @@ def dump_to_minify(gm, args, compiler_name: str): contents = textwrap.dedent( f""" -isolate_fails_code_str = None - {generate_compiler_repro_string(gm, args)} from functools import partial @@ -367,7 +343,7 @@ def dump_to_minify(gm, args, compiler_name: str): minifier( mod, args, - module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}", patch_code=isolate_fails_code_str), + module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}"), dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"), ) """ @@ -375,10 +351,6 @@ def dump_to_minify(gm, args, compiler_name: str): return helper_for_dump_minify(contents) -class AccuracyError(Exception): - pass - - def wrap_compiler_debug(compiler_fn, compiler_name: str): """ Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both @@ -438,7 +410,7 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, f"{compiler_name}_accuracy", ) - raise AccuracyError("Bad accuracy detected") + raise ValueError("Bad accuracy detected") else: # Call the compiled function with real inputs return inner_compiled_fn(real_inputs) @@ -463,8 +435,7 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, compiler_name, ) - log.error("CompilerError") - raise + raise e if config.repro_after == "aot": compiled_fn = deferred_for_real_inputs @@ -581,14 +552,9 @@ def generate_dynamo_fx_repro_string( f""" mod.eval() opt_mod.eval() - -class AccuracyError(Exception): - pass - with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): assert same_two_models(mod, mod, args), "Eager itself failed" - if not same_two_models(mod, opt_mod, args): - raise AccuracyError("Dynamo failed") + assert same_two_models(mod, opt_mod, args), "Dynamo failed" """ ) @@ -603,14 +569,12 @@ class AccuracyError(Exception): from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd from {config.dynamo_import}.debug_utils import same_two_models -{TEST_REPLACEABLE_COMMENT} - args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro() +mod = Repro().cuda() opt_mod = {config.dynamo_import}.optimize("{compiler_name}")(mod) {run_code} @@ -749,21 +713,6 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): if config.repro_level == 4: minifier_backend = "dynamo_accuracy_minifier_backend" - custom_compiler_error = ( - textwrap.dedent( - """\ - raise RuntimeError( - 'Compiler name is None - this likely means that a custom compiler ' - 'was called by torchdynamo. Please remove this error, import your ' - 'custom compiler function, and replace the compiler_name="None" ' - 'line below to compiler_name=' - ) - """ - ) - if compiler_name is None - else "" - ) - contents = textwrap.dedent( f""" import os @@ -777,17 +726,14 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): from {config.dynamo_import}.optimizations.backends import BACKENDS from {config.dynamo_import}.testing import rand_strided -{TEST_REPLACEABLE_COMMENT} - args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro() +mod = Repro().cuda() # Setup debug minifier compiler compiler_fn = BACKENDS["{minifier_backend}"] -{custom_compiler_error} dynamo_minifier_backend = functools.partial( compiler_fn, compiler_name="{compiler_name}", @@ -831,7 +777,7 @@ def debug_wrapper(gm, example_inputs, **kwargs): example_inputs, compiler_name, ) - exc = AccuracyError("Bad accuracy detected.") + exc = ValueError("Bad accuracy detected.") exc.minifier_path = os.path.join( minifier_dir(), "minifier_launcher.py" ) From 6b775c42dd2d40992611fb5636e787560663902c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 12 Nov 2022 07:52:44 +0000 Subject: [PATCH 58/62] [quant][executorch] Support quant fusion for reshape in quant in executorch stack (#88858) Summary: This diff added support for fusing "dq - reshape - q" to a reshape op, the op is needed in wakeword model Test Plan: buck test executorch/exir/tests:quant_fusion_pass Reviewed By: qihqi, JacobSzwejbka Differential Revision: D41111069 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88858 Approved by: https://github.com/JacobSzwejbka --- torch/_C/__init__.pyi.in | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 2d20da2a04f30d9..5833d7d7f2a4166 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -972,11 +972,14 @@ class AggregationType(Enum): AVG = 1 class FileCheck(object): - # TODO (add more FileCheck signature) - def check_source_highlighted(self, highlight: str) -> 'FileCheck': ... def run(self, test_string: str) -> None: ... def check(self, test_string: str) -> 'FileCheck': ... def check_not(self, test_string: str) -> 'FileCheck': ... + def check_same(self, test_string: str) -> 'FileCheck': ... + def check_next(self, test_string: str) -> 'FileCheck': ... + def check_count(self, test_string: str, count: _int, exactly: _bool = False) -> 'FileCheck': ... + def check_dag(self, test_string: str) -> 'FileCheck': ... + def check_source_highlighted(self, test_string: str) -> 'FileCheck': ... ... # Defined in torch/csrc/jit/python/init.cpp From ae2c668cc044d841853e2672d96bfe0afb38a89c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 12 Nov 2022 07:52:53 +0000 Subject: [PATCH 59/62] Revert "[dynamo][api] Better support of torch.nn.Module (#88629)" This reverts commit c83348597b195f2da1cca0e8318c878b104bce5d. Reverted https://github.com/pytorch/pytorch/pull/88629 on behalf of https://github.com/anijain2305 due to job failing on master https://github.com/pytorch/pytorch/actions/runs/3449914495/jobs/5758267231 --- test/dynamo/test_modules.py | 127 ----------------------------------- torch/_dynamo/__init__.py | 2 - torch/_dynamo/debug_utils.py | 8 --- torch/_dynamo/eval_frame.py | 74 ++++++-------------- torch/_dynamo/testing.py | 13 ---- 5 files changed, 20 insertions(+), 204 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 930035f99a30c33..2fb83b3add6cfbc 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -904,133 +904,6 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) -class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) - - def forward(self, x): - return self.relu(self.linear(x) + self.buf0) - - -class OptimizedModuleTest(torch._dynamo.test_case.TestCase): - def test_nn_module(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) - - x = torch.randn(10, 10) - self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - def test_to(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - x = torch.randn(10, 10) - self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - # Ensure that there is no recompilation - opt_mod(x) - self.assertEqual(cnt.frame_count, 1) - - opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) - self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) - x = torch.randn(10, 10).to(dtype=torch.float64) - opt_mod(x) - # Ensure that there is a recompilation - self.assertEqual(cnt.frame_count, 2) - - def test_attr(self): - class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) - - def forward(self, x): - return self.r(torch.sin(x)) + self.buf0 - - mod = MockModule() - opt_mod = torch._dynamo.optimize("eager")(mod) - - # Check parameteres and buffers - for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): - self.assertTrue(id(p1) == id(p2)) - - def test_recursion(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - - for _ in range(5): - opt_mod = torch._dynamo.optimize(cnt)(opt_mod) - opt_mod(torch.randn(10, 10)) - self.assertEqual(cnt.frame_count, 1) - - def test_composition(self): - class InnerModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(torch.sin(x)) - - opt_inner_mod = InnerModule() - - class OuterModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = opt_inner_mod - - def forward(self, x): - return self.mod(torch.cos(x)) - - outer_mod = OuterModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) - - x = torch.randn(4) - self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) - self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - def test_composition_with_opt_mod(self): - class InnerModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(torch.sin(x)) - - inner_mod = InnerModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) - - class OuterModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = opt_inner_mod - - def forward(self, x): - return self.mod(torch.cos(x)) - - outer_mod = OuterModule() - opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) - - x = torch.randn(4) - self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) - self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) - # There will be a graph break for the inner mod being OptimizedModule - self.assertEqual(cnt.frame_count, 2) - - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 5eee609b0852a21..80f927aeef2fad9 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -7,7 +7,6 @@ export, optimize, optimize_assert, - OptimizedModule, reset_code, run, skip, @@ -26,7 +25,6 @@ "reset", "list_backends", "skip", - "OptimizedModule", ] diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 089ef172d625d9c..f09991f9bf3489c 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -486,16 +486,8 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): """ Check two models have same accuracy. """ - from .eval_frame import OptimizedModule - from .testing import named_parameters_for_optimized_module from .utils import same - if isinstance(gm, OptimizedModule): - gm.named_parameters = named_parameters_for_optimized_module(gm) - - if isinstance(opt_gm, OptimizedModule): - opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) - ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) try: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 20e8c7de085e0b7..8d9e3b7b6aa1469 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -5,7 +5,6 @@ import logging import os import sys -import textwrap import threading import traceback import types @@ -45,27 +44,6 @@ most_recent_backend = None -class OptimizedModule(torch.nn.Module): - """ - Wraps the original nn.Module object and later patches its - forward method to optimized self.forward method. - """ - - def __init__(self, mod): - super().__init__() - # Installs the params/buffer - self._orig_mod = mod - - def __getattr__(self, name): - if name == "_orig_mod": - return self._modules["_orig_mod"] - return getattr(self._orig_mod, name) - - def forward(self, *args, **kwargs): - # This will be monkey patched later - raise RuntimeError("Should not be here") - - def remove_from_cache(f): """ Make sure f.__code__ is not cached to force a recompile @@ -140,15 +118,31 @@ def __call__(self, fn): # Optimize the forward method of torch.nn.Module object if isinstance(fn, torch.nn.Module): mod = fn - new_mod = OptimizedModule(mod) - new_mod.forward = self(mod.forward) + optimized_forward = self(mod.forward) + + class TorchDynamoNNModuleWrapper: + """ + A wrapper that redirects the forward call to the optimized + forward, while for rest it redirects the calls to the original + module. + """ + + def __getattr__(self, name): + return getattr(mod, name) + + def forward(self, *args, **kwargs): + return optimized_forward(*args, **kwargs) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + new_mod = TorchDynamoNNModuleWrapper() # Save the function pointer to find the original callable while nesting # of decorators. - new_mod._torchdynamo_orig_callable = mod.forward + new_mod._torchdynamo_orig_callable = mod return new_mod assert callable(fn) - callback = self.callback on_enter = self.on_enter backend_ctx_ctor = self.extra_ctx_ctor @@ -190,34 +184,6 @@ def _fn(*args, **kwargs): # If the function is called using torch._dynamo.optimize decorator, we # should prevent any type of skipping. if callback not in (None, False): - if not hasattr(fn, "__code__"): - raise RuntimeError( - textwrap.dedent( - """ - - torch._dynamo.optimize is called on a non function object. - If this is a callable class, please optimize the individual methods that you are interested in optimizing. - - >> class CallableClass: - >> def __init__(self): - >> super().__init__() - >> self.relu = torch.nn.ReLU() - >> - >> def __call__(self, x): - >> return self.relu(torch.sin(x)) - >> - >> def print_hello(self): - >> print("Hello world") - >> - >> mod = CallableClass() - - If you want to optimize the __call__ function - - >> mod.__call__ = torch._dynamo.optimize(mod.__call__) - - """ - ) - ) always_optimize_code_objects[fn.__code__] = True return _fn diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index b37299ffd5791fe..d6082ce48acf833 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -32,17 +32,6 @@ def clone_me(x): return x.detach().clone().requires_grad_(x.requires_grad) -def named_parameters_for_optimized_module(mod): - assert isinstance(mod, eval_frame.OptimizedModule) - return mod._orig_mod.named_parameters - - -def remove_optimized_module_prefix(name): - prefix = "_orig_mod." - assert name.startswith(prefix) - return name[len(prefix) :] - - def collect_results(model, prediction, loss, example_inputs): results = [] results.append(prediction) @@ -55,8 +44,6 @@ def collect_results(model, prediction, loss, example_inputs): grads = dict() params = dict() for name, param in model.named_parameters(): - if isinstance(model, eval_frame.OptimizedModule): - name = remove_optimized_module_prefix(name) param_copy = param grad = param.grad # Treat None and zero grad as same From 6e5f736d86be09bd86a5da276ce2f5dcbe0bfc09 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Fri, 11 Nov 2022 08:21:48 -0800 Subject: [PATCH 60/62] [15/N] Add allreduce_coalesced custom op with CPU/CUDA implementations (#88846) Differential Revision: [D41227740](https://our.internmc.facebook.com/intern/diff/D41227740) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88846 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_common.py | 15 +++++++++++ test/distributed/test_c10d_gloo.py | 4 +++ test/distributed/test_c10d_nccl.py | 5 ++++ torch/csrc/distributed/c10d/Ops.cpp | 36 +++++++++++++++++++++++++ torch/csrc/distributed/c10d/Ops.hpp | 5 ++++ torch/csrc/distributed/c10d/OpsImpl.cpp | 34 +++++++++++++++++++++++ torch/csrc/distributed/c10d/init.cpp | 6 ++--- 7 files changed, 102 insertions(+), 3 deletions(-) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index cf46f89b353cd34..77ee7487a0afa1f 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1503,6 +1503,21 @@ def _test_collectives(self, backend): with self.subTest(collective=collective, args=args): self._call_collective_with_varying_tensors(backend, collective, *args) + def _test_allreduce_coalesced(self, backend): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend, + world_size=self.world_size, + rank=self.rank, + store=store, + ) + # TODO: this will be updated in the future to not be backend specific + device = "cuda" if backend == "nccl" else "cpu" + tensors = [torch.ones(10, 10, device=torch.device(device))] + dist.all_reduce_coalesced(tensors, dist.ReduceOp.SUM) + for tensor in tensors: + self.assertEqual(tensor, torch.ones(10, 10) * self.world_size) + class CompilerTest(MultiProcessTestCase): def setUp(self): super(CompilerTest, self).setUp() diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index e0c7c64f7b83610..ba214a02696f9c2 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -2363,6 +2363,10 @@ class GlooProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGro def test_collectives(self): self._test_collectives(backend="gloo") + @requires_gloo() + def test_allreduce_coalesced(self): + self._test_allreduce_coalesced(backend="gloo") + class CompilerTest(test_c10d_common.CompilerTest): @property diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 5d412dd3fb1b048..b3790b082ed57fd 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2953,6 +2953,11 @@ class NcclProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGro def test_collectives(self): self._test_collectives(backend="nccl") + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_allreduce_coalesced(self): + self._test_allreduce_coalesced(backend="nccl") + if __name__ == "__main__": assert ( not torch.cuda._initialized diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index ea77bb337b4a81b..15e186fe3d22d7a 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -40,6 +40,19 @@ std::tuple, c10::intrusive_ptr> allreduce_( std::move(tensor_vec), work); } +c10::intrusive_ptr allreduce_coalesced_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + c10::intrusive_ptr reduce_( at::TensorList tensors, const c10::intrusive_ptr& process_group, @@ -177,6 +190,10 @@ TORCH_LIBRARY(c10d, m) { m.def( "allreduce_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, allreduce_)); + m.def( + "allreduce_coalesced_", + dispatch( + c10::DispatchKey::CompositeExplicitAutograd, allreduce_coalesced_)); m.def( "allgather_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, allgather_)); @@ -249,6 +266,25 @@ c10::intrusive_ptr allreduce( opts.timeout.count())); } +c10::intrusive_ptr allreduce_coalesced( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const AllreduceCoalescedOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::allreduce_coalesced_", "") + .typed( + at::TensorList, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + const c10::intrusive_ptr<::c10d::ReduceOp>&, + int64_t)>(); + + return op.call( + tensors, + process_group, + c10::make_intrusive(opts.reduceOp), + opts.timeout.count()); +} + c10::intrusive_ptr allgather( const c10::intrusive_ptr& process_group, const std::vector>& output_tensors, diff --git a/torch/csrc/distributed/c10d/Ops.hpp b/torch/csrc/distributed/c10d/Ops.hpp index adc64066a885eac..8ef78126e5b9e17 100644 --- a/torch/csrc/distributed/c10d/Ops.hpp +++ b/torch/csrc/distributed/c10d/Ops.hpp @@ -21,6 +21,11 @@ TORCH_API c10::intrusive_ptr allreduce( at::TensorList tensors, const AllreduceOptions& opts = {}); +TORCH_API c10::intrusive_ptr allreduce_coalesced( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const AllreduceCoalescedOptions& opts = {}); + TORCH_API c10::intrusive_ptr allgather( const c10::intrusive_ptr& process_group, const std::vector>& output_tensors, diff --git a/torch/csrc/distributed/c10d/OpsImpl.cpp b/torch/csrc/distributed/c10d/OpsImpl.cpp index 03ec6892857e783..94f5febec14d013 100644 --- a/torch/csrc/distributed/c10d/OpsImpl.cpp +++ b/torch/csrc/distributed/c10d/OpsImpl.cpp @@ -149,6 +149,32 @@ std::tuple, c10::intrusive_ptr> allreduce_cuda_( std::move(tensor_vec), work); } +c10::intrusive_ptr allreduce_coalesced_cpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + +c10::intrusive_ptr allreduce_coalesced_cuda_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + std::tuple>, c10::intrusive_ptr> allgather_cpu_( const std::vector>& output_tensors, @@ -367,6 +393,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("allreduce_", allreduce_cuda_); } +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("allreduce_coalesced_", allreduce_coalesced_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("allreduce_coalesced_", allreduce_coalesced_cuda_); +} + TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("allgather_", allgather_cpu_); } diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6515a3d9a87d475..673f481d602518e 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1134,10 +1134,10 @@ that adds a prefix to each key inserted to the store. .def( "allreduce_coalesced", - [](::c10d::ProcessGroup& self, - std::vector& xs, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& xs, ::c10d::AllreduceCoalescedOptions opts) { - return self.allreduce_coalesced(xs, opts); + return ::c10d::ops::allreduce_coalesced(self, xs, opts); }, py::arg("tensors"), py::arg("opts") = ::c10d::AllreduceCoalescedOptions(), From 4270bb37dacf7e3b2b784fa4ff4002ee6bf87e56 Mon Sep 17 00:00:00 2001 From: Nikita Karetnikov Date: Sat, 12 Nov 2022 00:41:57 +0100 Subject: [PATCH 61/62] [primTorch] Improve `narrow` and `narrow_copy`: refs, tests, docs (#87045) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87045 Approved by: https://github.com/mruberry --- aten/src/ATen/native/TensorShape.cpp | 13 +- test/test_meta.py | 1 - torch/_refs/__init__.py | 38 +++- torch/_tensor_docs.py | 13 +- torch/_torch_docs.py | 27 +-- .../_internal/common_methods_invocations.py | 163 ++++++++++++++---- 6 files changed, 188 insertions(+), 67 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index deb9b949aa5d3e4..e8c87a2f1f5ce1b 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1196,6 +1196,8 @@ Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); } +// Should just use narrow_copy_out, but this API is used internally at Meta: +// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561 Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){ auto output = at::empty_like(self); return narrow_copy_dense_cpu_out(self, dim, start, length, output); @@ -1205,9 +1207,10 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_ int64_t allDim = self.dim(); int64_t end = start+length; TORCH_CHECK(allDim > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); TORCH_CHECK(dim >= 0 && dim < allDim, "Dimension ", dim, " out of range. Expecting 0 <= dim < ", allDim, "."); - TORCH_CHECK(start >= 0 && length >= 0 && end <= self.size(dim), + TORCH_CHECK(start >= 0 && end <= self.size(dim), "Invalid range to narrow. range(start, start+length) must be a subset of range(0, ", self.size(dim), ").") Tensor indices = self._indices(); int64_t sparse_dim = self.sparse_dim(); @@ -1235,6 +1238,8 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_ return newTensor._coalesced_(self.is_coalesced()); } +// Should just use narrow_copy_out, but this API is used internally at Meta: +// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561 Tensor& narrow_copy_dense_cpu_out( const Tensor& self, int64_t dim, int64_t start, int64_t length, Tensor& output ) { @@ -1318,22 +1323,24 @@ Tensor& narrow_copy_dense_cpu_out( Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) { TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); auto cur_size = self.size(dim); if (start != cur_size) { // start being the end is valid, but not a valid dim specification. start = maybe_wrap_dim(start, cur_size); } - TORCH_CHECK(length >= 0 && start <= cur_size - length, + TORCH_CHECK(start <= cur_size - length, "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ")."); return at::slice(self, dim, start, start + length, 1); } Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) { TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); auto cur_size = self.sym_size(dim); if (start != cur_size) { // start being the end is valid, but not a valid dim specification. start = maybe_wrap_dim(start, cur_size); } - TORCH_CHECK(length >= 0 && start <= cur_size - length, + TORCH_CHECK(start <= cur_size - length, "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ")."); return at::slice_symint(self, dim, start, start + length, 1); } diff --git a/test/test_meta.py b/test/test_meta.py index ef25d184c84286b..ae248a90cffb757 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -745,7 +745,6 @@ def run_meta_crossref( } meta_function_device_skips['cpu'] = { - torch.narrow_copy: {b8, bf16, c128, c32, c64, f16, f32, f64, i16, i32, i64, i8, u8}, torch.native_batch_norm: {f32, f64}, } diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 43b0c74192dee5f..70edbff2237f2e0 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -2750,19 +2750,39 @@ def flipud(a: TensorLikeType) -> TensorLikeType: # CompositeImplicitAutograd - don't register decomp -def narrow(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType: +def narrow( + a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int +) -> TensorLikeType: + # Supports Tensor overload that was added for XLA: + # https://github.com/pytorch/pytorch/issues/31558 + if isinstance(start, TensorLike): + check( + start.dim() == 0 and utils.is_integer_dtype(start.dtype), + lambda: "start must be an 0-dim integral Tensor.", + ) + start = start.item() # type: ignore[assignment] + check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") + check(length >= 0, lambda: "narrow(): length must be non-negative.") dim = utils.canonicalize_dim(a.ndim, dim) + dim_length = a.size(dim) + # Start being the end is usually invalid since it's out of bounds. So it's + # not allowed by canonicalize_dim. But for narrow it's valid as long as + # the length is 0, which is handled by the check below. + if start != dim_length: + # Negative start means indexing from the end of dim. + # Note: a dimension isn't being canonicalized here, this reuses + # canonicalize_dim because the semantics are similar. + start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type] + check( + start <= dim_length - length, # type: ignore[arg-type] + lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).", + ) return prims.slice_in_dim(a, start, start + length, axis=dim) -@register_decomposition(torch.ops.aten.narrow_copy) -@out_wrapper() -def narrow_copy(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType: - # TODO: This must return a sparse tensor if the input is sparse, but refs - # have no sparse support. See narrow_copy_sparse in core. - if a.is_sparse: - raise NotImplementedError("narrow_copy ref doesn't support sparse tensors") - return torch.clone(torch.narrow(a=a, dim=dim, start=start, length=length)) # type: ignore[call-overload] +# TODO: This must return a sparse tensor if the input is sparse, but refs have +# no sparse support. See narrow_copy_sparse in core. +narrow_copy = _make_copy_from_view(narrow) def _normalize( diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 8c734a1f3774b21..726ae5137e6a45b 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -3436,18 +3436,7 @@ def callable(a, b) -> number r""" narrow(dimension, start, length) -> Tensor -See :func:`torch.narrow` - -Example:: - - >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - >>> x.narrow(0, 0, 2) - tensor([[ 1, 2, 3], - [ 4, 5, 6]]) - >>> x.narrow(1, 1, 2) - tensor([[ 2, 3], - [ 5, 6], - [ 8, 9]]) +See :func:`torch.narrow`. """, ) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 40375bae3e2741c..2ff2e9be315dece 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -7980,8 +7980,10 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to narrow dim (int): the dimension along which to narrow - start (Tensor or int): the starting dimension - length (int): the distance to the ending dimension + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive Example:: @@ -7993,6 +7995,10 @@ def merge_dicts(*dicts): tensor([[ 2, 3], [ 5, 6], [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) """, ) @@ -8008,8 +8014,9 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to narrow dim (int): the dimension along which to narrow - start (int): the starting offset - length (int): the distance to the ending dimension + start (int): index of the element to start the narrowed dimension from. Can + be negative, which means indexing from the end of `dim` + length (int): length of the narrowed dimension, must be weakly positive Keyword args: {out} @@ -8027,13 +8034,13 @@ def merge_dicts(*dicts): >>> s = torch.arange(16).reshape(2, 2, 2, 2).to_sparse(2) >>> torch.narrow_copy(s, 0, 0, 1) tensor(indices=tensor([[0, 0], - [0, 1]]), - values=tensor([[[0, 1], - [2, 3]], + [0, 1]]), + values=tensor([[[0, 1], + [2, 3]], - [[4, 5], - [6, 7]]]), - size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) + [[4, 5], + [6, 7]]]), + size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) .. seealso:: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5178ec978bd1c63..8ab1ea8a047cda2 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -4391,29 +4391,127 @@ def sample_repeat_tile(op_info, device, dtype, requires_grad, **kwargs): yield SampleInput(make_arg(shape), rep_dim) -def sample_inputs_narrow_copy(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): shapes_and_args = ( - ((S, S, S), (1, 2, 2)), - ((S, S, S), (-1, 2, 2)), - ((S, S, S), (1, 0, 0)), - ((S, S, S), (-1, 0, 0)), - ((S, S, S), (2, 1, 2)), + ((S, S, S), 1, 2, 2), + ((S, S, S), -1, 2, 2), + ((S, S, S), 1, 0, 0), + ((S, S, S), -1, 0, 0), + ((S, S, S), 2, 1, 2), ) - for shape, args in shapes_and_args: + for shape, dim, start, length in shapes_and_args: tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) - yield SampleInput(tensor, args=args) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) +def reference_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): + yield from sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, is_narrow=is_narrow, **kwargs) -def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs): - ''' - sample_inputs_narrow accepts the same inputs as narrow_copy, in addition - narrow also accepts `start` argument to be a Tensor. - ''' - for sample in sample_inputs_narrow_copy(op_info, device, dtype, requires_grad, **kwargs): - yield sample - yield SampleInput(sample.input, args=(sample.args[0], torch.tensor(sample.args[1]), sample.args[2])) + shapes_and_args = ( + # 1-dim + ((M,), 0, 0, 0), # 0 elems from the left + ((M,), -1, -1, 0), # 0 elems from the right + ((M,), 0, 5, 3), # 3 elems from the left + ((M,), 0, -5, 2), # 2 elems from the right + ((M,), -1, 0, M), # M elems from the left + ((M,), 0, -M, M), # M elems from the right + + # 2-dim + ((M, S), 1, 0, 0), # dim 1, 0 elems from the left + ((S, M), -2, -1, 0), # dim 0, 0 elems from the right + ((L, S), 1, 2, 3), # dim 1, 3 elems from the left + ((L, S), -1, 3, 2), # dim 1, 2 elems from the left + ((M, L), 0, 0, M), # dim 0, M elems from the left + ((M, L), -1, -L, L), # dim 1, L elems from the right + + # 3-dim + ((L, M, S), 2, 0, 0), # dim 2, 0 elems from the left + ((M, S, L), -1, -1, 0), # dim 2, 0 elems from the right + ((S, L, M), 2, 0, M), # dim 2, M elems from the left + ((L, S, M), -1, -M, M), # dim 2, M elems from the right + ((S, L, M), 1, 0, 0), # dim 1, 0 elems from the left + ((S, L, M), 0, 2, 1), # dim 0, 1 elem from the left + ((M, S, M), -1, -5, 4), # dim 2, 4 elems from the right + ) + + for shape, dim, start, length in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) + +def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # 0-dim + yield ErrorInput(SampleInput(make_arg(()), 0, 0, 1), + error_type=RuntimeError, + error_regex=r"narrow\(\) cannot be applied to a 0-dim tensor\.") + + # out of bounds dim + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=RuntimeError, + error_regex=r"Expected dim < static_cast\(self_sizes.size\(\)\) to be true, but got false\.") + else: + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got 3\)") + # out of bounds dim (negative) + yield ErrorInput(SampleInput(make_arg((L, S, M)), -4, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)") + + # out of bounds start + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0), + error_type=RuntimeError, + error_regex=r"start \(11\) \+ length \(0\) exceeds dimension size \(10\)\.") + else: + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got 11\)") + # out of bounds start (negative) + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got -11\)") + + # out of bounds length + yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(11\) exceeds dimension size \(10\)\.") + # out of bounds length (negative) + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(-1\) exceeds dimension size \(10\)\.") + else: + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"narrow\(\): length must be non-negative\.") + + # Test Tensor overload that was added for XLA. Start must be an 0-dim + # integral Tensor. narrow_copy doesn't have this overload. + # https://github.com/pytorch/pytorch/issues/31558 + if is_narrow: + # *1-dim* integral Tensor + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, make_arg(S, dtype=torch.int), 2), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") + + # 0-dim *bool* Tensor (bools are not allowed) + yield ErrorInput(SampleInput(make_arg((L, M, S)), -3, make_arg((), dtype=torch.bool), 3), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs): @@ -12407,7 +12505,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, - sample_inputs_func=sample_inputs_narrow, + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=True), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=True), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=False), skips=( # Use of .item() DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), @@ -12423,15 +12523,16 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_fwgrad_bwgrad=False, supports_autograd=False, # https://github.com/pytorch/pytorch/issues/86931 - sample_inputs_func=sample_inputs_narrow_copy, + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=False), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=False), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=False), skips=( # https://github.com/pytorch/pytorch/issues/84577 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), - # Not implemented - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace', device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace', device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta', device_type='cuda'), + # Lazy tensor failures: mutating and aliasing ops should all have codegen'd kernels + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), )), UnaryUfuncInfo('neg', aliases=('negative', ), @@ -18061,22 +18162,20 @@ def reference_flatten(input, start_dim=0, end_dim=-1): "_refs.narrow", torch_opinfo_name="narrow", supports_nvfuser=False, - skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), - ) - ), - PythonRefInfo( - "_refs.nn.functional.group_norm", - torch_opinfo_name="nn.functional.group_norm", - supports_nvfuser=False, - validate_view_consistency=False, + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=True), ), PythonRefInfo( "_refs.narrow_copy", torch_opinfo_name="narrow_copy", supports_out=True, supports_nvfuser=False, + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True), + ), + PythonRefInfo( + "_refs.nn.functional.group_norm", + torch_opinfo_name="nn.functional.group_norm", + supports_nvfuser=False, + validate_view_consistency=False, ), PythonRefInfo( "_refs.native_layer_norm", From 27dc03e09b6b1948e416a9fd78e6ca2b0a0bb1c7 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 11 Nov 2022 11:51:22 -0500 Subject: [PATCH 62/62] Turn internal assert when saved tensor is detached inplace into torch check (#88860) Fixes https://github.com/pytorch/pytorch/issues/88809 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88860 Approved by: https://github.com/albanD --- test/test_autograd.py | 14 ++++++++++++++ torch/csrc/autograd/saved_variable.cpp | 11 ++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index e08047860e42333..33cf188af065913 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6776,6 +6776,20 @@ def inplace_double(x): # not leaf, not output test(lambda: (1 + torch.randn(5, requires_grad=True)), False) + def test_saved_variable_saved_original_inplace_detach(self): + # Detaching a tensor that is saved input raises + a = torch.tensor(1., requires_grad=True).clone() + b = a.sin() + a.detach_() + with self.assertRaisesRegex(RuntimeError, "Trying to use a saved tensor that has been detached"): + b.backward() + + # Detaching a tensor that is saved as output is OK + a = torch.tensor(1., requires_grad=True).clone() + b = a.exp() + a.detach_() + b.backward() + def test_saved_variable_packing_unpacking_did_not_save_original_with_hooks(self): # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks # The saved_original / did_not_save_original distinction corresponds to the `save_original` diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index a2e0f05b63943ba..d438205e8947fc8 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -144,7 +144,16 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { : grad_fn_; if (!is_leaf_ && !grad_fn) { - TORCH_INTERNAL_ASSERT(saved_for, "No grad_fn for non-leaf saved tensor"); + // This issue was introduced when we added logic to save the original + // because now we rely on data_.grad_fn(), but can be unreliable if the + // autograd_meta of that saved tensor is cleared with an in-place detach. + // As a simple fix, we choose to disallow that behavior here even though + // it makes behavior inconsistent depending on whether you are saving + // input or output. + TORCH_CHECK( + saved_for, + "Trying to use a saved tensor that has been detached in-place, i.e. with .detach_()." + "This is not supported, please use out-of-place `.detach()` instead"); grad_fn = std::move(saved_for); }