Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type promote clamp #77035

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 21 additions & 3 deletions aten/src/ATen/native/Activation.cpp
Expand Up @@ -363,15 +363,31 @@ auto approximate_type = get_gelutype_enum(approximate);
}

Tensor hardtanh(const Tensor& self, const Scalar& min, const Scalar& max) {
return at::clamp(self, min, max);
Tensor result = at::empty_like(self);
return at::hardtanh_out(result, self, min, max);
}

Tensor& hardtanh_out(const Tensor& self, const Scalar& min, const Scalar& max, Tensor& result) {
return at::clamp_out(result, self, min, max);
TORCH_CHECK(self.scalar_type() != at::kBool,
"Bool inputs not supported for hardtanh");
//preserve legacy behavior of boundaries not causing type promotion
Scalar min_, max_;
if (at::isIntegralType(self.scalar_type(), /*include_bool*/false)) {
int64_t minval = min.toLong();
int64_t maxval = max.toLong();
TORCH_CHECK(self.dtype() != at::kByte || (minval >= 0 &&
maxval >=0), "cannot do hardtanh on an unsigned type with negative limits");
min_ = minval;
max_ = maxval;
} else {
min_ = min;
max_ = max;
}
return at::clamp_out(result, self, min_, max_);
}

Tensor& hardtanh_(Tensor& self, const Scalar& min, const Scalar& max) {
return at::clamp_(self, min, max);
return at::hardtanh_out(self, self, min, max);
}

Tensor& hardtanh_backward_out(const Tensor& grad_output, const Tensor& self, const Scalar& min, const Scalar& max, Tensor& grad_input) {
Expand Down Expand Up @@ -425,10 +441,12 @@ Tensor hardswish_backward(const Tensor& grad_output, const Tensor& self) {
}

Tensor relu(const Tensor & self) {
TORCH_CHECK(self.scalar_type() != at::kBool, "Boolean inputs not supported for relu");
return at::clamp_min(self, 0);
}

Tensor & relu_(Tensor & self) {
TORCH_CHECK(self.scalar_type() != at::kBool, "Boolean inputs not supported for relu");
return at::clamp_min_(self, 0);
}

Expand Down
69 changes: 64 additions & 5 deletions aten/src/ATen/native/TensorCompare.cpp
Expand Up @@ -9,6 +9,7 @@
#include <ATen/native/TensorCompare.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/TensorIndexing.h>
#include <ATen/native/TypeProperties.h>

namespace at {
namespace meta {
Expand All @@ -29,15 +30,50 @@ const OptionalScalarRef max) {
if (!min && !max) {
TORCH_CHECK(false, "torch.clamp: At least one of 'min' or 'max' must not be None");
}

build_borrowing_unary_op(maybe_get_output(), self);
//Manual type promotion, since scalars have to participate in it
ScalarType result_type = self.scalar_type();
TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types");
//Floating is the highest supported
if (!isFloatingType(result_type)) {
at::native::ResultTypeState state = {};
state = at::native::update_result_type_state(self, state);

if (min) {
state = at::native::update_result_type_state(min.get(), state);
}
if (max) {
state = at::native::update_result_type_state(max.get(), state);
}
result_type = at::native::result_type(state);
//disallow type promoting inplace op
TORCH_CHECK((result_type == self.scalar_type()) ||
(!(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))),
"result type ", result_type, " can't be cast to the desired output type ",
self.dtype());
}
build_unary_op(maybe_get_output(), self.to(result_type));
}

TORCH_META_FUNC(clamp_max) (
const Tensor& self,
const Scalar& max
) {
build_borrowing_unary_op(maybe_get_output(), self);
//we could wrap max into tensor and send to tensor overload,
//but relu is implemented via clamp_min, so for perf an uniformity reasons
//do a faster but correct thing
ScalarType result_type = self.scalar_type();
TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types");
//Floating is the highest supported
if (!isFloatingType(result_type)) {
auto result_type = at::native::result_type(self, max);
TORCH_CHECK((result_type == self.scalar_type()) ||
(!(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))),
"result type ", result_type, " can't be cast to the desired output type ",
self.dtype());
build_unary_op(maybe_get_output(), self.to(result_type));
} else {
build_borrowing_unary_op(maybe_get_output(), self);
}
}

TORCH_META_FUNC2(clamp_max, Tensor) (
Expand All @@ -52,7 +88,19 @@ TORCH_META_FUNC(clamp_min) (
const Tensor& self,
const Scalar& min
) {
build_borrowing_unary_op(maybe_get_output(), self);
ScalarType result_type = self.scalar_type();
TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types");
//Floating is the highest supported
if (!isFloatingType(result_type)) {
auto result_type = at::native::result_type(self, min);
TORCH_CHECK((result_type == self.scalar_type() ||
!(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))),
"result type ", result_type, " can't be cast to the desired output type ",
self.dtype());
build_unary_op(maybe_get_output(), self.to(result_type));
} else {
build_borrowing_unary_op(maybe_get_output(), self);
}
}

TORCH_META_FUNC2(clamp_min, Tensor) (
Expand Down Expand Up @@ -521,7 +569,18 @@ Tensor& clamp_out(const Tensor& self, const c10::optional<Tensor>& min,
}

Tensor clamp(const Tensor& self, const c10::optional<Tensor>& min, const c10::optional<Tensor>& max) {
Tensor result = at::empty({0}, self.options());
//manual type promotion to send to `out`
//won't be needed once clamp is ported to structured
at::native::ResultTypeState state = {};
state = at::native::update_result_type_state(self, state);
if (min) {
state = at::native::update_result_type_state(*min, state);
}
if (max) {
state = at::native::update_result_type_state(*max, state);
}
auto result_type = at::native::result_type(state);
Tensor result = at::empty({0}, self.options().dtype(result_type));
return at::clamp_outf(self, min, max, result);
}

Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/TypeProperties.h
Expand Up @@ -12,6 +12,7 @@ struct ResultTypeState {
};

TORCH_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state);
TORCH_API ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition

TORCH_API ScalarType result_type(const ResultTypeState& state);

TORCH_API ScalarType result_type(ITensorListRef tensors);
Expand Down
1 change: 1 addition & 0 deletions test/test_ops.py
Expand Up @@ -183,6 +183,7 @@ def unsupported(dtype):

# Checks that dtypes are listed correctly and generates an informative
# error message

supported_forward = supported_dtypes - unsupported_dtypes
partially_supported_forward = supported_dtypes & unsupported_dtypes
unsupported_forward = unsupported_dtypes - supported_dtypes
Expand Down
97 changes: 66 additions & 31 deletions test/test_type_promotion.py
Expand Up @@ -9,9 +9,9 @@
from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests,
TEST_NUMPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict)
from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes,
dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta, skipMeta)
dtypes, onlyCPU, expectedFailureMeta, skipMeta)
from torch.testing._internal.common_dtype import (
all_types_and_complex_and, all_types_and, get_all_math_dtypes, integral_types_and, floating_types_and
all_types_and_complex_and, get_all_math_dtypes, floating_types
)

import numpy as np
Expand Down Expand Up @@ -971,35 +971,70 @@ def test_computation_ignores_out(self, device):
self.assertEqual(result, a - b, exact_dtype=False)
self.assertNotEqual(result, a.double() - b, exact_dtype=False)

@dtypesIfCUDA(*itertools.product(all_types_and(torch.half, torch.bool),
all_types_and(torch.half, torch.bool)))
@dtypes(*itertools.product(all_types_and(torch.bool),
all_types_and(torch.bool)))
def test_atan2_type_promotion(self, device, dtypes):
dtype1, dtype2 = dtypes
default_float = torch.get_default_dtype()

def is_int(dtype):
return dtype in integral_types_and(torch.bool)

def is_float(dtype):
return dtype in floating_types_and(torch.half)

def get_binary_float_result_type(x, y):
dtype1 = x.dtype
dtype2 = y.dtype
if is_float(dtype1) and is_float(dtype2):
return torch.result_type(x, y)
elif is_float(dtype1) and is_int(dtype2):
return dtype1
elif is_int(dtype1) and is_float(dtype2):
return dtype2
elif is_int(dtype1) and is_int(dtype2):
return default_float

