Skip to content

Commit

Permalink
Handle shared memory cases in MathBitFallback (#66667)
Browse files Browse the repository at this point in the history
* Handle shared memory cases in MathBithFallback (#63602)

Summary:
Pull Request resolved: #63602

This PR fixes the case when a read and write is performed on a memory shared between mutable and (or) non-mutable arguments. Example:
```
a=torch.tensor([1+1j])
b=a.conj()
b.add_(a) # should return tensor([2]) but returns tensor ([2-2j])
```

The issue here is that in the conjugate fallback, we resolve the conjugation in-place for mutable arguments which can be a problem as shown above in the case when other input arguments share memory with the mutable argument(s).
This PR fixes this issue by:
1. first scanning through the operator input arguments and creating a vector of mutable arguments that have the conj bit set to `True` (and accordingly setting the flag `check_for_alias_with_mut_arg ` to `True` or `False`).
2. Iterating through all the arguments. At this time we only look at the non-mutable arguments. If `check_for_alias_with_mut_arg` is set to `True`, then we iterate through `mutable_inputs` to check if the current arg tensor in question doesn't alias any of the entries in `mutable_inputs`. If yes, then we clone the non-mutable tensor arg, else we resolve the conjugation as before.
3. Now we look through the mutable_inputs vector (which contains only mutable input tensors with conj bit set to `True`). We in-place conjugate each of the entries in the vector.
4. Do the computation.
5. Re-conjugate the mutable argument tensors.

NOTE: `TensorLists` are not fully handled in ConjugateFallback. Please see the in-line comment for more details.

Fixes #59943

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D30466905

Pulled By: anjali411

fbshipit-source-id: 58058e5e6481da04a12d03f743c1491942a6cc9b

* fix lint (#66572)

Summary: Pull Request resolved: #66572

Test Plan: Imported from OSS

Reviewed By: seemethere

Differential Revision: D31624043

Pulled By: suo

fbshipit-source-id: 9db9cee3140d78c2a2f0c937be84755206fee1dd

Co-authored-by: anjali411 <chourdiaanjali123@gmail.com>
Co-authored-by: Michael Suo <suo@fb.com>
  • Loading branch information
3 people committed Oct 15, 2021
1 parent ddf3092 commit b544cbd
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 69 deletions.
12 changes: 2 additions & 10 deletions aten/src/ATen/ConjugateFallback.cpp
Expand Up @@ -2,21 +2,12 @@
#include <ATen/native/MathBitFallThroughLists.h>

namespace at {

namespace native {
struct ConjFallback : MathOpFallback {
ConjFallback() : MathOpFallback(DispatchKey::Conjugate, "conjugate") {}
bool is_bit_set(const Tensor& tensor) override {
return tensor.is_conj();
}
void _set_bit(const Tensor& tensor, bool value) override {
return tensor._set_conj(value);
}
Tensor resolve_bit(const Tensor& tensor) override {
return at::resolve_conj(tensor);
}
Tensor& math_op_(Tensor& tensor) override {
return at::conj_physical_(tensor);
}
};

void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
Expand Down Expand Up @@ -60,4 +51,5 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
}

}
} // namespace at
3 changes: 2 additions & 1 deletion aten/src/ATen/native/MathBitFallThroughLists.h
Expand Up @@ -50,7 +50,8 @@ namespace at {
m.impl("vsplit.array", torch::CppFunction::makeFallthrough()); \
m.impl("conj", torch::CppFunction::makeFallthrough()); \
m.impl("_conj", torch::CppFunction::makeFallthrough()); \
m.impl("_unsafe_view", torch::CppFunction::makeFallthrough());
m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); \
m.impl("resize_", torch::CppFunction::makeFallthrough());

#define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \
m.impl("empty_like", torch::CppFunction::makeFallthrough()); \
Expand Down
110 changes: 62 additions & 48 deletions aten/src/ATen/native/MathBitsFallback.h
Expand Up @@ -3,42 +3,49 @@
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/Resize.h>
#include <c10/util/irange.h>
#include <torch/library.h>

namespace at {

namespace native {
// This fallback should only be used for operations that are self inverse and have a corresponding tensor
// bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
// Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
// Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.

// NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit.
struct MathOpFallback {
MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(op_name_) {}
virtual bool is_bit_set(const Tensor&) = 0;
virtual void _set_bit(const Tensor&, bool) = 0;
// materializes the bit, i.e., returns a new tensor tensor containing the true output
// (after performing the math operation corresponding to the tensor bit) if the bit is set to 1
// else returns self.
virtual Tensor resolve_bit(const Tensor&) = 0;
// in-place operation corresponding to the math op represented by the bit. Im the future if this class
// is generalized for ops that are not self inverse, then this must be replaced by op_inverse_inplace
virtual Tensor& math_op_(Tensor&) = 0;
void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
// Situations to handle:
// 1. Out-of-place operation. Easy: materialize all inputs and
// call it a day.
// 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
// Materialize other inputs as in (1).
// 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
// Materialize other inputs as in (1).
//
// It is important to be able to tell if we READ from an argument and if we
// WRITE from an argument. Conservative approach is to assume that we always
// READ from an argument, but in out-of-place operations you can skip
// conjugating inputs on entry that never get used. In current schema we
// can't easily tell if inplace situation has happened, so don't do it.
/*
Situations to handle:
1. Out-of-place operation. Easy: materialize all inputs and
call it a day.
2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
Materialize other inputs as in (1).
3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
Materialize other inputs as in (1).
It is important to be able to tell if we READ from an argument and if we
WRITE to an argument. Conservative approach is to assume that we always
READ from an argument, but in out= operations you can skip
conjugating inputs on entry that never get used. In the current schema we
can't easily tell if the operation is in in-place or out= operation.
Note:
1. Mutable tensorlists containing tensors whose math bit set to true are disallowed.
2. Mutable tensors with math bit set to true are unconditionally cloned to ensure
correct behavior in the case when the mutable tensor shares memory with non mutable arguments.
If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory
with these mutable inputs would read into wrong values in the following cases:
1. Non mutable inputs have their math bit set to false.
2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory
with one or more mutable arg(s)) are cloned.
At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs.
*/
const auto& arguments = op.schema().arguments();
const auto num_arguments = arguments.size();
const auto stack_start = stack->size() - num_arguments;
Expand Down Expand Up @@ -72,9 +79,8 @@ struct MathOpFallback {
return;
}

// Mutable inputs to be tracked separately
std::vector<Tensor> mutable_inputs;

// Mutable inputs with math bit set to True and their clones
std::vector<std::pair<Tensor, Tensor>> mutable_inputs_with_their_clones;
for (const auto i : c10::irange(num_arguments)) {
auto& ivalue = (*stack)[stack_start + i];
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
Expand All @@ -91,41 +97,49 @@ struct MathOpFallback {
if (!is_bit_set(ivalue.toTensor())) {
continue;
}

auto tensor = std::move(ivalue).toTensor();
TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), op_name, " fallback does not support meta tensors.");
auto resolved_tensor = at::clone(tensor);
if (mut_arg) {
// TODO: This is a waste if the argument is write only
_set_bit(tensor, false);
math_op_(tensor);
mutable_inputs.emplace_back(tensor);
} else {
tensor = resolve_bit(tensor);
TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ",
op_name, "bit set to true.");
mutable_inputs_with_their_clones.emplace_back(std::make_pair(std::move(tensor), resolved_tensor));
}
(*stack)[stack_start + i] = std::move(tensor);
(*stack)[stack_start + i] = std::move(resolved_tensor);
} else if (ivalue.isTensorList()) {
auto tensors = std::move(ivalue).toTensorList();
if (mut_arg) {
for(const auto j : c10::irange(tensors.size())) {
Tensor t = tensors[j];
_set_bit(t, false);
math_op_(t);
mutable_inputs.emplace_back(t);
}
} else {
for(const auto j : c10::irange(tensors.size())) {
tensors[j] = resolve_bit(tensors[j]);
for(const auto j : c10::irange(tensors.size())) {
const auto& tensor = tensors[j];
if (!is_bit_set(tensor)) {
continue;
}
TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ",
op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ",
op.schema().name());
tensors[j] = at::clone(tensor);
}
(*stack)[stack_start + i] = std::move(tensors);
}
}

op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);

for (auto& mutable_input : mutable_inputs) {
math_op_(mutable_input);
_set_bit(mutable_input, true);
TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1);

for (std::pair<Tensor, Tensor> mut_tensors: mutable_inputs_with_their_clones) {
auto& mutable_input = mut_tensors.first;
auto& cloned_mutable_input = mut_tensors.second;
auto& ivalue = (*stack)[stack_start];
auto returned_output = std::move(ivalue).toTensor();

// sanity check to ensure that the tensor in stack aliases the cloned_mutable_input
TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output));

// necessary for out= arg
at::native::resize_output(mutable_input, returned_output.sizes());

mutable_input.copy_(returned_output);
(*stack)[stack_start] = std::move(mutable_input);
}
}

