Skip to content

Commit

Permalink
[FIX] Support for torch>=1.13 (#295)
Browse files Browse the repository at this point in the history
* [FIX] Copy `_grad_input_padding` from torch==1.9

The function was removed between torch 1.12.1 and torch 1.13.
Reintroducing it should fix
#272.

* [CI] Use latest two torch releases for tests

* [FIX] Ignore flake8 warning about abstract methods

* [FIX] Import

* [CI] Test torch from 1.9 to 1.13

* [FIX] Ignore 'zip()' without an explicit 'strict=' parameter

* [REF] Make GGNvps contiguous before flattening and concatenation

* [CI] Unambiguously specify tested torch versions

* [REF] Import _grad_input_padding from torch for torch<1.13

* [FIX] Exception handling for Hessians of linear functions

* [REF] Same `_grad_input_padding` import strategy for conv_transpose
  • Loading branch information
f-dangel committed Dec 19, 2022
1 parent 6414462 commit bfb1fde
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 24 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/test.yaml
Expand Up @@ -20,7 +20,13 @@ jobs:
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
pytorch-version: [1.9.0, 1.12.0]
pytorch-version:
- "==1.9.1"
- "==1.10.1"
- "==1.11.0"
- "==1.12.1"
- "==1.13.1"
- "" # latest
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v1
Expand All @@ -30,7 +36,7 @@ jobs:
run: |
python -m pip install --upgrade pip
make install-test
pip install torch==${{ matrix.pytorch-version }} torchvision
pip install torch${{ matrix.pytorch-version }} torchvision
- name: Run test
if: contains('refs/heads/master refs/heads/development refs/heads/release', github.ref)
run: |
Expand Down
7 changes: 6 additions & 1 deletion backpack/core/derivatives/conv_transposend.py
Expand Up @@ -5,16 +5,21 @@
from numpy import prod
from torch import Tensor, einsum
from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, Module
from torch.nn.grad import _grad_input_padding

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION_AT_LEAST_1_13
from backpack.utils.conv import get_conv_function
from backpack.utils.conv_transpose import (
get_conv_transpose_function,
unfold_by_conv_transpose,
)
from backpack.utils.subsampling import subsample

if TORCH_VERSION_AT_LEAST_1_13:
from backpack.utils.conv import _grad_input_padding
else:
from torch.nn.grad import _grad_input_padding


class ConvTransposeNDDerivatives(BaseParameterDerivatives):
"""Base class for partial derivatives of transpose convolution."""
Expand Down
7 changes: 6 additions & 1 deletion backpack/core/derivatives/convnd.py
Expand Up @@ -5,13 +5,18 @@
from numpy import prod
from torch import Tensor, einsum
from torch.nn import Conv1d, Conv2d, Conv3d, Module
from torch.nn.grad import _grad_input_padding

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION_AT_LEAST_1_13
from backpack.utils.conv import get_conv_function, unfold_by_conv
from backpack.utils.conv_transpose import get_conv_transpose_function
from backpack.utils.subsampling import subsample

if TORCH_VERSION_AT_LEAST_1_13:
from backpack.utils.conv import _grad_input_padding
else:
from torch.nn.grad import _grad_input_padding


class weight_jac_t_save_memory:
"""Choose algorithm to apply transposed convolution weight Jacobian."""
Expand Down
1 change: 1 addition & 0 deletions backpack/utils/__init__.py
Expand Up @@ -4,5 +4,6 @@
TORCH_VERSION = packaging.version.parse(get_distribution("torch").version)
TORCH_VERSION_AT_LEAST_1_9_1 = TORCH_VERSION >= packaging.version.parse("1.9.1")
TORCH_VERSION_AT_LEAST_2_0_0 = TORCH_VERSION >= packaging.version.parse("2.0.0")
TORCH_VERSION_AT_LEAST_1_13 = TORCH_VERSION >= packaging.version.parse("1.13")

ADAPTIVE_AVG_POOL_BUG: bool = not TORCH_VERSION_AT_LEAST_2_0_0
46 changes: 46 additions & 0 deletions backpack/utils/conv.py
@@ -1,4 +1,5 @@
from typing import Callable, Type, Union
from warnings import warn

import torch
from einops import rearrange
Expand Down Expand Up @@ -164,3 +165,48 @@ def make_weight():
)

return unfold.reshape(N, C_in * kernel_size_numel, -1)


