Skip to content

Commit

Permalink
SWA script added (#945)
Browse files Browse the repository at this point in the history
* SWA script added

* Codestyle fixed

* fixed pr issues

* added test and fix

* codestyle fix

* fixes

* test fixed

* division fix

* codestyle fixed

* fixed again

* check keys matching added

* fixed test

* fixed docs

* fix
  • Loading branch information
Ivan Ivashnev committed Oct 14, 2020
1 parent cf1190c commit f2acebb
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- SoftMax, CosFace, ArcFace layers to contrib ([#939](https://github.com/catalyst-team/catalyst/pull/939))
- ArcMargin layer to contrib ([#957](https://github.com/catalyst-team/catalyst/pull/957))
- AdaCos to contrib ([#958](https://github.com/catalyst-team/catalyst/pull/958))
- Manual SWA to utils (https://github.com/catalyst-team/catalyst/pull/945)

### Changed

Expand Down Expand Up @@ -305,4 +306,4 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
-
5 changes: 5 additions & 0 deletions bin/tests/check_dl_cv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ PYTHONPATH=./examples:.:${PYTHONPATH} \
python catalyst/dl/scripts/trace.py \
${LOGDIR}

echo 'pipeline 01 - swa'
PYTHONPATH=./examples:.:${PYTHONPATH} \
python catalyst/dl/scripts/swa.py \
--logdir=${LOGDIR} --output-path=./swa.pth

rm -rf ${LOGDIR}


Expand Down
52 changes: 52 additions & 0 deletions catalyst/dl/scripts/swa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import argparse
from argparse import ArgumentParser
from pathlib import Path

import torch

from catalyst.dl.utils.swa import generate_averaged_weights


def build_args(parser: ArgumentParser):
"""Builds the command line parameters."""
parser.add_argument(
"--logdir", type=Path, default=None, help="Path to models logdir"
)
parser.add_argument(
"--models-mask",
"-m",
type=str,
default="*.pth",
help="Pattern for models to average",
)
parser.add_argument(
"--output-path",
type=Path,
default="./swa.pth",
help="Path to save averaged model",
)

return parser


def parse_args():
"""Parses the command line arguments for the main method."""
parser = argparse.ArgumentParser()
build_args(parser)
args = parser.parse_args()
return args


def main(args, _):
"""Main method for ``catalyst-dl swa``."""
logdir: Path = args.logdir
models_mask: str = args.models_mask
output_path: Path = args.output_path

averaged_weights = generate_averaged_weights(logdir, models_mask)

torch.save(averaged_weights, str(output_path))


if __name__ == "__main__":
main(parse_args(), None)
53 changes: 53 additions & 0 deletions catalyst/dl/tests/test_swa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
from pathlib import Path
import shutil
import unittest

import torch
import torch.nn as nn

from catalyst.dl.utils.swa import generate_averaged_weights
from catalyst.utils.checkpoint import load_checkpoint


class Net(nn.Module):
"""Dummy network class."""

def __init__(self, init_weight=4):
"""Initialization of network and filling it with given numbers."""
super(Net, self).__init__()
self.fc = nn.Linear(2, 1)
self.fc.weight.data.fill_(init_weight)
self.fc.bias.data.fill_(init_weight)


class TestSwa(unittest.TestCase):
"""Test SWA class."""

def setUp(self):
"""Test set up."""
net1 = Net(init_weight=2.0)
net2 = Net(init_weight=5.0)
os.mkdir("./checkpoints")
torch.save(net1.state_dict(), "./checkpoints/net1.pth")
torch.save(net2.state_dict(), "./checkpoints/net2.pth")

def tearDown(self):
"""Test tear down."""
shutil.rmtree("./checkpoints")

def test_averaging(self):
"""Test SWA method."""
weights = generate_averaged_weights(
logdir=Path("./"), models_mask="net*"
)
torch.save(weights, str("./checkpoints/swa_weights.pth"))
model = Net()
model.load_state_dict(load_checkpoint("./checkpoints/swa_weights.pth"))
self.assertEqual(float(model.fc.weight.data[0][0]), 3.5)
self.assertEqual(float(model.fc.weight.data[0][1]), 3.5)
self.assertEqual(float(model.fc.bias.data[0]), 3.5)


if __name__ == "__main__":
unittest.main()
85 changes: 85 additions & 0 deletions catalyst/dl/utils/swa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import List, Union
from collections import OrderedDict
import glob
import os
from pathlib import Path

import torch

from catalyst.utils.checkpoint import load_checkpoint


def average_weights(state_dicts: List[dict]) -> OrderedDict:
"""
Averaging of input weights.
Args:
state_dicts: Weights to average
Raises:
KeyError: If states do not match
Returns:
Averaged weights
"""
# source https://gist.github.com/qubvel/70c3d5e4cddcde731408f478e12ef87b
params_keys = None
for i, state_dict in enumerate(state_dicts):
model_params_keys = list(state_dict.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
"For checkpoint {}, expected list of params: {}, "
"but found: {}".format(i, params_keys, model_params_keys)
)

average_dict = OrderedDict()
for k in state_dicts[0].keys():
average_dict[k] = torch.div(
sum(state_dict[k] for state_dict in state_dicts), len(state_dicts),
)

return average_dict


def load_weight(path: str) -> dict:
"""
Load weights of a model.
Args:
path: Path to model weights
Returns:
Weights
"""
weights = load_checkpoint(path)
if "model_state_dict" in weights:
weights = weights["model_state_dict"]
return weights


def generate_averaged_weights(
logdir: Union[str, Path], models_mask: str
) -> OrderedDict:
"""
Averaging of input weights and saving them.
Args:
logdir: Path to logs directory
models_mask: globe-like pattern for models to average
Returns:
Averaged weights
"""
if logdir is None:
models_pathes = glob.glob(models_mask)
else:
models_pathes = glob.glob(
os.path.join(logdir, "checkpoints", models_mask)
)

all_weights = [load_weight(path) for path in models_pathes]
averaged_dict = average_weights(all_weights)

return averaged_dict

0 comments on commit f2acebb

Please sign in to comment.