diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index 3162c839f4af..84b3c8ef059f 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -363,15 +363,31 @@ auto approximate_type = get_gelutype_enum(approximate); } Tensor hardtanh(const Tensor& self, const Scalar& min, const Scalar& max) { - return at::clamp(self, min, max); + Tensor result = at::empty_like(self); + return at::hardtanh_out(result, self, min, max); } Tensor& hardtanh_out(const Tensor& self, const Scalar& min, const Scalar& max, Tensor& result) { - return at::clamp_out(result, self, min, max); + TORCH_CHECK(self.scalar_type() != at::kBool, + "Bool inputs not supported for hardtanh"); + //preserve legacy behavior of boundaries not causing type promotion + Scalar min_, max_; + if (at::isIntegralType(self.scalar_type(), /*include_bool*/false)) { + int64_t minval = min.toLong(); + int64_t maxval = max.toLong(); + TORCH_CHECK(self.dtype() != at::kByte || (minval >= 0 && + maxval >=0), "cannot do hardtanh on an unsigned type with negative limits"); + min_ = minval; + max_ = maxval; + } else { + min_ = min; + max_ = max; + } + return at::clamp_out(result, self, min_, max_); } Tensor& hardtanh_(Tensor& self, const Scalar& min, const Scalar& max) { - return at::clamp_(self, min, max); + return at::hardtanh_out(self, self, min, max); } Tensor& hardtanh_backward_out(const Tensor& grad_output, const Tensor& self, const Scalar& min, const Scalar& max, Tensor& grad_input) { @@ -425,10 +441,12 @@ Tensor hardswish_backward(const Tensor& grad_output, const Tensor& self) { } Tensor relu(const Tensor & self) { + TORCH_CHECK(self.scalar_type() != at::kBool, "Boolean inputs not supported for relu"); return at::clamp_min(self, 0); } Tensor & relu_(Tensor & self) { + TORCH_CHECK(self.scalar_type() != at::kBool, "Boolean inputs not supported for relu"); return at::clamp_min_(self, 0); } diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 6229b43007f0..6607790b9de4 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace at { namespace meta { @@ -29,15 +30,50 @@ const OptionalScalarRef max) { if (!min && !max) { TORCH_CHECK(false, "torch.clamp: At least one of 'min' or 'max' must not be None"); } - - build_borrowing_unary_op(maybe_get_output(), self); + //Manual type promotion, since scalars have to participate in it + ScalarType result_type = self.scalar_type(); + TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types"); + //Floating is the highest supported + if (!isFloatingType(result_type)) { + at::native::ResultTypeState state = {}; + state = at::native::update_result_type_state(self, state); + + if (min) { + state = at::native::update_result_type_state(min.get(), state); + } + if (max) { + state = at::native::update_result_type_state(max.get(), state); + } + result_type = at::native::result_type(state); + //disallow type promoting inplace op + TORCH_CHECK((result_type == self.scalar_type()) || + (!(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))), + "result type ", result_type, " can't be cast to the desired output type ", + self.dtype()); + } + build_unary_op(maybe_get_output(), self.to(result_type)); } TORCH_META_FUNC(clamp_max) ( const Tensor& self, const Scalar& max ) { - build_borrowing_unary_op(maybe_get_output(), self); + //we could wrap max into tensor and send to tensor overload, + //but relu is implemented via clamp_min, so for perf an uniformity reasons + //do a faster but correct thing + ScalarType result_type = self.scalar_type(); + TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types"); + //Floating is the highest supported + if (!isFloatingType(result_type)) { + auto result_type = at::native::result_type(self, max); + TORCH_CHECK((result_type == self.scalar_type()) || + (!(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))), + "result type ", result_type, " can't be cast to the desired output type ", + self.dtype()); + build_unary_op(maybe_get_output(), self.to(result_type)); + } else { + build_borrowing_unary_op(maybe_get_output(), self); + } } TORCH_META_FUNC2(clamp_max, Tensor) ( @@ -52,7 +88,19 @@ TORCH_META_FUNC(clamp_min) ( const Tensor& self, const Scalar& min ) { - build_borrowing_unary_op(maybe_get_output(), self); + ScalarType result_type = self.scalar_type(); + TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types"); + //Floating is the highest supported + if (!isFloatingType(result_type)) { + auto result_type = at::native::result_type(self, min); + TORCH_CHECK((result_type == self.scalar_type() || + !(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))), + "result type ", result_type, " can't be cast to the desired output type ", + self.dtype()); + build_unary_op(maybe_get_output(), self.to(result_type)); + } else { + build_borrowing_unary_op(maybe_get_output(), self); + } } TORCH_META_FUNC2(clamp_min, Tensor) ( @@ -521,7 +569,18 @@ Tensor& clamp_out(const Tensor& self, const c10::optional& min, } Tensor clamp(const Tensor& self, const c10::optional& min, const c10::optional& max) { - Tensor result = at::empty({0}, self.options()); + //manual type promotion to send to `out` + //won't be needed once clamp is ported to structured + at::native::ResultTypeState state = {}; + state = at::native::update_result_type_state(self, state); + if (min) { + state = at::native::update_result_type_state(*min, state); + } + if (max) { + state = at::native::update_result_type_state(*max, state); + } + auto result_type = at::native::result_type(state); + Tensor result = at::empty({0}, self.options().dtype(result_type)); return at::clamp_outf(self, min, max, result); } diff --git a/aten/src/ATen/native/TypeProperties.h b/aten/src/ATen/native/TypeProperties.h index d7b123c9c936..b0f18c594882 100644 --- a/aten/src/ATen/native/TypeProperties.h +++ b/aten/src/ATen/native/TypeProperties.h @@ -12,6 +12,7 @@ struct ResultTypeState { }; TORCH_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state); +TORCH_API ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state); TORCH_API ScalarType result_type(const ResultTypeState& state); TORCH_API ScalarType result_type(ITensorListRef tensors); diff --git a/test/test_ops.py b/test/test_ops.py index d397e469767e..5e33b8412020 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -183,6 +183,7 @@ def unsupported(dtype): # Checks that dtypes are listed correctly and generates an informative # error message + supported_forward = supported_dtypes - unsupported_dtypes partially_supported_forward = supported_dtypes & unsupported_dtypes unsupported_forward = unsupported_dtypes - supported_dtypes diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 5ff2da736ead..26d12c1d86ff 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -9,9 +9,9 @@ from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests, TEST_NUMPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict) from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes, - dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta, skipMeta) + dtypes, onlyCPU, expectedFailureMeta, skipMeta) from torch.testing._internal.common_dtype import ( - all_types_and_complex_and, all_types_and, get_all_math_dtypes, integral_types_and, floating_types_and + all_types_and_complex_and, get_all_math_dtypes, floating_types ) import numpy as np @@ -971,35 +971,70 @@ def test_computation_ignores_out(self, device): self.assertEqual(result, a - b, exact_dtype=False) self.assertNotEqual(result, a.double() - b, exact_dtype=False) - @dtypesIfCUDA(*itertools.product(all_types_and(torch.half, torch.bool), - all_types_and(torch.half, torch.bool))) - @dtypes(*itertools.product(all_types_and(torch.bool), - all_types_and(torch.bool))) - def test_atan2_type_promotion(self, device, dtypes): - dtype1, dtype2 = dtypes - default_float = torch.get_default_dtype() - - def is_int(dtype): - return dtype in integral_types_and(torch.bool) - - def is_float(dtype): - return dtype in floating_types_and(torch.half) - - def get_binary_float_result_type(x, y): - dtype1 = x.dtype - dtype2 = y.dtype - if is_float(dtype1) and is_float(dtype2): - return torch.result_type(x, y) - elif is_float(dtype1) and is_int(dtype2): - return dtype1 - elif is_int(dtype1) and is_float(dtype2): - return dtype2 - elif is_int(dtype1) and is_int(dtype2): - return default_float - - x = torch.tensor(1, dtype=dtype1, device=device) - y = torch.tensor(2, dtype=dtype2, device=device) - self.assertEqual(get_binary_float_result_type(x, y), torch.atan2(x, y).dtype) + @onlyNativeDeviceTypes + @dtypes(*itertools.product((torch.bool, torch.int, torch.float, torch.double), repeat=3)) + def test_clamp_type_promotion(self, device, dtypes): + dtype0, dtype1, dtype2 = dtypes + S = 4 + + def make_tensor(size, dtype): + if dtype == torch.bool: + return torch.randint(2, size, dtype=dtype, device=device) + elif dtype == torch.int: + return torch.randint(10, size, dtype=dtype, device=device) + else: + return torch.randn(size, dtype=dtype, device=device) + min_t = make_tensor((S,), dtype1) + max_t = make_tensor((S,), dtype2) + mins = (min_t, min_t[0], min_t[0].item()) + maxs = (max_t, max_t[0], max_t[0].item()) + inp = make_tensor((S,), dtype0) + for min_v, max_v in itertools.product(mins, maxs): + if type(max_v) != type(min_v): + continue + if isinstance(min_v, torch.Tensor) and min_v.ndim == 0 and max_v.ndim == 0: + continue # 0d tensors go to scalar overload, and it's tested separately + + def expected_type(inp, max, min): + arg1, arg2 = max, min + if isinstance(max, torch.Tensor) and max.ndim == 0: + # first do a maybe dimensional boundary + arg1, arg2 = min, max + exp_type = torch.result_type(inp, arg1) + inp_new = torch.empty_like(inp, dtype=exp_type) + return torch.result_type(inp_new, arg2) + exp_type = expected_type(inp, min_v, max_v) + if exp_type != torch.bool: + actual = torch.clamp(inp, min_v, max_v) + inps = list(map(lambda x: x.to(exp_type) if isinstance(x, torch.Tensor) else x, + (inp, min_v, max_v))) + expected = torch.clamp(inps[0], inps[1], inps[2]) + self.assertEqual(actual, expected) + if inp.dtype in floating_types() or exp_type == inp.dtype: + actual = torch.clamp_(inp, min_v, max_v) + self.assertEqual(actual, expected, exact_dtype=False) + for val in mins: + def expected_type(inp, val): + return torch.result_type(inp, val) + exp_type = expected_type(inp, val) + if exp_type != torch.bool: + actual = torch.clamp_min(inp, val) + inps = list(map(lambda x: x.to(exp_type) if isinstance(x, torch.Tensor) else x, + (inp, val))) + expected = torch.clamp_min(inps[0], inps[1]) + self.assertEqual(actual.dtype, exp_type) + self.assertEqual(actual, expected) + if inp.dtype == exp_type: + actual = torch.clamp_min_(inp, val) + self.assertEqual(actual, expected) + actual = torch.clamp_max(inp, val) + expected = torch.clamp_max(inps[0], inps[1]) + self.assertEqual(actual, expected) + if inp.dtype in floating_types() or exp_type == inp.dtype: + actual = torch.clamp_max_(inp, val) + self.assertEqual(actual, expected, exact_dtype=False) + + instantiate_device_type_tests(TestTypePromotion, globals()) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index f8ab05cfb5b5..2a779c7c8d23 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1442,7 +1442,7 @@ def glu(input: Tensor, dim: int = -1) -> Tensor: return torch._C._nn.glu(input, dim) -def hardtanh(input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False) -> Tensor: +def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor: r""" hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 08d680c8bb23..84580f29f646 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10072,7 +10072,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), ), decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off]), - # NOTE: clamp has seperate opinfos for scalar min/max (unary op) vs. tensors + # NOTE: clamp has separate opinfos for scalar min/max (unary op) vs. tensors OpInfo('clamp', aliases=('clip',), dtypes=all_types_and(torch.bfloat16), @@ -10086,8 +10086,8 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): aliases=('clip', ), decorators=(precisionOverride({torch.bfloat16: 7e-2, torch.float16: 1e-2}),), ref=np.clip, - dtypes=all_types_and(torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -17555,8 +17555,8 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): OpInfo( "nn.functional.gaussian_nll_loss", ref=_NOTHING, - dtypes=all_types_and(torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True,