From 495e7b1c729e64693e794ea22640b4552816f0ef Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 10 Nov 2022 21:22:29 +0000 Subject: [PATCH] Ref for aten.full; symint changes in prim (#88762) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88762 Approved by: https://github.com/ezyang --- test/functorch/test_vmap.py | 1 + test/test_ops.py | 1 - torch/_prims_common/__init__.py | 5 ++- torch/_refs/__init__.py | 17 +++++--- .../_internal/common_methods_invocations.py | 40 +++++++++++++++++++ 5 files changed, 56 insertions(+), 8 deletions(-) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 5ba35de21b8b..6d95077b627e 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3233,6 +3233,7 @@ def test(): xfail('empty', ''), # test runner can't handle factory functions xfail('ones', ''), # test runner can't handle factory functions xfail('zeros', ''), # test runner can't handle factory functions + xfail('full', ''), # test runner can't handle factory functions xfail('eye', ''), # non-tensor input xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops xfail('sparse.sampled_addmm'), # sparse diff --git a/test/test_ops.py b/test/test_ops.py index 73758bfc6b46..c688f6521af1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1743,7 +1743,6 @@ class TestRefsOpsInfo(TestCase): '_refs.unflatten', '_refs.sum_to_size', # ref implementation missing kwargs - '_refs.full', # missing "layout" '_refs.full_like', # missing "layout" '_refs.ones_like', # missing "layout" '_refs.round', # missing "decimals" diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 90777ed6601a..128796dfa3d0 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -837,10 +837,11 @@ def type_to_dtype(typ: type) -> torch.dtype: if typ is bool: return torch.bool - if typ is int: + if typ in [int, torch.SymInt]: return torch.long - if typ is float: + if typ in [float, torch.SymFloat]: return torch.get_default_dtype() + # TODO: sym_complex_float? if typ is complex: return corresponding_complex_dtype(torch.get_default_dtype()) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 36fef59df375..43b0c74192de 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -322,7 +322,7 @@ def _broadcast_shapes(*_shapes): common_shape = [ 1, ] * reduce(max, (len(shape) for shape in shapes)) - for shape in shapes: + for arg_idx, shape in enumerate(shapes): for idx in range(-1, -1 - len(shape), -1): if common_shape[idx] == 1: if shape[idx] < 0: @@ -333,9 +333,9 @@ def _broadcast_shapes(*_shapes): elif shape[idx] != 1: if common_shape[idx] != shape[idx]: raise RuntimeError( - "Attempting to broadcast a dimension of length ", - str(shape[idx]), - "!", + f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " + f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " + f"should be broadcastable to {common_shape}" ) return common_shape @@ -4495,6 +4495,7 @@ def eye( # result.requires_grad_(requires_grad) +@register_decomposition(torch.ops.aten.full) @out_wrapper() def full( shape: ShapeType, @@ -4506,6 +4507,12 @@ def full( pin_memory: bool = False, requires_grad: bool = False, ) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value)) + device = device if device is not None else torch.device("cpu") + e = empty( shape, dtype=dtype, @@ -4514,7 +4521,7 @@ def full( pin_memory=pin_memory, requires_grad=requires_grad, ) - return fill(e, fill_value) + return torch.fill(e, fill_value) # type: ignore[arg-type] def full_like( diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b41e74a24c10..5178ec978bd1 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -772,6 +772,20 @@ def sample_inputs_ones_zeros(op, device, dtype, requires_grad, **kwargs): for size in sizes: yield SampleInput(size, kwargs={'dtype': dtype, 'device': device}) +def sample_inputs_full(op, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + sizes = ( + (M,), + (S, S), + ) + fill_values = [get_val(dtype), get_val(torch.int)] + + for size, fill_value in product(sizes, fill_values): + yield SampleInput(size, fill_value, dtype=dtype, device=device) + + def error_inputs_uniform(op, device, **kwargs): t = torch.zeros([10], device=device) yield ErrorInput( @@ -14373,6 +14387,32 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), )), + OpInfo('full', + op=torch.full, + supports_autograd=False, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_full, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # boolean alpha not handled properly + DecorateInfo(unittest.expectedFailure, + 'TestCudaFuserOpInfo', + 'test_nvfuser_correctness', + dtypes=(torch.bool,)), + # RuntimeError: UNSUPPORTED DTYPE: bool + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bool,)), + )), OpInfo('new_empty', op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),