Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Double-backward with full_backward_hook causes RuntimeError in PyTorch 1.13 #88312

Closed
f-dangel opened this issue Nov 2, 2022 · 3 comments
Closed
Assignees
Labels
actionable high priority module: autograd Related to torch.autograd, and the autograd engine in general module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@f-dangel
Copy link

f-dangel commented Nov 2, 2022

Hi,

I am using double-backward calls to compute Hessian matrices, in combination with PyTorch's full_backward_hooks. After upgrading from 1.12.1 to 1.13.0, I now run into the following error in the second backward pass:

RuntimeError: Module backward hook for grad_input is called before the grad_output one. This happens because the gradient in your nn.Module flows to the Module's input without passing through the Module's output. Make sure that the output depends on the input and that the loss is computed based on the output.

The following snippet reproduces my problem when I try to compute the Hessian of f(x, y) w.r.t. x where all symbols are scalars for simplicity.

"""Compute the scalar-valued second-order derivative of f(x, y) w.r.t. x.

Use Hessian-vector products (double-backward pass) in combination with
full_backward_hook.
"""

from torch import ones_like, rand, rand_like
from torch.autograd import grad
from torch.nn import MSELoss

x = rand(1)
x.requires_grad_(True)
y = rand_like(x)

# without hook (working in 1.12.1 and 1.13.0)
lossfunc = MSELoss()
f = lossfunc(x, y)

(gradx_f,) = grad(f, x, create_graph=True)
(gradxgradx_f,) = grad(gradx_f @ ones_like(x), x)

# with hook (working in 1.12.1 and broken in 1.13.0
lossfunc = MSELoss()


def hook(module, grad_input, grad_output):
    print("This is a test hook")


lossfunc.register_full_backward_hook(hook)

f = lossfunc(x, y)

# this line triggers the backward hook as expected
(gradx_f,) = grad(f, x, create_graph=True)
# the double-backward with hook crashes in 1.13, but used to work before
try:
    (gradxgradx_f,) = grad(gradx_f @ ones_like(x), x)
except RuntimeError as e:
    print(f"Caught RuntimeError: {e}")

Is this the intended behavior? If so, how do I compute higher-order derivatives through multiple backward calls, while using hooks, e.g. for monitoring?

Best,
Felix

Versions

PyTorch version: 1.13.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.10

Python version: 3.7.6 (default, Jan 8 2020, 19:59:22) [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-52-generic-x86_64-with-debian-bookworm-sid
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] backpack-for-pytorch==1.5.1.dev14+g5401cde6
[pip3] mypy==0.940
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.20.1
[pip3] pytorch-memlab==0.2.3
[pip3] torch==1.13.0
[pip3] torchvision==0.10.0
[conda] backpack-for-pytorch 1.5.1.dev14+g5401cde6 dev_0
[conda] numpy 1.20.1 pypi_0 pypi
[conda] pytorch-memlab 0.2.3 pypi_0 pypi
[conda] torch 1.13.0 pypi_0 pypi
[conda] torchvision 0.10.0 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @albanD @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7

@gchanan gchanan added the module: autograd Related to torch.autograd, and the autograd engine in general label Nov 2, 2022
@soulitzer soulitzer self-assigned this Nov 2, 2022
@soulitzer
Copy link
Contributor

soulitzer commented Nov 2, 2022

I think the assertion is wrong, and we can just remove it. Previously, we assumed that everytime the backward post hook fires, the backward pre hook must've fired first - and so we rely on the pre hook to setup the grad_outputs information. Double backward fails because gradx_f does not actually depend on the output of f, so grad_outputs information was not setup when the post hook fires..

This was only "working" previously because we failed to clear the grad_output after the hook fires #82788, so those buffers would've been leaked, and because you aren't truly backwarding through that module again, the grad_outputs that you are getting from that hook don't really mean anything.

@albanD
Copy link
Collaborator

albanD commented Nov 2, 2022

Interesting! Sounds good!

