Skip to content

Commit

Permalink
Add support for neg to NestedTensor (pytorch#88131)
Browse files Browse the repository at this point in the history
Partially fixes pytorch#86889

Pull Request resolved: pytorch#88131
Approved by: https://github.com/drisspg
  • Loading branch information
cpuhrsch authored and pytorchmergebot committed Nov 3, 2022
1 parent 35be73d commit 5e6ceeb
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 15 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -4247,6 +4247,7 @@
dispatch:
SparseCPU, SparseCUDA: neg_sparse
SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr
NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg
tags: canonical

- func: neg_(Tensor(a!) self) -> Tensor(a!)
Expand All @@ -4256,6 +4257,7 @@
dispatch:
SparseCPU, SparseCUDA: neg_sparse_
SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr_
NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg_

- func: neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp
Expand Up @@ -58,5 +58,17 @@ Tensor NestedTensor_tanh(const Tensor& self) {
return map_nt(self, at::tanh);
}

Tensor& NestedTensor_neg_(Tensor& self) {
auto self_ptr = get_nested_tensor_impl(self);
check_numel_equals_buffer_size(self_ptr);
auto buffer = self_ptr->get_buffer();
at::neg_(buffer);
return self;
}

Tensor NestedTensor_neg(const Tensor& self) {
return map_nt(self, at::neg);
}

} // namespace native
} // namespace at
1 change: 1 addition & 0 deletions docs/source/nested.rst
Expand Up @@ -196,6 +196,7 @@ NestedTensor and any constraints they have.
:func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors."
:func:`torch.relu`; "Behavior is the same as on regular tensors."
:func:`torch.gelu`; "Behavior is the same as on regular tensors."
:func:`torch.neg`; "Behavior is the same as on regular tensors."
:func:`torch.add`; "Supports elementwise addition of two nested tensors.
Supports addition of a scalar to a nested tensor."
:func:`torch.mul`; "Supports elementwise multiplication of two nested tensors.
Expand Down
34 changes: 19 additions & 15 deletions test/test_nestedtensor.py
Expand Up @@ -20,6 +20,7 @@
parametrize,
run_tests,
TestCase,
subtest,
)

# Tests are ported from pytorch/nestedtensor.
Expand Down Expand Up @@ -304,20 +305,6 @@ def test_repr_string(self):
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)

@torch.inference_mode()
def test_activations(self):
for func in (torch.nn.functional.relu,
torch.nn.functional.relu_,
torch.nn.functional.gelu,
torch._C._nn.gelu_,
torch.tanh,
torch.tanh_):
t = torch.tensor([-1, 0, 1], dtype=torch.float)
nt = torch.nested.nested_tensor([t])
nested_result = func(nt)
self.assertTrue(nested_result.is_nested)
self.assertEqual(func(t), nested_result.unbind()[0])

def test_to_padded_tensor_on_empty_tensor(self):

nt = torch.nested.nested_tensor([])
Expand Down Expand Up @@ -762,6 +749,24 @@ def test_nested_tensor_indexing(self, device, dtype):
expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)])
self.assertEqual(nt.grad, expected_grad)

@parametrize("func", [subtest(torch.nn.functional.relu, name='relu'),
subtest(torch.nn.functional.relu_, name='relu_'),
subtest(torch.nn.functional.gelu, name='gelu'),
subtest(torch._C._nn.gelu_, name='gelu_'),
subtest(torch.tanh, name='tanh'),
subtest(torch.tanh_, name='tanh_'),
subtest(torch.neg, name='neg')])
def test_activations(self, device, func):
nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32)
nested_result = func(nt)
self.assertTrue(nested_result.is_nested)
for t, t_res in zip(nt.unbind(), nested_result.unbind()):
self.assertEqual(func(t), t_res)
self.assertRaisesRegex(
RuntimeError,
"NestedTensor must be contiguous to get buffer.",
lambda: func(nt_noncontiguous))

@dtypes(*floating_types_and_half())
def test_nested_tensor_chunk(self, device, dtype):
# Transformer use case
Expand Down Expand Up @@ -912,7 +917,6 @@ def test_nested_tensor_div(self, device, dtype):
RuntimeError, "div requires offsets to match when given NestedTensors",
lambda: nt_chunks[0] / nt_chunks[1])


@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
Expand Down

0 comments on commit 5e6ceeb

Please sign in to comment.