Skip to content

Commit

Permalink
[FIX] Support torch>=1.13 (#296)
Browse files Browse the repository at this point in the history
* [FIX] Copy `_grad_input_padding` from torch==1.9

The function was removed between torch 1.12.1 and torch 1.13.
Reintroducing it should fix
#272.

* [CI] Use latest two torch releases for tests

* [FIX] Ignore flake8 warning about abstract methods

* [FIX] Import

* [CI] Test with `torch=={1.9.0, 1.12.0}` and make tests compatible (#276)

* [CI] Test with `torch=={1.9.0, 1.10.0}`

* [CI] Test with `torch=={1.9.0, 1.11.0}`

* [FIX] flake8

* [CI] Test with `torch=={1.9.0, 1.12.0}`

* [TEST] Replace `parameters_to_vector` by custom function

This should fix
`test_network_diag_ggn[<class
'test.converter.converter_cases._Permute'>]`
in `test/converter/test_converter.py`. Between torch 1.11.0 and torch
1.12.0, the GGN-vector products for this case became non-contiguous, and
`torch.nn.utils.convert_parameters.parameters_to_vector` stopped working
as it uses `view`.

Here is a short self-contained snippet to reproduce the issue:

```python
from torch import Tensor, permute, rand, rand_like
from torch.autograd import grad
from torch.nn import Linear, Module
from torch.nn.utils.convert_parameters import parameters_to_vector

from backpack.utils.convert_parameters import tensor_list_to_vector

class Permute(Module):
    def __init__(self):
        super().__init__()
        self.batch_size = 3
        self.in_dim = (5, 3)
        out_dim = 2
        self.linear = Linear(self.in_dim[-1], out_dim)
        self.linear2 = Linear(self.in_dim[-2], out_dim)

    def forward(self, x):
        x = self.linear(x)
        x = x.permute(0, 2, 1)  # method permute
        x = self.linear2(x)
        x = permute(x, (0, 2, 1))  # function permute
        return x

    def input_fn(self) -> Tensor:
        return rand(self.batch_size, *self.in_dim)

model = Permute()

inputs = model.input_fn()
outputs = model(inputs)

params = list(model.parameters())
grad_outputs = rand_like(outputs)
v = [rand_like(p) for p in model.parameters()]

vJ_tuple = grad(outputs, params, grad_outputs=grad_outputs)

for p, vJ in zip(params, vJ_tuple):
    # all contiguous()
    print(p.shape, vJ.shape)
    # between 1.11.0 and 1.12.0, the vector-Jacobian product w.r.t. the second
    # linear layer's weight is not contiguous anymore
    print(p.is_contiguous(), vJ.is_contiguous())

vJ_vector = parameters_to_vector(vJ_tuple)

vJ_vector = tensor_list_to_vector(vJ_tuple)
```

* [REF] Use f-string and add type hints

* [REQ] Require `torch<1.13`

See #272. Waiting for
pytorch/pytorch#88312 before `torch>=1.13`
can be supported.

* [DOC] Update changelog to prepare compatibility patch

* [DOC] fix date

Co-authored-by: Felix Dangel <fdangel@tue.mpg.de>

* [CI] Test torch from 1.9 to 1.13

* [FIX] Ignore 'zip()' without an explicit 'strict=' parameter

* [REF] Make GGNvps contiguous before flattening and concatenation

* [CI] Unambiguously specify tested torch versions

* [REF] Import _grad_input_padding from torch for torch<1.13

* [FIX] Exception handling for Hessians of linear functions

* [REF] Same `_grad_input_padding` import strategy for conv_transpose

* [FIX] Merge conflict

* [CI] Ignore docstring check of _grad_input_padding

* [DOC] Add type annotation, remove unused import

* [DOC] Add type annotation for output
  • Loading branch information
f-dangel committed Dec 20, 2022
1 parent d3b134f commit 0fe55d7
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 27 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/test.yaml
Expand Up @@ -20,7 +20,13 @@ jobs:
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
pytorch-version: [1.9.0, 1.12.0]
pytorch-version:
- "==1.9.1"
- "==1.10.1"
- "==1.11.0"
- "==1.12.1"
- "==1.13.1"
- "" # latest
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v1
Expand All @@ -30,7 +36,7 @@ jobs:
run: |
python -m pip install --upgrade pip
make install-test
pip install torch==${{ matrix.pytorch-version }} torchvision
pip install torch${{ matrix.pytorch-version }} torchvision
- name: Run test
if: contains('refs/heads/master refs/heads/development refs/heads/release', github.ref)
run: |
Expand Down
7 changes: 6 additions & 1 deletion backpack/core/derivatives/conv_transposend.py
Expand Up @@ -5,16 +5,21 @@
from numpy import prod
from torch import Tensor, einsum
from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, Module
from torch.nn.grad import _grad_input_padding

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION_AT_LEAST_1_13
from backpack.utils.conv import get_conv_function
from backpack.utils.conv_transpose import (
get_conv_transpose_function,
unfold_by_conv_transpose,
)
from backpack.utils.subsampling import subsample

if TORCH_VERSION_AT_LEAST_1_13:
from backpack.utils.conv import _grad_input_padding
else:
from torch.nn.grad import _grad_input_padding


class ConvTransposeNDDerivatives(BaseParameterDerivatives):
"""Base class for partial derivatives of transpose convolution."""
Expand Down
7 changes: 6 additions & 1 deletion backpack/core/derivatives/convnd.py
Expand Up @@ -5,13 +5,18 @@
from numpy import prod
from torch import Tensor, einsum
from torch.nn import Conv1d, Conv2d, Conv3d, Module
from torch.nn.grad import _grad_input_padding

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION_AT_LEAST_1_13
from backpack.utils.conv import get_conv_function, unfold_by_conv
from backpack.utils.conv_transpose import get_conv_transpose_function
from backpack.utils.subsampling import subsample

if TORCH_VERSION_AT_LEAST_1_13:
from backpack.utils.conv import _grad_input_padding
else:
from torch.nn.grad import _grad_input_padding


class weight_jac_t_save_memory:
"""Choose algorithm to apply transposed convolution weight Jacobian."""
Expand Down
1 change: 1 addition & 0 deletions backpack/utils/__init__.py
Expand Up @@ -4,5 +4,6 @@
TORCH_VERSION = packaging.version.parse(get_distribution("torch").version)
TORCH_VERSION_AT_LEAST_1_9_1 = TORCH_VERSION >= packaging.version.parse("1.9.1")
TORCH_VERSION_AT_LEAST_2_0_0 = TORCH_VERSION >= packaging.version.parse("2.0.0")
TORCH_VERSION_AT_LEAST_1_13 = TORCH_VERSION >= packaging.version.parse("1.13")

ADAPTIVE_AVG_POOL_BUG: bool = not TORCH_VERSION_AT_LEAST_2_0_0
60 changes: 56 additions & 4 deletions backpack/utils/conv.py
@@ -1,8 +1,8 @@
"""Utility functions for convolution layers."""

from typing import Callable, Tuple, Type, Union
from warnings import warn

import torch
from einops import rearrange
from torch import Tensor, einsum
from torch.nn import (
Expand Down Expand Up @@ -158,9 +158,7 @@ def extract_bias_diagonal(
return S.sum(sum_before).pow_(2).sum(sum_after)


def unfold_by_conv(
input: torch.Tensor, module: Union[Conv1d, Conv2d, Conv3d]
) -> torch.Tensor:
def unfold_by_conv(input: Tensor, module: Union[Conv1d, Conv2d, Conv3d]) -> Tensor:
"""Return the unfolded input using convolution.
Args:
Expand All @@ -179,3 +177,57 @@ def unfold_by_conv(
padding=module.padding,
stride=module.stride,
)


def _grad_input_padding(
grad_output: Tensor,
input_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Tuple[int, ...],
kernel_size: Tuple[int, ...],
dilation: Union[None, Tuple[int]] = None,
) -> Tuple[int, ...]:
"""Determine padding for the VJP of convolution.
# noqa: DAR101
# noqa: DAR201
# noqa: DAR401
Note:
This function was copied from the PyTorch repository (version 1.9).
It was removed between torch 1.12.1 and torch 1.13.
"""
if dilation is None:
# For backward compatibility
warn(
"_grad_input_padding 'dilation' argument not provided. Default of 1 is used."
)
dilation = [1] * len(stride)

input_size = list(input_size)
k = grad_output.dim() - 2

if len(input_size) == k + 2:
input_size = input_size[-k:]
if len(input_size) != k:
raise ValueError(f"input_size must have {k+2} elements (got {len(input_size)})")

def dim_size(d):
return (
(grad_output.size(d + 2) - 1) * stride[d]
- 2 * padding[d]
+ 1
+ dilation[d] * (kernel_size[d] - 1)
)

min_sizes = [dim_size(d) for d in range(k)]
max_sizes = [min_sizes[d] + stride[d] - 1 for d in range(k)]
for size, min_size, max_size in zip(input_size, min_sizes, max_sizes):
if size < min_size or size > max_size:
raise ValueError(
f"requested an input grad size of {input_size}, but valid sizes range "
f"from {min_sizes} to {max_sizes} (for a grad_output of "
f"{grad_output.size()[2:]})"
)

return tuple(input_size[d] - min_sizes[d] for d in range(k))
14 changes: 1 addition & 13 deletions backpack/utils/convert_parameters.py
Expand Up @@ -2,7 +2,7 @@

from typing import Iterable, List

from torch import Tensor, cat, typename
from torch import Tensor, typename


def vector_to_parameter_list(vec: Tensor, parameters: Iterable[Tensor]) -> List[Tensor]:
Expand Down Expand Up @@ -51,15 +51,3 @@ def vector_to_parameter_list(vec: Tensor, parameters: Iterable[Tensor]) -> List[
pointer += num_param

return params_new


def tensor_list_to_vector(tensor_list: Iterable[Tensor]) -> Tensor:
"""Convert a list of tensors into a vector by flattening and concatenation.
Args:
tensor_list: List of tensors.
Returns:
Vector containing the flattened and concatenated tensor inputs.
"""
return cat([t.flatten() for t in tensor_list])
9 changes: 4 additions & 5 deletions backpack/utils/examples.py
Expand Up @@ -3,15 +3,13 @@

from torch import Tensor, stack, zeros
from torch.nn import Module
from torch.nn.utils.convert_parameters import parameters_to_vector
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor

from backpack.hessianfree.ggnvp import ggn_vector_product
from backpack.utils.convert_parameters import (
tensor_list_to_vector,
vector_to_parameter_list,
)
from backpack.utils.convert_parameters import vector_to_parameter_list


def load_mnist_dataset() -> Dataset:
Expand Down Expand Up @@ -115,5 +113,6 @@ def _autograd_ggn_exact_columns(
e_d_list = vector_to_parameter_list(e_d, trainable_parameters)

ggn_d_list = ggn_vector_product(loss, outputs, model, e_d_list)
ggn_d_list = [t.contiguous() for t in ggn_d_list]

yield d, tensor_list_to_vector(ggn_d_list)
yield d, parameters_to_vector(ggn_d_list)
1 change: 1 addition & 0 deletions setup.cfg
Expand Up @@ -107,6 +107,7 @@ ignore =
W291, # trailing whitespace
W503, # line break before binary operator
W504, # line break after binary operator
B905, # 'zip()' without an explicit 'strict=' parameter
exclude = docs, build, .git, docs_src/rtd, docs_src/rtd_output, .eggs

# Differences with pytorch
Expand Down
2 changes: 1 addition & 1 deletion test/core/derivatives/implementation/autograd.py
Expand Up @@ -248,7 +248,7 @@ def _elementwise_hessian(self, tensor: Tensor, x: Tensor) -> Tensor:
for t in tensor.flatten():
try:
yield self._hessian(t, x)
except (RuntimeError, AttributeError):
except (RuntimeError, AttributeError, TypeError):
yield zeros(*x.shape, *x.shape, device=x.device, dtype=x.dtype)

def hessian_is_zero(self) -> bool: # noqa: D102
Expand Down

0 comments on commit 0fe55d7

Please sign in to comment.