soulitzer added a commit that referenced this issue Nov 2, 2022
@soulitzer soulitzer added high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: regression It used to work, and now it doesn't actionable labels Nov 2, 2022
f-dangel added a commit to f-dangel/backpack that referenced this issue Nov 3, 2022
See #272. Waiting for
pytorch/pytorch#88312 before `torch>=1.13`
can be supported.
f-dangel added a commit to f-dangel/backpack that referenced this issue Nov 3, 2022
* [CI] Test with `torch=={1.9.0, 1.10.0}`

* [CI] Test with `torch=={1.9.0, 1.11.0}`

* [FIX] flake8

* [CI] Test with `torch=={1.9.0, 1.12.0}`

* [TEST] Replace `parameters_to_vector` by custom function

This should fix
`test_network_diag_ggn[<class
'test.converter.converter_cases._Permute'>]`
in `test/converter/test_converter.py`. Between torch 1.11.0 and torch
1.12.0, the GGN-vector products for this case became non-contiguous, and
`torch.nn.utils.convert_parameters.parameters_to_vector` stopped working
as it uses `view`.

Here is a short self-contained snippet to reproduce the issue:

```python
from torch import Tensor, permute, rand, rand_like
from torch.autograd import grad
from torch.nn import Linear, Module
from torch.nn.utils.convert_parameters import parameters_to_vector

from backpack.utils.convert_parameters import tensor_list_to_vector

class Permute(Module):
    def __init__(self):
        super().__init__()
        self.batch_size = 3
        self.in_dim = (5, 3)
        out_dim = 2
        self.linear = Linear(self.in_dim[-1], out_dim)
        self.linear2 = Linear(self.in_dim[-2], out_dim)

    def forward(self, x):
        x = self.linear(x)
        x = x.permute(0, 2, 1)  # method permute
        x = self.linear2(x)
        x = permute(x, (0, 2, 1))  # function permute
        return x

    def input_fn(self) -> Tensor:
        return rand(self.batch_size, *self.in_dim)

model = Permute()

inputs = model.input_fn()
outputs = model(inputs)

params = list(model.parameters())
grad_outputs = rand_like(outputs)
v = [rand_like(p) for p in model.parameters()]

vJ_tuple = grad(outputs, params, grad_outputs=grad_outputs)

for p, vJ in zip(params, vJ_tuple):
    # all contiguous()
    print(p.shape, vJ.shape)
    # between 1.11.0 and 1.12.0, the vector-Jacobian product w.r.t. the second
    # linear layer's weight is not contiguous anymore
    print(p.is_contiguous(), vJ.is_contiguous())

vJ_vector = parameters_to_vector(vJ_tuple)

vJ_vector = tensor_list_to_vector(vJ_tuple)
```

* [REF] Use f-string and add type hints

* [REQ] Require `torch<1.13`

See #272. Waiting for
pytorch/pytorch#88312 before `torch>=1.13`
can be supported.

* [DOC] Update changelog to prepare compatibility patch

* [DOC] fix date

Co-authored-by: Felix Dangel <fdangel@tue.mpg.de>
f-dangel added a commit to f-dangel/backpack that referenced this issue Nov 3, 2022
* [CI] Test with `torch=={1.9.0, 1.10.0}`

* [CI] Test with `torch=={1.9.0, 1.11.0}`

* [FIX] flake8

* [CI] Test with `torch=={1.9.0, 1.12.0}`

* [TEST] Replace `parameters_to_vector` by custom function

This should fix
`test_network_diag_ggn[<class
'test.converter.converter_cases._Permute'>]`
in `test/converter/test_converter.py`. Between torch 1.11.0 and torch
1.12.0, the GGN-vector products for this case became non-contiguous, and
`torch.nn.utils.convert_parameters.parameters_to_vector` stopped working
as it uses `view`.

Here is a short self-contained snippet to reproduce the issue:

```python
from torch import Tensor, permute, rand, rand_like
from torch.autograd import grad
from torch.nn import Linear, Module
from torch.nn.utils.convert_parameters import parameters_to_vector

from backpack.utils.convert_parameters import tensor_list_to_vector

class Permute(Module):
    def __init__(self):
        super().__init__()
        self.batch_size = 3
        self.in_dim = (5, 3)
        out_dim = 2
        self.linear = Linear(self.in_dim[-1], out_dim)
        self.linear2 = Linear(self.in_dim[-2], out_dim)

    def forward(self, x):
        x = self.linear(x)
        x = x.permute(0, 2, 1)  # method permute
        x = self.linear2(x)
        x = permute(x, (0, 2, 1))  # function permute
        return x

    def input_fn(self) -> Tensor:
        return rand(self.batch_size, *self.in_dim)

model = Permute()

inputs = model.input_fn()
outputs = model(inputs)

params = list(model.parameters())
grad_outputs = rand_like(outputs)
v = [rand_like(p) for p in model.parameters()]

vJ_tuple = grad(outputs, params, grad_outputs=grad_outputs)

for p, vJ in zip(params, vJ_tuple):
    # all contiguous()
    print(p.shape, vJ.shape)
    # between 1.11.0 and 1.12.0, the vector-Jacobian product w.r.t. the second
    # linear layer's weight is not contiguous anymore
    print(p.is_contiguous(), vJ.is_contiguous())

vJ_vector = parameters_to_vector(vJ_tuple)

vJ_vector = tensor_list_to_vector(vJ_tuple)
```

* [REF] Use f-string and add type hints

* [REQ] Require `torch<1.13`

See #272. Waiting for
pytorch/pytorch#88312 before `torch>=1.13`
can be supported.

* [DOC] Update changelog to prepare compatibility patch

* [DOC] fix date

Co-authored-by: Felix Dangel <fdangel@tue.mpg.de>
@gchanan
Copy link
Contributor

gchanan commented Nov 3, 2022

should this be in the 1.13.1 milestone?

@soulitzer soulitzer added this to the 1.13.1 milestone Nov 3, 2022
soulitzer added a commit that referenced this issue Nov 11, 2022
…oring in double backward"


Fixes #88312


[ghstack-poisoned]
soulitzer added a commit that referenced this issue Nov 11, 2022
soulitzer added a commit that referenced this issue Nov 11, 2022
…oring in double backward"


Fixes #88312


[ghstack-poisoned]
soulitzer added a commit that referenced this issue Nov 11, 2022
soulitzer added a commit that referenced this issue Nov 13, 2022
…oring in double backward"


See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312


[ghstack-poisoned]
soulitzer added a commit that referenced this issue Nov 13, 2022
…backward"


See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312


[ghstack-poisoned]
soulitzer added a commit that referenced this issue Nov 16, 2022
…oring in double backward"


See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312


[ghstack-poisoned]
soulitzer added a commit that referenced this issue Nov 16, 2022
…backward"


See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312


[ghstack-poisoned]
soulitzer added a commit that referenced this issue Nov 16, 2022
…oring in double backward"


See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312


[ghstack-poisoned]
soulitzer added a commit that referenced this issue Nov 16, 2022
…backward"


See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312


[ghstack-poisoned]
weiwangmeta pushed a commit to weiwangmeta/pytorch that referenced this issue Nov 30, 2022
…ytorch#88357)

Also clarifies documentation to say "execute if and only if gradients wrt outputs are computed" (previously, "execute every time gradients wrt inputs are computed")

See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes pytorch#88312

Pull Request resolved: pytorch#88357
Approved by: https://github.com/albanD
weiwangmeta pushed a commit to weiwangmeta/pytorch that referenced this issue Dec 6, 2022
…ytorch#88357)

Also clarifies documentation to say "execute if and only if gradients wrt outputs are computed" (previously, "execute every time gradients wrt inputs are computed")

See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes pytorch#88312

