From 6414462596f6c19629b3fa21394541980ec23fe3 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Thu, 3 Nov 2022 10:32:29 +0100 Subject: [PATCH] [CI] Test with `torch=={1.9.0, 1.12.0}` and make tests compatible (#276) * [CI] Test with `torch=={1.9.0, 1.10.0}` * [CI] Test with `torch=={1.9.0, 1.11.0}` * [FIX] flake8 * [CI] Test with `torch=={1.9.0, 1.12.0}` * [TEST] Replace `parameters_to_vector` by custom function This should fix `test_network_diag_ggn[]` in `test/converter/test_converter.py`. Between torch 1.11.0 and torch 1.12.0, the GGN-vector products for this case became non-contiguous, and `torch.nn.utils.convert_parameters.parameters_to_vector` stopped working as it uses `view`. Here is a short self-contained snippet to reproduce the issue: ```python from torch import Tensor, permute, rand, rand_like from torch.autograd import grad from torch.nn import Linear, Module from torch.nn.utils.convert_parameters import parameters_to_vector from backpack.utils.convert_parameters import tensor_list_to_vector class Permute(Module): def __init__(self): super().__init__() self.batch_size = 3 self.in_dim = (5, 3) out_dim = 2 self.linear = Linear(self.in_dim[-1], out_dim) self.linear2 = Linear(self.in_dim[-2], out_dim) def forward(self, x): x = self.linear(x) x = x.permute(0, 2, 1) # method permute x = self.linear2(x) x = permute(x, (0, 2, 1)) # function permute return x def input_fn(self) -> Tensor: return rand(self.batch_size, *self.in_dim) model = Permute() inputs = model.input_fn() outputs = model(inputs) params = list(model.parameters()) grad_outputs = rand_like(outputs) v = [rand_like(p) for p in model.parameters()] vJ_tuple = grad(outputs, params, grad_outputs=grad_outputs) for p, vJ in zip(params, vJ_tuple): # all contiguous() print(p.shape, vJ.shape) # between 1.11.0 and 1.12.0, the vector-Jacobian product w.r.t. the second # linear layer's weight is not contiguous anymore print(p.is_contiguous(), vJ.is_contiguous()) vJ_vector = parameters_to_vector(vJ_tuple) vJ_vector = tensor_list_to_vector(vJ_tuple) ``` * [REF] Use f-string and add type hints * [REQ] Require `torch<1.13` See https://github.com/f-dangel/backpack/issues/272. Waiting for https://github.com/pytorch/pytorch/issues/88312 before `torch>=1.13` can be supported. * [DOC] Update changelog to prepare compatibility patch * [DOC] fix date Co-authored-by: Felix Dangel --- .github/workflows/test.yaml | 2 +- backpack/core/derivatives/basederivatives.py | 2 +- backpack/utils/convert_parameters.py | 45 ++++++++++++++------ backpack/utils/examples.py | 8 ++-- changelog.md | 13 +++++- fully_documented.txt | 1 + setup.cfg | 2 +- 7 files changed, 51 insertions(+), 22 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index ee7391bf..a4ec683b 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -20,7 +20,7 @@ jobs: strategy: matrix: python-version: [3.7, 3.8, 3.9] - pytorch-version: [1.9.0, 1.9.1] + pytorch-version: [1.9.0, 1.12.0] steps: - uses: actions/checkout@v1 - uses: actions/setup-python@v1 diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 94c15288..e3c39805 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -9,7 +9,7 @@ from backpack.core.derivatives import shape_check -class BaseDerivatives(ABC): +class BaseDerivatives(ABC): # noqa: B024 """First- and second-order partial derivatives of unparameterized module. Note: diff --git a/backpack/utils/convert_parameters.py b/backpack/utils/convert_parameters.py index b3f91731..2ec80f77 100644 --- a/backpack/utils/convert_parameters.py +++ b/backpack/utils/convert_parameters.py @@ -1,9 +1,12 @@ -import torch +"""Utility functions to convert between parameter lists and vectors.""" +from typing import Iterable, List -def vector_to_parameter_list(vec, parameters): - """ - Convert the vector `vec` to a parameter-list format matching `parameters`. +from torch import Tensor, cat, typename + + +def vector_to_parameter_list(vec: Tensor, parameters: Iterable[Tensor]) -> List[Tensor]: + """Convert the vector `vec` to a parameter-list format matching `parameters`. This function is the inverse of `parameters_to_vector` from the pytorch module `torch.nn.utils.convert_parameters`. @@ -21,18 +24,20 @@ def vector_to_parameter_list(vec, parameters): assert torch.all_close(a, b) ``` - Parameters: - ----------- - vec: Tensor - a single vector represents the parameters of a model - parameters: (Iterable[Tensor]) - an iterator of Tensors that are of the desired shapes. + Args: + vec: A single vector represents the parameters of a model + parameters: An iterator of Tensors that are of the desired shapes. + + Raises: + TypeError: If `vec` is not a PyTorch tensor. + + Returns: + List of parameter-shaped tensors containing the entries of `vec`. """ # Ensure vec of type Tensor - if not isinstance(vec, torch.Tensor): - raise TypeError( - "expected torch.Tensor, but got: {}".format(torch.typename(vec)) - ) + if not isinstance(vec, Tensor): + raise TypeError(f"expected Tensor, but got: {typename(vec)}") + params_new = [] # Pointer for slicing the vector for each parameter pointer = 0 @@ -46,3 +51,15 @@ def vector_to_parameter_list(vec, parameters): pointer += num_param return params_new + + +def tensor_list_to_vector(tensor_list: Iterable[Tensor]) -> Tensor: + """Convert a list of tensors into a vector by flattening and concatenation. + + Args: + tensor_list: List of tensors. + + Returns: + Vector containing the flattened and concatenated tensor inputs. + """ + return cat([t.flatten() for t in tensor_list]) diff --git a/backpack/utils/examples.py b/backpack/utils/examples.py index 78824114..519f600d 100644 --- a/backpack/utils/examples.py +++ b/backpack/utils/examples.py @@ -3,13 +3,15 @@ from torch import Tensor, stack, zeros from torch.nn import Module -from torch.nn.utils.convert_parameters import parameters_to_vector from torch.utils.data import DataLoader, Dataset from torchvision.datasets import MNIST from torchvision.transforms import Compose, Normalize, ToTensor from backpack.hessianfree.ggnvp import ggn_vector_product -from backpack.utils.convert_parameters import vector_to_parameter_list +from backpack.utils.convert_parameters import ( + tensor_list_to_vector, + vector_to_parameter_list, +) def load_mnist_dataset() -> Dataset: @@ -114,4 +116,4 @@ def _autograd_ggn_exact_columns( ggn_d_list = ggn_vector_product(loss, outputs, model, e_d_list) - yield d, parameters_to_vector(ggn_d_list) + yield d, tensor_list_to_vector(ggn_d_list) diff --git a/changelog.md b/changelog.md index 7526fa4c..5d617267 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.5.1] - 2022-11-03 + +This patch fixes temporary compatibility issues with the latest PyTorch release. + +### Fixed/Removed +- Circumvent compatibility issues with `torch==1.13.0` by requiring + `torch<1.13.0`` ([PR](https://github.com/f-dangel/backpack/pull/276)) + ## [1.5.0] - 2022-02-15 This small release improves ResNet support of some second-order extensions and @@ -378,8 +386,9 @@ co-authoring many PRs shipped in this release. Initial release -[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.5.0...HEAD -[1.4.0]: https://github.com/f-dangel/backpack/compare/1.5.0...1.4.0 +[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.5.1...HEAD +[1.5.1]: https://github.com/f-dangel/backpack/compare/1.5.1...1.5.0 +[1.5.0]: https://github.com/f-dangel/backpack/compare/1.5.0...1.4.0 [1.4.0]: https://github.com/f-dangel/backpack/compare/1.4.0...1.3.0 [1.3.0]: https://github.com/f-dangel/backpack/compare/1.3.0...1.2.0 [1.2.0]: https://github.com/f-dangel/backpack/compare/1.2.0...1.1.1 diff --git a/fully_documented.txt b/fully_documented.txt index 275f0b92..9441c26e 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -76,6 +76,7 @@ backpack/utils/__init__.py backpack/utils/module_classification.py backpack/utils/hooks.py backpack/utils/examples.py +backpack/utils/convert_parameters.py test/extensions/automated_settings.py test/extensions/problem.py diff --git a/setup.cfg b/setup.cfg index b4ed8c35..70a38822 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,7 @@ setup_requires = setuptools_scm # Dependencies of the project (semicolon/line-separated): install_requires = - torch >= 1.9.0, < 2.0.0 + torch >= 1.9.0, < 1.13.0 torchvision >= 0.7.0, < 1.0.0 einops >= 0.3.0, < 1.0.0 # Require a specific Python version, e.g. Python 2.7 or >= 3.4