def _grad_input_padding(
grad_output, input_size, stride, padding, kernel_size, dilation=None
):
"""Determine padding for the VJP of convolution.
Note:
This function was copied from the PyTorch repository (version 1.9).
It was removed between torch 1.12.1 and torch 1.13.
"""
if dilation is None:
# For backward compatibility
warn(
"_grad_input_padding 'dilation' argument not provided. Default of 1 is used."
)
dilation = [1] * len(stride)

input_size = list(input_size)
k = grad_output.dim() - 2

if len(input_size) == k + 2:
input_size = input_size[-k:]
if len(input_size) != k:
raise ValueError(f"input_size must have {k+2} elements (got {len(input_size)})")

def dim_size(d):
return (
(grad_output.size(d + 2) - 1) * stride[d]
- 2 * padding[d]
+ 1
+ dilation[d] * (kernel_size[d] - 1)
)

min_sizes = [dim_size(d) for d in range(k)]
max_sizes = [min_sizes[d] + stride[d] - 1 for d in range(k)]
for size, min_size, max_size in zip(input_size, min_sizes, max_sizes):
if size < min_size or size > max_size:
raise ValueError(
f"requested an input grad size of {input_size}, but valid sizes range "
f"from {min_sizes} to {max_sizes} (for a grad_output of "
f"{grad_output.size()[2:]})"
)

return tuple(input_size[d] - min_sizes[d] for d in range(k))
14 changes: 1 addition & 13 deletions backpack/utils/convert_parameters.py
Expand Up @@ -2,7 +2,7 @@

from typing import Iterable, List

from torch import Tensor, cat, typename
from torch import Tensor, typename


def vector_to_parameter_list(vec: Tensor, parameters: Iterable[Tensor]) -> List[Tensor]:
Expand Down Expand Up @@ -51,15 +51,3 @@ def vector_to_parameter_list(vec: Tensor, parameters: Iterable[Tensor]) -> List[
pointer += num_param

return params_new


def tensor_list_to_vector(tensor_list: Iterable[Tensor]) -> Tensor:
"""Convert a list of tensors into a vector by flattening and concatenation.
Args:
tensor_list: List of tensors.
Returns:
Vector containing the flattened and concatenated tensor inputs.
"""
return cat([t.flatten() for t in tensor_list])
9 changes: 4 additions & 5 deletions backpack/utils/examples.py
Expand Up @@ -3,15 +3,13 @@

from torch import Tensor, stack, zeros
from torch.nn import Module
from torch.nn.utils.convert_parameters import parameters_to_vector
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor

from backpack.hessianfree.ggnvp import ggn_vector_product
from backpack.utils.convert_parameters import (
tensor_list_to_vector,
vector_to_parameter_list,
)
from backpack.utils.convert_parameters import vector_to_parameter_list


def load_mnist_dataset() -> Dataset:
Expand Down Expand Up @@ -115,5 +113,6 @@ def _autograd_ggn_exact_columns(
e_d_list = vector_to_parameter_list(e_d, trainable_parameters)

ggn_d_list = ggn_vector_product(loss, outputs, model, e_d_list)
ggn_d_list = [t.contiguous() for t in ggn_d_list]

yield d, tensor_list_to_vector(ggn_d_list)
yield d, parameters_to_vector(ggn_d_list)
3 changes: 2 additions & 1 deletion setup.cfg
Expand Up @@ -34,7 +34,7 @@ setup_requires =
setuptools_scm
# Dependencies of the project (semicolon/line-separated):
install_requires =
torch >= 1.9.0, < 1.13.0
torch >= 1.9.0, < 2.0.0
torchvision >= 0.7.0, < 1.0.0
einops >= 0.3.0, < 1.0.0
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
Expand Down Expand Up @@ -106,6 +106,7 @@ ignore =
W291, # trailing whitespace
W503, # line break before binary operator
W504, # line break after binary operator
B905, # 'zip()' without an explicit 'strict=' parameter
exclude = docs, build, .git, docs_src/rtd, docs_src/rtd_output, .eggs

# Differences with pytorch
Expand Down
2 changes: 1 addition & 1 deletion test/core/derivatives/implementation/autograd.py
Expand Up @@ -248,7 +248,7 @@ def _elementwise_hessian(self, tensor: Tensor, x: Tensor) -> Tensor:
for t in tensor.flatten():
try:
yield self._hessian(t, x)
except (RuntimeError, AttributeError):
except (RuntimeError, AttributeError, TypeError):
yield zeros(*x.shape, *x.shape, device=x.device, dtype=x.dtype)

def hessian_is_zero(self) -> bool: # noqa: D102
Expand Down

0 comments on commit bfb1fde

Please sign in to comment.