Skip to content

Commit

Permalink
Keep track of ViewMeta with symbolic inputs. (#125876)
Browse files Browse the repository at this point in the history
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

Pull Request resolved: #125876
Approved by: https://github.com/ezyang
  • Loading branch information
ysiraichi authored and pytorchmergebot committed May 12, 2024
1 parent 6ffc94f commit 02093b6
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 38 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/FunctionalStorageImpl.cpp
Expand Up @@ -10,7 +10,7 @@ namespace at::functionalization {

ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
if (out_idx == this->out_index) return *this;
return ViewMeta(forward_fn, reverse_fn, is_multi_output, is_as_strided, out_idx);
return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
}

// Note [Functionalization: Alias Removal Part 2]
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/FunctionalStorageImpl.h
Expand Up @@ -31,14 +31,16 @@ struct ViewMeta {
ViewMeta(
std::function<Tensor(const Tensor&, int64_t)> forward,
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
bool has_symbolic_inputs,
bool is_multi_output = false,
bool is_as_strided = false,
int64_t out_idx = 0)
: forward_fn(std::move(forward)),
reverse_fn(std::move(reverse)),
out_index(out_idx),
is_multi_output(is_multi_output),
is_as_strided(is_as_strided) {}
is_as_strided(is_as_strided),
has_symbolic_inputs(has_symbolic_inputs) {}

std::function<Tensor(const Tensor&, int64_t)> forward_fn;
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
Expand All @@ -50,6 +52,9 @@ struct ViewMeta {

bool is_as_strided;

// Tells us if this view operation has any symbolic inputs
bool has_symbolic_inputs;

// Returns a copy of the current ViewMeta, if out_idx matches the current
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
// functions, but a new out index.
Expand Down
8 changes: 7 additions & 1 deletion aten/src/ATen/FunctionalTensorWrapper.cpp
Expand Up @@ -137,7 +137,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
),
value_(view_value),
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output),
was_storage_changed_(base->was_storage_changed_)
was_storage_changed_(base->was_storage_changed_),
is_symbolic_(base->is_symbolic_)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
Expand All @@ -147,6 +148,7 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
view_metas_ = base->view_metas_; // copy
}
view_metas_.push_back(meta);
maybe_mark_symbolic(meta);
storage_ = base->storage_; // alias this tensor's storage with the base tensor's
}

Expand Down Expand Up @@ -178,6 +180,8 @@ void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::View
view_metas_.push_back(meta);
// Manually track the fact that this tensor recieved a metadata mutation!
has_metadata_mutation_ = true;
// Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation.
maybe_mark_symbolic(meta);
// Note [Functionalization Pass - Inplace View Ops]
// So, these ops are special - they're mutation AND view ops. They get special codegen.
// An example is transpose_, e.g. `a.transpose_()`
Expand Down Expand Up @@ -257,6 +261,7 @@ void FunctionalTensorWrapper::set__impl(const FunctionalTensorWrapper* other) {
value_ = other->value_;
generation_ = other->generation_;
view_metas_ = other->view_metas_;
is_symbolic_ = other->is_symbolic_;
// FREEZE the old storage, preventing mutations to it.
// this is a huge pain to handle properly in all cases, so we ban it.
functional_storage_impl()->freeze();
Expand Down Expand Up @@ -414,6 +419,7 @@ void FunctionalTensorWrapper::copy_tensor_metadata(
dest_impl->has_metadata_mutation_ = src_impl->has_metadata_mutation_;
dest_impl->is_multi_output_view_ = src_impl->is_multi_output_view_;
dest_impl->was_storage_changed_ = src_impl->was_storage_changed_;
dest_impl->is_symbolic_ = src_impl->is_symbolic_;
dest_impl->generation_ = src_impl->generation_;
dest_impl->view_metas_ = src_impl->view_metas_;
}
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/FunctionalTensorWrapper.h
Expand Up @@ -97,6 +97,14 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
->are_all_mutations_under_no_grad_or_inference_mode();
}

void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
}

bool is_symbolic() const {
return is_symbolic_;
}

// Runs the forward_fn of every ViewMeta collected in the current instance
// to some other base.
Tensor apply_view_metas(const Tensor& base);
Expand Down Expand Up @@ -250,6 +258,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
bool is_multi_output_view_ = false;
// Did the tensor experience a set_() call.
bool was_storage_changed_ = false;
// Did the tensor experience any view operation with symbolic int.
bool is_symbolic_ = false;

size_t generation_ = 0;
std::vector<at::functionalization::ViewMeta> view_metas_;
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/FunctionalizeFallbackKernel.cpp
Expand Up @@ -178,7 +178,8 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
}
},
/*has_symbolic_inputs=*/false
);
at::functionalization::impl::mutate_view_meta(self, view_meta);
return self;
Expand Down Expand Up @@ -298,13 +299,16 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
tmp_output = at::_unsafe_view_symint(self_, size);
}

bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });

at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
return at::_unsafe_view_symint(base, size);
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
}
},
/*has_symbolic_inputs=*/has_symbolic_inputs
);

auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
Expand Down
31 changes: 28 additions & 3 deletions test/functorch/test_aotdispatch.py
Expand Up @@ -67,6 +67,7 @@
parametrize,
run_tests,
skipIfRocm,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.hop_db import hop_db
Expand Down Expand Up @@ -3475,7 +3476,6 @@ def wrapper(g, *args, **kwargs):

return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g))

@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
def test_output_aliases_input_view_meta_replay(self):
@self._compile_and_erase_bases(0)
def f(a):
Expand All @@ -3489,7 +3489,6 @@ def f(a):
str(out.grad_fn.__class__), """<class 'ViewBackward0'>"""
)

