Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SWA script added #945

Merged
merged 16 commits into from
Oct 14, 2020
12 changes: 1 addition & 11 deletions catalyst/dl/scripts/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@ def build_args(parser: ArgumentParser):
parser.add_argument(
"--models_mask", "-m", type=str, help="Pattern for models to average"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's simplify to --models-mask

)
parser.add_argument(
"--save-avaraged-model",
"-s",
type=bool,
default=True,
help="Flag for saving avaraged model",
)

return parser

Expand All @@ -37,11 +30,8 @@ def main(args, _):
"""Main method for ``catalyst-dl swa``."""
logdir: Path = args.logdir
models_mask: str = args.models_mask
save_avaraged_model: bool = args.save_avaraged_model

averaged_weights = generate_averaged_weights(
logdir, models_mask, save_avaraged_model=save_avaraged_model
)
averaged_weights = generate_averaged_weights(logdir, models_mask)
Scitator marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
Expand Down
18 changes: 6 additions & 12 deletions catalyst/dl/utils/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def average_weights(state_dicts: List[dict]) -> OrderedDict:
"""
Averaging of input weights.
Args:
state_dicts (List[dict]): Weights to avarage
state_dicts (List[dict]): Weights to average
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please remove type from docs (the type already specified in arguments)

Returns:
Avaraged weights
Averaged weights
"""
# source https://gist.github.com/qubvel/70c3d5e4cddcde731408f478e12ef87b

Expand All @@ -45,17 +45,14 @@ def load_weight(path: str) -> dict:
return weights


def generate_averaged_weights(
logdir: Path, models_mask: str, save_avaraged_model: bool = True
) -> OrderedDict:
def generate_averaged_weights(logdir: Path, models_mask: str) -> OrderedDict:
"""
Averaging of input weights.
Averaging of input weights and saving them.
Args:
logdir (Path): Path to logs directory
models_mask (str): globe-like pattern for models to average
save_avaraged_model (bool): Flag for saving avaraged model
Returns:
Avaraged weights
Averaged weights
"""

config_path = logdir / "configs" / "_config.json"
Expand All @@ -66,9 +63,6 @@ def generate_averaged_weights(
all_weights = [load_weight(path) for path in models_pathes]
averaged_dict = average_weights(all_weights)

if save_avaraged_model:
torch.save(
averaged_dict, str(logdir / "checkpoints" / "swa_weights.pth")
)
torch.save(averaged_dict, str(logdir / "checkpoints" / "swa_weights.pth"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's return the model and save it outside


return averaged_dict