Expand All @@ -134,5 +148,5 @@ struct MathOpFallback {
DispatchKey key;
string op_name;
};

} // namespace at
}
}// namespace at
12 changes: 2 additions & 10 deletions aten/src/ATen/native/NegateFallback.cpp
Expand Up @@ -2,21 +2,12 @@
#include <ATen/native/MathBitFallThroughLists.h>

namespace at {

namespace native {
struct NegFallback : MathOpFallback {
NegFallback() : MathOpFallback(DispatchKey::Negative, "negation") {}
bool is_bit_set(const Tensor& tensor) override {
return tensor.is_neg();
}
void _set_bit(const Tensor& tensor, bool value) override {
return tensor._set_neg(value);
}
Tensor resolve_bit(const Tensor& tensor) override {
return at::resolve_neg(tensor);
}
Tensor& math_op_(Tensor& tensor) override {
return tensor.neg_();
}
};

void negationFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
Expand All @@ -42,4 +33,5 @@ TORCH_LIBRARY_IMPL(aten, Negative, m) {
TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
}

}
} // namespace at
10 changes: 10 additions & 0 deletions test/test_view_ops.py
Expand Up @@ -365,6 +365,16 @@ def test_conj_imag_view(self, device, dtype) -> None:
self.assertEqual(v_imag, t_numpy_conj.imag)
self.assertTrue(v_imag.is_neg())

@onlyOnCPUAndCUDA
def test_conj_view_with_shared_memory(self, device) -> None:
a = _make_tensor((4, 5,), torch.cfloat, device)
b = a.conj()
c = a.conj()

self.assertEqual(torch.add(a, b), a.add_(b))
self.assertEqual(torch.add(b, c), torch.add(b, c, out=a))
self.assertEqual(torch.add(b, c), b.add_(c))

@onlyOnCPUAndCUDA
@dtypes(*product(get_all_complex_dtypes(), get_all_dtypes()))
@suppress_warnings
Expand Down

0 comments on commit b544cbd

Please sign in to comment.