Skip to content

Commit

Permalink
[functorch] fix AOTAutograd tutorial (pytorch#87415) (pytorch#87434)
Browse files Browse the repository at this point in the history
It was raising asserts previously
Pull Request resolved: pytorch#87415
Approved by: https://github.com/Chillee
  • Loading branch information
zou3519 committed Oct 21, 2022
1 parent 385022c commit 3ffddc0
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion functorch/notebooks/aot_autograd_optimizations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@
"aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)\n",
"\n",
"# Run the aot_print_fn once to trigger the compilation and print the graphs\n",
"res = aot_print_fn(a, b, c, d).sum().backward()\n",
"cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]\n",
"cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs\n",
"res = aot_print_fn(cloned_a, cloned_b, cloned_c, cloned_d)\n",
"res.sum().backward()\n",
"assert torch.allclose(ref, res)"
]
},
Expand Down Expand Up @@ -300,6 +303,9 @@
"source": [
"from functorch.compile import min_cut_rematerialization_partition\n",
"\n",
"# Zero out the gradients so we can do a comparison later\n",
"a.grad, b.grad, c.grad, d.grad = (None,) * 4\n",
"\n",
"# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.\n",
"# This will show us how the recomputation has modified the graph.\n",
"aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)\n",
Expand Down

0 comments on commit 3ffddc0

Please sign in to comment.