x = torch.tensor(1, dtype=dtype1, device=device)
y = torch.tensor(2, dtype=dtype2, device=device)
self.assertEqual(get_binary_float_result_type(x, y), torch.atan2(x, y).dtype)
@onlyNativeDeviceTypes
@dtypes(*itertools.product((torch.bool, torch.int, torch.float, torch.double), repeat=3))
def test_clamp_type_promotion(self, device, dtypes):
dtype0, dtype1, dtype2 = dtypes
S = 4

def make_tensor(size, dtype):
if dtype == torch.bool:
return torch.randint(2, size, dtype=dtype, device=device)
elif dtype == torch.int:
return torch.randint(10, size, dtype=dtype, device=device)
else:
return torch.randn(size, dtype=dtype, device=device)
min_t = make_tensor((S,), dtype1)
max_t = make_tensor((S,), dtype2)
mins = (min_t, min_t[0], min_t[0].item())
maxs = (max_t, max_t[0], max_t[0].item())
inp = make_tensor((S,), dtype0)
for min_v, max_v in itertools.product(mins, maxs):
if type(max_v) != type(min_v):
continue
if isinstance(min_v, torch.Tensor) and min_v.ndim == 0 and max_v.ndim == 0:
continue # 0d tensors go to scalar overload, and it's tested separately

def expected_type(inp, max, min):
arg1, arg2 = max, min
if isinstance(max, torch.Tensor) and max.ndim == 0:
# first do a maybe dimensional boundary
arg1, arg2 = min, max
exp_type = torch.result_type(inp, arg1)
inp_new = torch.empty_like(inp, dtype=exp_type)
return torch.result_type(inp_new, arg2)
exp_type = expected_type(inp, min_v, max_v)
if exp_type != torch.bool:
actual = torch.clamp(inp, min_v, max_v)
inps = list(map(lambda x: x.to(exp_type) if isinstance(x, torch.Tensor) else x,
(inp, min_v, max_v)))
expected = torch.clamp(inps[0], inps[1], inps[2])
self.assertEqual(actual, expected)
if inp.dtype in floating_types() or exp_type == inp.dtype:
actual = torch.clamp_(inp, min_v, max_v)
self.assertEqual(actual, expected, exact_dtype=False)
for val in mins:
def expected_type(inp, val):
return torch.result_type(inp, val)
exp_type = expected_type(inp, val)
if exp_type != torch.bool:
actual = torch.clamp_min(inp, val)
inps = list(map(lambda x: x.to(exp_type) if isinstance(x, torch.Tensor) else x,
(inp, val)))
expected = torch.clamp_min(inps[0], inps[1])
self.assertEqual(actual.dtype, exp_type)
self.assertEqual(actual, expected)
if inp.dtype == exp_type:
actual = torch.clamp_min_(inp, val)
self.assertEqual(actual, expected)
actual = torch.clamp_max(inp, val)
expected = torch.clamp_max(inps[0], inps[1])
self.assertEqual(actual, expected)
if inp.dtype in floating_types() or exp_type == inp.dtype:
actual = torch.clamp_max_(inp, val)
self.assertEqual(actual, expected, exact_dtype=False)



instantiate_device_type_tests(TestTypePromotion, globals())

Expand Down
2 changes: 1 addition & 1 deletion torch/nn/functional.py
Expand Up @@ -1442,7 +1442,7 @@ def glu(input: Tensor, dim: int = -1) -> Tensor:
return torch._C._nn.glu(input, dim)


def hardtanh(input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False) -> Tensor:
def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor:
r"""
hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor

Expand Down
10 changes: 5 additions & 5 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -10072,7 +10072,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
),
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off]),
# NOTE: clamp has seperate opinfos for scalar min/max (unary op) vs. tensors
# NOTE: clamp has separate opinfos for scalar min/max (unary op) vs. tensors
OpInfo('clamp',
aliases=('clip',),
dtypes=all_types_and(torch.bfloat16),
Expand All @@ -10086,8 +10086,8 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
aliases=('clip', ),
decorators=(precisionOverride({torch.bfloat16: 7e-2, torch.float16: 1e-2}),),
ref=np.clip,
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down Expand Up @@ -17555,8 +17555,8 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
OpInfo(
"nn.functional.gaussian_nll_loss",
ref=_NOTHING,
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down