-
-
Notifications
You must be signed in to change notification settings - Fork 385
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* SWA script added * Codestyle fixed * fixed pr issues * added test and fix * codestyle fix * fixes * test fixed * division fix * codestyle fixed * fixed again * check keys matching added * fixed test * fixed docs * fix
- Loading branch information
Ivan Ivashnev
committed
Oct 14, 2020
1 parent
cf1190c
commit f2acebb
Showing
5 changed files
with
197 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import argparse | ||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from catalyst.dl.utils.swa import generate_averaged_weights | ||
|
||
|
||
def build_args(parser: ArgumentParser): | ||
"""Builds the command line parameters.""" | ||
parser.add_argument( | ||
"--logdir", type=Path, default=None, help="Path to models logdir" | ||
) | ||
parser.add_argument( | ||
"--models-mask", | ||
"-m", | ||
type=str, | ||
default="*.pth", | ||
help="Pattern for models to average", | ||
) | ||
parser.add_argument( | ||
"--output-path", | ||
type=Path, | ||
default="./swa.pth", | ||
help="Path to save averaged model", | ||
) | ||
|
||
return parser | ||
|
||
|
||
def parse_args(): | ||
"""Parses the command line arguments for the main method.""" | ||
parser = argparse.ArgumentParser() | ||
build_args(parser) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(args, _): | ||
"""Main method for ``catalyst-dl swa``.""" | ||
logdir: Path = args.logdir | ||
models_mask: str = args.models_mask | ||
output_path: Path = args.output_path | ||
|
||
averaged_weights = generate_averaged_weights(logdir, models_mask) | ||
|
||
torch.save(averaged_weights, str(output_path)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main(parse_args(), None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import os | ||
from pathlib import Path | ||
import shutil | ||
import unittest | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from catalyst.dl.utils.swa import generate_averaged_weights | ||
from catalyst.utils.checkpoint import load_checkpoint | ||
|
||
|
||
class Net(nn.Module): | ||
"""Dummy network class.""" | ||
|
||
def __init__(self, init_weight=4): | ||
"""Initialization of network and filling it with given numbers.""" | ||
super(Net, self).__init__() | ||
self.fc = nn.Linear(2, 1) | ||
self.fc.weight.data.fill_(init_weight) | ||
self.fc.bias.data.fill_(init_weight) | ||
|
||
|
||
class TestSwa(unittest.TestCase): | ||
"""Test SWA class.""" | ||
|
||
def setUp(self): | ||
"""Test set up.""" | ||
net1 = Net(init_weight=2.0) | ||
net2 = Net(init_weight=5.0) | ||
os.mkdir("./checkpoints") | ||
torch.save(net1.state_dict(), "./checkpoints/net1.pth") | ||
torch.save(net2.state_dict(), "./checkpoints/net2.pth") | ||
|
||
def tearDown(self): | ||
"""Test tear down.""" | ||
shutil.rmtree("./checkpoints") | ||
|
||
def test_averaging(self): | ||
"""Test SWA method.""" | ||
weights = generate_averaged_weights( | ||
logdir=Path("./"), models_mask="net*" | ||
) | ||
torch.save(weights, str("./checkpoints/swa_weights.pth")) | ||
model = Net() | ||
model.load_state_dict(load_checkpoint("./checkpoints/swa_weights.pth")) | ||
self.assertEqual(float(model.fc.weight.data[0][0]), 3.5) | ||
self.assertEqual(float(model.fc.weight.data[0][1]), 3.5) | ||
self.assertEqual(float(model.fc.bias.data[0]), 3.5) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from typing import List, Union | ||
from collections import OrderedDict | ||
import glob | ||
import os | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from catalyst.utils.checkpoint import load_checkpoint | ||
|
||
|
||
def average_weights(state_dicts: List[dict]) -> OrderedDict: | ||
""" | ||
Averaging of input weights. | ||
Args: | ||
state_dicts: Weights to average | ||
Raises: | ||
KeyError: If states do not match | ||
Returns: | ||
Averaged weights | ||
""" | ||
# source https://gist.github.com/qubvel/70c3d5e4cddcde731408f478e12ef87b | ||
params_keys = None | ||
for i, state_dict in enumerate(state_dicts): | ||
model_params_keys = list(state_dict.keys()) | ||
if params_keys is None: | ||
params_keys = model_params_keys | ||
elif params_keys != model_params_keys: | ||
raise KeyError( | ||
"For checkpoint {}, expected list of params: {}, " | ||
"but found: {}".format(i, params_keys, model_params_keys) | ||
) | ||
|
||
average_dict = OrderedDict() | ||
for k in state_dicts[0].keys(): | ||
average_dict[k] = torch.div( | ||
sum(state_dict[k] for state_dict in state_dicts), len(state_dicts), | ||
) | ||
|
||
return average_dict | ||
|
||
|
||
def load_weight(path: str) -> dict: | ||
""" | ||
Load weights of a model. | ||
Args: | ||
path: Path to model weights | ||
Returns: | ||
Weights | ||
""" | ||
weights = load_checkpoint(path) | ||
if "model_state_dict" in weights: | ||
weights = weights["model_state_dict"] | ||
return weights | ||
|
||
|
||
def generate_averaged_weights( | ||
logdir: Union[str, Path], models_mask: str | ||
) -> OrderedDict: | ||
""" | ||
Averaging of input weights and saving them. | ||
Args: | ||
logdir: Path to logs directory | ||
models_mask: globe-like pattern for models to average | ||
Returns: | ||
Averaged weights | ||
""" | ||
if logdir is None: | ||
models_pathes = glob.glob(models_mask) | ||
else: | ||
models_pathes = glob.glob( | ||
os.path.join(logdir, "checkpoints", models_mask) | ||
) | ||
|
||
all_weights = [load_weight(path) for path in models_pathes] | ||
averaged_dict = average_weights(all_weights) | ||
|
||
return averaged_dict |