Skip to content

Commit

Permalink
Ref for aten.full; symint changes in prim (pytorch#88762)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#88762
Approved by: https://github.com/ezyang
  • Loading branch information
SherlockNoMad authored and pytorchmergebot committed Nov 11, 2022
1 parent 3fbf748 commit 495e7b1
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 8 deletions.
1 change: 1 addition & 0 deletions test/functorch/test_vmap.py
Expand Up @@ -3233,6 +3233,7 @@ def test():
xfail('empty', ''), # test runner can't handle factory functions
xfail('ones', ''), # test runner can't handle factory functions
xfail('zeros', ''), # test runner can't handle factory functions
xfail('full', ''), # test runner can't handle factory functions
xfail('eye', ''), # non-tensor input
xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops
xfail('sparse.sampled_addmm'), # sparse
Expand Down
1 change: 0 additions & 1 deletion test/test_ops.py
Expand Up @@ -1743,7 +1743,6 @@ class TestRefsOpsInfo(TestCase):
'_refs.unflatten',
'_refs.sum_to_size',
# ref implementation missing kwargs
'_refs.full', # missing "layout"
'_refs.full_like', # missing "layout"
'_refs.ones_like', # missing "layout"
'_refs.round', # missing "decimals"
Expand Down
5 changes: 3 additions & 2 deletions torch/_prims_common/__init__.py
Expand Up @@ -837,10 +837,11 @@ def type_to_dtype(typ: type) -> torch.dtype:

if typ is bool:
return torch.bool
if typ is int:
if typ in [int, torch.SymInt]:
return torch.long
if typ is float:
if typ in [float, torch.SymFloat]:
return torch.get_default_dtype()
# TODO: sym_complex_float?
if typ is complex:
return corresponding_complex_dtype(torch.get_default_dtype())

Expand Down
17 changes: 12 additions & 5 deletions torch/_refs/__init__.py
Expand Up @@ -322,7 +322,7 @@ def _broadcast_shapes(*_shapes):
common_shape = [
1,
] * reduce(max, (len(shape) for shape in shapes))
for shape in shapes:
for arg_idx, shape in enumerate(shapes):
for idx in range(-1, -1 - len(shape), -1):
if common_shape[idx] == 1:
if shape[idx] < 0:
Expand All @@ -333,9 +333,9 @@ def _broadcast_shapes(*_shapes):
elif shape[idx] != 1:
if common_shape[idx] != shape[idx]:
raise RuntimeError(
"Attempting to broadcast a dimension of length ",
str(shape[idx]),
"!",
f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
f"should be broadcastable to {common_shape}"
)

return common_shape
Expand Down Expand Up @@ -4495,6 +4495,7 @@ def eye(
# result.requires_grad_(requires_grad)


@register_decomposition(torch.ops.aten.full)
@out_wrapper()
def full(
shape: ShapeType,
Expand All @@ -4506,6 +4507,12 @@ def full(
pin_memory: bool = False,
requires_grad: bool = False,
) -> TensorLikeType:
utils.check_layout(layout)
utils.check_pin_memory(pin_memory)

dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
device = device if device is not None else torch.device("cpu")

e = empty(
shape,
dtype=dtype,
Expand All @@ -4514,7 +4521,7 @@ def full(
pin_memory=pin_memory,
requires_grad=requires_grad,
)
return fill(e, fill_value)
return torch.fill(e, fill_value) # type: ignore[arg-type]


def full_like(
Expand Down
40 changes: 40 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -772,6 +772,20 @@ def sample_inputs_ones_zeros(op, device, dtype, requires_grad, **kwargs):
for size in sizes:
yield SampleInput(size, kwargs={'dtype': dtype, 'device': device})

def sample_inputs_full(op, device, dtype, requires_grad, **kwargs):
def get_val(dtype):
return make_tensor([], dtype=dtype, device="cpu").item()

sizes = (
(M,),
(S, S),
)
fill_values = [get_val(dtype), get_val(torch.int)]

for size, fill_value in product(sizes, fill_values):
yield SampleInput(size, fill_value, dtype=dtype, device=device)


def error_inputs_uniform(op, device, **kwargs):
t = torch.zeros([10], device=device)
yield ErrorInput(
Expand Down Expand Up @@ -14373,6 +14387,32 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
# UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
)),
OpInfo('full',
op=torch.full,
supports_autograd=False,
is_factory_function=True,
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
supports_out=True,
sample_inputs_func=sample_inputs_full,
skips=(
# Tests that assume input is a tensor or sequence of tensors
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
# Same failure as arange: cannot find linspace in captured graph
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
# boolean alpha not handled properly
DecorateInfo(unittest.expectedFailure,
'TestCudaFuserOpInfo',
'test_nvfuser_correctness',
dtypes=(torch.bool,)),
# RuntimeError: UNSUPPORTED DTYPE: bool
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bool,)),
)),
OpInfo('new_empty',
op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
Expand Down

0 comments on commit 495e7b1

Please sign in to comment.