Skip to content

Commit

Permalink
[CI] Test with torch=={1.9.0, 1.12.0} and make tests compatible (#276)
Browse files Browse the repository at this point in the history
* [CI] Test with `torch=={1.9.0, 1.10.0}`

* [CI] Test with `torch=={1.9.0, 1.11.0}`

* [FIX] flake8

* [CI] Test with `torch=={1.9.0, 1.12.0}`

* [TEST] Replace `parameters_to_vector` by custom function

This should fix
`test_network_diag_ggn[<class
'test.converter.converter_cases._Permute'>]`
in `test/converter/test_converter.py`. Between torch 1.11.0 and torch
1.12.0, the GGN-vector products for this case became non-contiguous, and
`torch.nn.utils.convert_parameters.parameters_to_vector` stopped working
as it uses `view`.

Here is a short self-contained snippet to reproduce the issue:

```python
from torch import Tensor, permute, rand, rand_like
from torch.autograd import grad
from torch.nn import Linear, Module
from torch.nn.utils.convert_parameters import parameters_to_vector

from backpack.utils.convert_parameters import tensor_list_to_vector

class Permute(Module):
    def __init__(self):
        super().__init__()
        self.batch_size = 3
        self.in_dim = (5, 3)
        out_dim = 2
        self.linear = Linear(self.in_dim[-1], out_dim)
        self.linear2 = Linear(self.in_dim[-2], out_dim)

    def forward(self, x):
        x = self.linear(x)
        x = x.permute(0, 2, 1)  # method permute
        x = self.linear2(x)
        x = permute(x, (0, 2, 1))  # function permute
        return x

    def input_fn(self) -> Tensor:
        return rand(self.batch_size, *self.in_dim)

model = Permute()

inputs = model.input_fn()
outputs = model(inputs)

params = list(model.parameters())
grad_outputs = rand_like(outputs)
v = [rand_like(p) for p in model.parameters()]

vJ_tuple = grad(outputs, params, grad_outputs=grad_outputs)

for p, vJ in zip(params, vJ_tuple):
    # all contiguous()
    print(p.shape, vJ.shape)
    # between 1.11.0 and 1.12.0, the vector-Jacobian product w.r.t. the second
    # linear layer's weight is not contiguous anymore
    print(p.is_contiguous(), vJ.is_contiguous())

vJ_vector = parameters_to_vector(vJ_tuple)

vJ_vector = tensor_list_to_vector(vJ_tuple)
```

* [REF] Use f-string and add type hints

* [REQ] Require `torch<1.13`

See #272. Waiting for
pytorch/pytorch#88312 before `torch>=1.13`
can be supported.

* [DOC] Update changelog to prepare compatibility patch

* [DOC] fix date

Co-authored-by: Felix Dangel <fdangel@tue.mpg.de>
  • Loading branch information
f-dangel and f-dangel committed Nov 3, 2022
1 parent 0ab9421 commit 6414462
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
pytorch-version: [1.9.0, 1.9.1]
pytorch-version: [1.9.0, 1.12.0]
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v1
Expand Down
2 changes: 1 addition & 1 deletion backpack/core/derivatives/basederivatives.py
Expand Up @@ -9,7 +9,7 @@
from backpack.core.derivatives import shape_check


class BaseDerivatives(ABC):
class BaseDerivatives(ABC): # noqa: B024
"""First- and second-order partial derivatives of unparameterized module.
Note:
Expand Down
45 changes: 31 additions & 14 deletions backpack/utils/convert_parameters.py
@@ -1,9 +1,12 @@
import torch
"""Utility functions to convert between parameter lists and vectors."""

from typing import Iterable, List

def vector_to_parameter_list(vec, parameters):
"""
Convert the vector `vec` to a parameter-list format matching `parameters`.
from torch import Tensor, cat, typename


def vector_to_parameter_list(vec: Tensor, parameters: Iterable[Tensor]) -> List[Tensor]:
"""Convert the vector `vec` to a parameter-list format matching `parameters`.
This function is the inverse of `parameters_to_vector` from the
pytorch module `torch.nn.utils.convert_parameters`.
Expand All @@ -21,18 +24,20 @@ def vector_to_parameter_list(vec, parameters):
assert torch.all_close(a, b)
```
Parameters:
-----------
vec: Tensor
a single vector represents the parameters of a model
parameters: (Iterable[Tensor])
an iterator of Tensors that are of the desired shapes.
Args:
vec: A single vector represents the parameters of a model
parameters: An iterator of Tensors that are of the desired shapes.
Raises:
TypeError: If `vec` is not a PyTorch tensor.
Returns:
List of parameter-shaped tensors containing the entries of `vec`.
"""
# Ensure vec of type Tensor
if not isinstance(vec, torch.Tensor):
raise TypeError(
"expected torch.Tensor, but got: {}".format(torch.typename(vec))
)
if not isinstance(vec, Tensor):
raise TypeError(f"expected Tensor, but got: {typename(vec)}")

params_new = []
# Pointer for slicing the vector for each parameter
pointer = 0
Expand All @@ -46,3 +51,15 @@ def vector_to_parameter_list(vec, parameters):
pointer += num_param

return params_new


def tensor_list_to_vector(tensor_list: Iterable[Tensor]) -> Tensor:
"""Convert a list of tensors into a vector by flattening and concatenation.
Args:
tensor_list: List of tensors.
Returns:
Vector containing the flattened and concatenated tensor inputs.
"""
return cat([t.flatten() for t in tensor_list])
8 changes: 5 additions & 3 deletions backpack/utils/examples.py
Expand Up @@ -3,13 +3,15 @@

from torch import Tensor, stack, zeros
from torch.nn import Module
from torch.nn.utils.convert_parameters import parameters_to_vector
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor

from backpack.hessianfree.ggnvp import ggn_vector_product
from backpack.utils.convert_parameters import vector_to_parameter_list
from backpack.utils.convert_parameters import (
tensor_list_to_vector,
vector_to_parameter_list,
)


def load_mnist_dataset() -> Dataset:
Expand Down Expand Up @@ -114,4 +116,4 @@ def _autograd_ggn_exact_columns(

ggn_d_list = ggn_vector_product(loss, outputs, model, e_d_list)

yield d, parameters_to_vector(ggn_d_list)
yield d, tensor_list_to_vector(ggn_d_list)
13 changes: 11 additions & 2 deletions changelog.md
Expand Up @@ -6,6 +6,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [1.5.1] - 2022-11-03

This patch fixes temporary compatibility issues with the latest PyTorch release.

### Fixed/Removed
- Circumvent compatibility issues with `torch==1.13.0` by requiring
`torch<1.13.0`` ([PR](https://github.com/f-dangel/backpack/pull/276))

## [1.5.0] - 2022-02-15

This small release improves ResNet support of some second-order extensions and
Expand Down Expand Up @@ -378,8 +386,9 @@ co-authoring many PRs shipped in this release.

Initial release

[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.5.0...HEAD
[1.4.0]: https://github.com/f-dangel/backpack/compare/1.5.0...1.4.0
[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.5.1...HEAD
[1.5.1]: https://github.com/f-dangel/backpack/compare/1.5.1...1.5.0
[1.5.0]: https://github.com/f-dangel/backpack/compare/1.5.0...1.4.0
[1.4.0]: https://github.com/f-dangel/backpack/compare/1.4.0...1.3.0
[1.3.0]: https://github.com/f-dangel/backpack/compare/1.3.0...1.2.0
[1.2.0]: https://github.com/f-dangel/backpack/compare/1.2.0...1.1.1
Expand Down
1 change: 1 addition & 0 deletions fully_documented.txt
Expand Up @@ -76,6 +76,7 @@ backpack/utils/__init__.py
backpack/utils/module_classification.py
backpack/utils/hooks.py
backpack/utils/examples.py
backpack/utils/convert_parameters.py

test/extensions/automated_settings.py
test/extensions/problem.py
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -34,7 +34,7 @@ setup_requires =
setuptools_scm
# Dependencies of the project (semicolon/line-separated):
install_requires =
torch >= 1.9.0, < 2.0.0
torch >= 1.9.0, < 1.13.0
torchvision >= 0.7.0, < 1.0.0
einops >= 0.3.0, < 1.0.0
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
Expand Down

0 comments on commit 6414462

Please sign in to comment.