Pull Request resolved: pytorch#88357
Approved by: https://github.com/albanD
atalman pushed a commit that referenced this issue Dec 6, 2022
…88357) (#89928)

Also clarifies documentation to say "execute if and only if gradients wrt outputs are computed" (previously, "execute every time gradients wrt inputs are computed")

See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312

Pull Request resolved: #88357
Approved by: https://github.com/albanD

Co-authored-by: soulitzer <soulitzer@gmail.com>
kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Dec 10, 2022
…ytorch#88357)

Also clarifies documentation to say "execute if and only if gradients wrt outputs are computed" (previously, "execute every time gradients wrt inputs are computed")

See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes pytorch#88312

Pull Request resolved: pytorch#88357
Approved by: https://github.com/albanD
f-dangel added a commit to f-dangel/backpack that referenced this issue Dec 20, 2022
* [FIX] Copy `_grad_input_padding` from torch==1.9

The function was removed between torch 1.12.1 and torch 1.13.
Reintroducing it should fix
#272.

* [CI] Use latest two torch releases for tests

* [FIX] Ignore flake8 warning about abstract methods

* [FIX] Import

* [CI] Test with `torch=={1.9.0, 1.12.0}` and make tests compatible (#276)

* [CI] Test with `torch=={1.9.0, 1.10.0}`

* [CI] Test with `torch=={1.9.0, 1.11.0}`

* [FIX] flake8

* [CI] Test with `torch=={1.9.0, 1.12.0}`

* [TEST] Replace `parameters_to_vector` by custom function

This should fix
`test_network_diag_ggn[<class
'test.converter.converter_cases._Permute'>]`
in `test/converter/test_converter.py`. Between torch 1.11.0 and torch
1.12.0, the GGN-vector products for this case became non-contiguous, and
`torch.nn.utils.convert_parameters.parameters_to_vector` stopped working
as it uses `view`.

Here is a short self-contained snippet to reproduce the issue:

```python
from torch import Tensor, permute, rand, rand_like
from torch.autograd import grad
from torch.nn import Linear, Module
from torch.nn.utils.convert_parameters import parameters_to_vector

from backpack.utils.convert_parameters import tensor_list_to_vector

class Permute(Module):
    def __init__(self):
        super().__init__()
        self.batch_size = 3
        self.in_dim = (5, 3)
        out_dim = 2
        self.linear = Linear(self.in_dim[-1], out_dim)
        self.linear2 = Linear(self.in_dim[-2], out_dim)

    def forward(self, x):
        x = self.linear(x)
        x = x.permute(0, 2, 1)  # method permute
        x = self.linear2(x)
        x = permute(x, (0, 2, 1))  # function permute
        return x

    def input_fn(self) -> Tensor:
        return rand(self.batch_size, *self.in_dim)

model = Permute()

inputs = model.input_fn()
outputs = model(inputs)

params = list(model.parameters())
grad_outputs = rand_like(outputs)
v = [rand_like(p) for p in model.parameters()]

vJ_tuple = grad(outputs, params, grad_outputs=grad_outputs)

for p, vJ in zip(params, vJ_tuple):
    # all contiguous()
    print(p.shape, vJ.shape)
    # between 1.11.0 and 1.12.0, the vector-Jacobian product w.r.t. the second
    # linear layer's weight is not contiguous anymore
    print(p.is_contiguous(), vJ.is_contiguous())

vJ_vector = parameters_to_vector(vJ_tuple)

vJ_vector = tensor_list_to_vector(vJ_tuple)
```

* [REF] Use f-string and add type hints

* [REQ] Require `torch<1.13`

See #272. Waiting for
pytorch/pytorch#88312 before `torch>=1.13`
can be supported.

* [DOC] Update changelog to prepare compatibility patch

* [DOC] fix date

Co-authored-by: Felix Dangel <fdangel@tue.mpg.de>

* [CI] Test torch from 1.9 to 1.13

* [FIX] Ignore 'zip()' without an explicit 'strict=' parameter

* [REF] Make GGNvps contiguous before flattening and concatenation

* [CI] Unambiguously specify tested torch versions

* [REF] Import _grad_input_padding from torch for torch<1.13

* [FIX] Exception handling for Hessians of linear functions

* [REF] Same `_grad_input_padding` import strategy for conv_transpose

* [FIX] Merge conflict

* [CI] Ignore docstring check of _grad_input_padding

* [DOC] Add type annotation, remove unused import

* [DOC] Add type annotation for output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable high priority module: autograd Related to torch.autograd, and the autograd engine in general module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants