Skip to content

Commit

Permalink
Add nested tensor support to autograd (#79446)
Browse files Browse the repository at this point in the history
The issue that is tracking this work is: #79447

This is one in a series of PRs to add autograd support for nested tensors.
Pull Request resolved: #79446
Approved by: https://github.com/soulitzer
  • Loading branch information
drisspg authored and pytorchmergebot committed Jun 16, 2022
1 parent 4b342b3 commit f965681
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 28 deletions.
3 changes: 0 additions & 3 deletions aten/src/ATen/NestedTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ NestedTensorImpl::NestedTensorImpl(
TORCH_INTERNAL_ASSERT(nested_size_tensor_.is_contiguous());
int64_t size_dim = nested_size_tensor_.dim();
TORCH_INTERNAL_ASSERT(size_dim == 0 || size_dim == 2);
remove_autograd_key();
key_set_ =
key_set_ - c10::DispatchKeySet({c10::DispatchKey::ADInplaceOrView});
refresh_dim();
set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
}
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/native/nested/NestedTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,5 @@ TORCH_API std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl&

TORCH_API Tensor NestedTensor_to_padded_tensor_generic(const Tensor& t, double padding, OptionalIntArrayRef output_size);


} // namespace native
} // namespace at
46 changes: 46 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,52 @@ def test_clone(self, device, dtype):
with self.assertRaisesRegex(RuntimeError, msg):
nt1.clone(memory_format=torch.channels_last)

class TestNestedTensorAutograd(TestCase):
def nt_equal(self, nt1, nt2):
self.assertEqual(nt1.dtype, nt2.dtype)
self.assertEqual(nt1.device, nt2.device)
ub1 = nt1.unbind()
ub2 = nt2.unbind()
self.assertEqual(len(ub1), len(ub2))
n = len(ub1)
for i in range(n):
self.assertEqual(ub1[i], ub2[i])

def _create_nested_tensor_from_list(self, requires_grad=False):
return torch.nested_tensor([torch.randn(1, 2, requires_grad=requires_grad),
torch.randn(7, 8, requires_grad=requires_grad)])

def _create_nested_tensor_from_mask(self, requires_grad=False):
data = torch.randn(2, 3, 4, requires_grad=requires_grad)
mask = torch.ones_like(data[:, :, 0]).bool()
return torch._nested_tensor_from_mask(data, mask)

def test_set_requires_grad_from_list(self):
nt = self._create_nested_tensor_from_list()
nt.requires_grad_()
assert nt.requires_grad

def test_set_requires_grad_from_mask(self):
nt = self._create_nested_tensor_from_mask()
nt.requires_grad_()
assert nt.requires_grad

def test_backward_for_add_op(self):
nt_1 = self._create_nested_tensor_from_mask()
nt_2 = self._create_nested_tensor_from_mask()

nt_1.requires_grad_()
c = nt_1 + nt_2

assert nt_1.requires_grad
assert c.requires_grad
grad_output = self._create_nested_tensor_from_mask()
c.backward(grad_output)

# Grad check doesn't work with nested yet.
# d/dnt_1 (nt + nt_1) = 1*grad_output
self.nt_equal(nt_1.grad, grad_output)

instantiate_device_type_tests(TestNestedTensorDeviceType, globals())

if __name__ == '__main__':
Expand Down
5 changes: 3 additions & 2 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,15 +509,16 @@
"dequantize_self",
# lift() should never actually be called with a requires_grad=True tensor,
"lift",
# Nested Tensors related functions
# _nested_tensor_size() should never actually be called with requires_grad=True tensor
"_nested_tensor_size",
}

DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {
# These non-view functions return tensors with storage use_count != 1
"_slow_conv2d_forward",
"slow_conv3d_forward",
"channel_shuffle",
# lift() should never actually be called with a requires_grad=True tensor,
"lift",
# If an input is returned as-is in output, we cannot guarantee its storage_impl
# use count to be 1 either.
*DONT_ENFORCE_TENSOR_IMPL_USE_COUNT,
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ void validate_outputs(

if (!metadata.is_same_shape(grad)) {
if (metadata.is_expandable_to_shape(grad)) {
grad = at::sum_to(std::move(grad), metadata.shape());
grad = metadata.reduce_grad(grad);
} else {
const auto message = metadata.incompatible_shape_error_message(i, grad);
AT_ERROR(format_error(message.str()));
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
bool is_tensor_subclass) noexcept {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t input_nr = input_metadata_.size();
input_metadata_.emplace_back(options, shape, is_tensor_subclass);
auto meta_shape = MetadataShape{c10::in_place_type<at::DimVector>, shape};
input_metadata_.emplace_back(options, meta_shape, is_tensor_subclass);
return input_nr;
}

Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/autograd/functions/accumulate_grad.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ struct TORCH_API AccumulateGrad : public Node {
new_grad.sizes(),
new_grad.options()));
} else {
if (new_grad.is_sparse() || new_grad.is_sparse_csr()) {
if (new_grad.is_sparse() || new_grad.is_sparse_csr() ||
new_grad.is_nested()) {
update_grad(new_grad.clone());
} else {
if (new_grad.is_mkldnn()) {
Expand Down
85 changes: 67 additions & 18 deletions torch/csrc/autograd/input_metadata.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#pragma once

#include <ATen/ExpandUtils.h>
#include <ATen/NestedTensorImpl.h>
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/Exception.h>
#include <c10/util/variant.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand All @@ -14,10 +18,13 @@
#endif

#include <cstdint>
#include <utility>

namespace torch {
namespace autograd {

using MetadataShape = c10::variant<at::DimVector, at::Tensor>;

/**
* Records TensorOptions, shape of the tensor, whether or not the Python
* dispatch key is set (tensor subclass), and, where applicable, the stream the
Expand All @@ -31,10 +38,10 @@ struct InputMetadata {

InputMetadata(
const at::TensorOptions options,
at::IntArrayRef shape,
MetadataShape input_shape,
bool is_tensor_subclass)
: options_{options},
shape_{shape},
shape_{input_shape},
is_tensor_subclass_{is_tensor_subclass} {
auto device_ = options.device();
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
Expand All @@ -43,17 +50,13 @@ struct InputMetadata {
InputMetadata(const at::Tensor& t)
: InputMetadata(
t.options(),
t.sizes(),
compute_variant_shape(t),
t.unsafeGetTensorImpl()->is_python_dispatch()) {}

const at::TensorOptions options() const {
return options_;
}

at::IntArrayRef shape() const {
return shape_;
}

caffe2::TypeMeta dtype() const {
return options_.dtype();
}
Expand All @@ -75,37 +78,83 @@ struct InputMetadata {
}

at::Tensor zeros_like() const {
return at::zeros(shape_, options_);
TORCH_CHECK(
!is_nested_tensor(),
"Zeros is not currently supported for nested tensors.")
return at::zeros(shape_as_dim_vector(), options_);
}

bool is_same_shape(const at::Tensor& grad) const {
TORCH_CHECK(!grad.is_nested(), "Nested grads are not currently supported.")
return grad.sizes().equals(shape());
TORCH_CHECK(
grad.is_nested() == is_nested_tensor(),
"Both grad and InputMetadata need to be either nested or non nested tensors.")
if (grad.is_nested()) {
return at::native::get_nested_size_tensor(grad).is_same_size(
shape_as_tensor());
}
return grad.sizes().equals(shape_as_dim_vector());
}
bool is_expandable_to_shape(const at::Tensor& grad) const {
// TODO: Currently NestedTensors are not expandable.
return grad.is_nested() ? false
: at::is_expandable_to(shape(), grad.sizes());
// Currently NestedTensors are not expandable. If this support is added then
// updates to reduce_grad will be needed
TORCH_CHECK(
grad.is_nested() == is_nested_tensor(),
"Both grad and InputMetadata need to be either nested or non nested tensors.")
return grad.is_nested()
? false
: at::is_expandable_to(shape_as_dim_vector(), grad.sizes());
}

at::Tensor reduce_grad(at::Tensor& grad) const {
// Currently reduce_grad is only called if is_expandable_to_shape returns
// true For nested tensors this always returns False, so this check
// shouldn't fail
TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_tensor())
return at::sum_to(std::move(grad), shape_as_dim_vector());
}

std::stringstream incompatible_shape_error_message(
const size_t index,
const at::Tensor& grad) const {
std::stringstream ss;
TORCH_CHECK(!grad.is_nested(), "Nested grads are not currently supported.")
ss << "invalid gradient at index " << index << " - got ";
ss << grad.sizes();
if (grad.is_nested()) {
ss << at::native::get_nested_size_tensor(grad);
} else {
ss << grad.sizes();
}
ss << " but expected shape compatible with ";
ss << shape();
if (is_nested_tensor()) {
ss << shape_as_tensor();
} else {
ss << shape_as_dim_vector();
}
return ss;
}

private:
bool is_nested_tensor() const {
return (c10::holds_alternative<at::Tensor>(shape_));
}
MetadataShape compute_variant_shape(const at::Tensor& input) {
if (input.is_nested()) {
auto nested_size = at::native::get_nested_size_tensor(input);
return MetadataShape{c10::in_place_type<at::Tensor>, nested_size};
}
return MetadataShape{c10::in_place_type<at::DimVector>, input.sizes()};
}

at::DimVector shape_as_dim_vector() const {
return c10::get<at::DimVector>(shape_);
}
at::Tensor shape_as_tensor() const {
return c10::get<at::Tensor>(shape_);
}

const at::TensorOptions options_;
at::DimVector shape_;
MetadataShape shape_;
c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
bool is_tensor_subclass_ = false;
};

} // namespace autograd
} // namespace torch
9 changes: 8 additions & 1 deletion torch/csrc/autograd/utils/grad_layout_contract.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ inline bool obeys_layout_contract(
TORCH_INTERNAL_ASSERT(!variable.is_sparse());
TORCH_INTERNAL_ASSERT(!grad.is_sparse_csr());
TORCH_INTERNAL_ASSERT(!variable.is_sparse_csr());
if (variable.is_non_overlapping_and_dense()) {

if (variable.is_nested()) {
// TODO: Nested Tensor does not have an implementation of detach. The
// current implementation of nested tensor likely does obey the gradient
// contract and should return true, but this would likely change in the
// future
return false;
} else if (variable.is_non_overlapping_and_dense()) {
// Only look at stride for dimensions that are not of size 1.
const auto& grad_sizes = grad.sizes();
const auto& grad_strides = grad.strides();
Expand Down

0 comments on commit f965681

Please sign in to comment.