@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
def test_output_aliases_intermediate_view_meta_replay(self):
@self._compile_and_erase_bases(0, 1)
def f(a):
Expand All @@ -3509,7 +3508,6 @@ def f(a):
str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
)

@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
def test_output_aliases_output_view_meta_replay(self):
@self._compile_and_erase_bases(1)
def f(a):
Expand All @@ -3525,6 +3523,33 @@ def f(a):
str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
)

@skipIfTorchDynamo()
@patch("torch._dynamo.config.assume_static_by_default", False)
def test_dynamic_output_aliases_input_view_meta_replay(self):
# - torch.compile: using it so we can have a SymInt in the FX graph.
# - Compiling with inductor, so that tensor._base isn't tracked.
#
# This should force the use of as_strided in the view reconstruction path.
# The first 2 view-replay paths won't be taken because:
# - target_functional_tensor will be symbolic (_functionalize_is_symbolic call)
# - tensor._base will be None
@torch.compile(backend="inductor")
def f(a, sz):
return a.view(sz), a.view(-1)

inp = torch.ones(2, 2, requires_grad=True)
out1, out2 = f(inp, (4,))

self.assertIsNotNone(out1.grad_fn)
self.assertExpectedInline(
str(out1.grad_fn.__class__), """<class 'AsStridedBackward0'>"""
)

self.assertIsNotNone(out2.grad_fn)
self.assertExpectedInline(
str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
)


def extract_graph(fx_g, _, graph_cell):
graph_cell[0] = fx_g
Expand Down
3 changes: 3 additions & 0 deletions tools/pyi/gen_pyi.py
Expand Up @@ -811,6 +811,9 @@ def gen_pyi(
"_functionalize_apply_view_metas": [
"def _functionalize_apply_view_metas(tensor: Tensor, base: Tensor) -> Tensor: ..."
],
"_functionalize_is_symbolic": [
"def _functionalize_is_symbolic(tensor: Tensor) -> _bool: ..."
],
"_enable_functionalization": [
"def _enable_functionalization(*, reapply_views: _bool = False): ..."
],
Expand Down
44 changes: 16 additions & 28 deletions torch/_functorch/_aot_autograd/functional_utils.py
Expand Up @@ -234,39 +234,27 @@ def patch_requires_grad(out):
# In summary, we use the fact that FunctionalTensorWrapper saves the view
# functions applied to itself (collected during functionalization) so as
# to replay them (view functions) on the aliased_base_tensor.
if config.view_replay_for_aliased_outputs and target_functional_tensor is not None:
if (
config.view_replay_for_aliased_outputs
and target_functional_tensor is not None
and not torch._functionalize_is_symbolic(target_functional_tensor.tensor)
):
from .schemas import FunctionalTensorMetadataEq

assert isinstance(target_functional_tensor, FunctionalTensorMetadataEq)
functional_tensor = target_functional_tensor.tensor

try:
out = torch._functionalize_apply_view_metas(
functional_tensor, aliased_base_tensor
)
except RuntimeError as e:
# NYI for dynamic shapes.
#
# On functionalization, the ViewMeta lambdas will have symbolic shapes.
# When trying to apply those lambdas on concrete tensors, it will fail.
#
# In order for this to work, we should have a way to replace those
# symbolic shapes with concrete numbers.
aot_joint_log.info(
"could not reconstruct view by re-applying a ViewMeta sequence. "
"Fallbacking to reconstruction using as_strided. "
"Reason: %s",
str(e),
)
else:
# If re-applying the ViewMeta sequence succeeded, there should be no more
# problems going forward. We just check we got to the target shape and
# patch requires_grad flag.
assert out.shape == target_meta_tensor.shape, (
"incorrect out shape after application of ViewMeta sequence: "
f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)"
)
return patch_requires_grad(out)
out = torch._functionalize_apply_view_metas(
functional_tensor, aliased_base_tensor
)
# If re-applying the ViewMeta sequence succeeded, there should be no more
# problems going forward. We just check we got to the target shape and
# patch requires_grad flag.
assert out.shape == target_meta_tensor.shape, (
"incorrect out shape after application of ViewMeta sequence: "
f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)"
)
return patch_requires_grad(out)

# Try to do view-replay if possible.
# fall back to .as_strided() if we can't.
Expand Down
29 changes: 28 additions & 1 deletion torch/csrc/autograd/python_torch_functions_manual.cpp
Expand Up @@ -685,6 +685,29 @@ static PyObject* THPVariable__functionalize_sync(
END_HANDLE_TH_ERRORS
}

static PyObject* THPVariable__functionalize_is_symbolic(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_functionalize_is_symbolic(Tensor tensor)"},
/*traceable=*/true);

ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto tensor = r.tensor(0);
TORCH_INTERNAL_ASSERT(
at::functionalization::impl::isFunctionalTensor(tensor));
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
if (impl->is_symbolic()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}

static PyObject* THPVariable__functionalize_apply_view_metas(
PyObject* self,
PyObject* args,
Expand All @@ -694,7 +717,7 @@ static PyObject* THPVariable__functionalize_apply_view_metas(
{"_functionalize_apply_view_metas(Tensor tensor, Tensor base)"},
/*traceable=*/true);

ParsedArgs<4> parsed_args;
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto tensor = r.tensor(0);
TORCH_INTERNAL_ASSERT(
Expand Down Expand Up @@ -840,6 +863,10 @@ static PyMethodDef torch_functions_manual[] = {
castPyCFunctionWithKeywords(THPVariable__functionalize_sync),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_functionalize_is_symbolic",
castPyCFunctionWithKeywords(THPVariable__functionalize_is_symbolic),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_functionalize_apply_view_metas",
castPyCFunctionWithKeywords(THPVariable__functionalize_apply_view_metas),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
Expand Down

0 comments on commit 02093b6

Please sign in to comment.