From 5e6ceebccbafa6febf8c3fa8abc058f311319015 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 3 Nov 2022 15:15:57 +0000 Subject: [PATCH] Add support for neg to NestedTensor (#88131) Partially fixes #86889 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88131 Approved by: https://github.com/drisspg --- aten/src/ATen/native/native_functions.yaml | 2 ++ .../native/nested/NestedTensorUnaryOps.cpp | 12 +++++++ docs/source/nested.rst | 1 + test/test_nestedtensor.py | 34 +++++++++++-------- 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 2f757965f4e7..3af39c542918 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4247,6 +4247,7 @@ dispatch: SparseCPU, SparseCUDA: neg_sparse SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr + NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg tags: canonical - func: neg_(Tensor(a!) self) -> Tensor(a!) @@ -4256,6 +4257,7 @@ dispatch: SparseCPU, SparseCUDA: neg_sparse_ SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg_ - func: neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator diff --git a/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp index 74289a1372e1..6be7239775ea 100644 --- a/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp +++ b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp @@ -58,5 +58,17 @@ Tensor NestedTensor_tanh(const Tensor& self) { return map_nt(self, at::tanh); } +Tensor& NestedTensor_neg_(Tensor& self) { + auto self_ptr = get_nested_tensor_impl(self); + check_numel_equals_buffer_size(self_ptr); + auto buffer = self_ptr->get_buffer(); + at::neg_(buffer); + return self; +} + +Tensor NestedTensor_neg(const Tensor& self) { + return map_nt(self, at::neg); +} + } // namespace native } // namespace at diff --git a/docs/source/nested.rst b/docs/source/nested.rst index 21ff98025691..07712e0376f1 100644 --- a/docs/source/nested.rst +++ b/docs/source/nested.rst @@ -196,6 +196,7 @@ NestedTensor and any constraints they have. :func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors." :func:`torch.relu`; "Behavior is the same as on regular tensors." :func:`torch.gelu`; "Behavior is the same as on regular tensors." + :func:`torch.neg`; "Behavior is the same as on regular tensors." :func:`torch.add`; "Supports elementwise addition of two nested tensors. Supports addition of a scalar to a nested tensor." :func:`torch.mul`; "Supports elementwise multiplication of two nested tensors. diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 84a30e0125e4..f914fa57dd9a 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -20,6 +20,7 @@ parametrize, run_tests, TestCase, + subtest, ) # Tests are ported from pytorch/nestedtensor. @@ -304,20 +305,6 @@ def test_repr_string(self): self.assertEqual(str(a), expected) self.assertEqual(repr(a), expected) - @torch.inference_mode() - def test_activations(self): - for func in (torch.nn.functional.relu, - torch.nn.functional.relu_, - torch.nn.functional.gelu, - torch._C._nn.gelu_, - torch.tanh, - torch.tanh_): - t = torch.tensor([-1, 0, 1], dtype=torch.float) - nt = torch.nested.nested_tensor([t]) - nested_result = func(nt) - self.assertTrue(nested_result.is_nested) - self.assertEqual(func(t), nested_result.unbind()[0]) - def test_to_padded_tensor_on_empty_tensor(self): nt = torch.nested.nested_tensor([]) @@ -762,6 +749,24 @@ def test_nested_tensor_indexing(self, device, dtype): expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)]) self.assertEqual(nt.grad, expected_grad) + @parametrize("func", [subtest(torch.nn.functional.relu, name='relu'), + subtest(torch.nn.functional.relu_, name='relu_'), + subtest(torch.nn.functional.gelu, name='gelu'), + subtest(torch._C._nn.gelu_, name='gelu_'), + subtest(torch.tanh, name='tanh'), + subtest(torch.tanh_, name='tanh_'), + subtest(torch.neg, name='neg')]) + def test_activations(self, device, func): + nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32) + nested_result = func(nt) + self.assertTrue(nested_result.is_nested) + for t, t_res in zip(nt.unbind(), nested_result.unbind()): + self.assertEqual(func(t), t_res) + self.assertRaisesRegex( + RuntimeError, + "NestedTensor must be contiguous to get buffer.", + lambda: func(nt_noncontiguous)) + @dtypes(*floating_types_and_half()) def test_nested_tensor_chunk(self, device, dtype): # Transformer use case @@ -912,7 +917,6 @@ def test_nested_tensor_div(self, device, dtype): RuntimeError, "div requires offsets to match when given NestedTensors", lambda: nt_chunks[0] / nt_chunks[1]) - @dtypes(torch.float, torch.float16) @skipMeta @torch.inference_mode()