Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Report CycleGAN validation metrics correctly to wandb #2131

Open
wants to merge 87 commits into
base: master
Choose a base branch
from

Conversation

mcgibbon
Copy link
Contributor

@mcgibbon mcgibbon commented Jan 5, 2023

Currently the CycleGAN training routine does not report the same training losses on validation data. This PR refactors the code to use training losses for validation data, produce one wandb report per epoch, and adds regularization loss as an output metric.

Refactored public API:

  • CycleGAN metrics have been updated to show all metrics in one report, and show all training losses on validation data

Significant internal changes:

  • black is moved before flake8 in pre-commit-hooks to reduce line length errors in flake8

Coverage reports (updated automatically):

  • test_unit: 60%

@@ -1,5 +1,10 @@
exclude: "external/gcsfs/"
repos:
- repo: https://github.com/psf/black
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like black to run before flake8 so that it can auto-fix flake8 issues before flake8 runs.

generator
discriminator_optimizer: configuration for the optimizer used to train the
discriminator
optimizer: configuration for the optimizer used to train the
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging these was necessary so a wandb sweep can operate on the learning rate, there's no way to pair two hyperparameters for wandb sweeps.

@@ -316,76 +314,8 @@ def _init_targets(self, shape: Tuple[int, ...]):
torch.Tensor(shape).fill_(0.0).to(DEVICE), requires_grad=False
)

def evaluate_on_dataset(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function was never actually used (it was called, but only on validation data, and I never provided validation datasets before now).

@@ -395,6 +325,8 @@ def train_on_batch(
[sample, time, tile, channel, y, x]
real_b: a batch of data from domain B, should have shape
[sample, time, tile, channel, y, x]
training: if True, the model will be trained, otherwise we will
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows getting training metrics on validation data.

@@ -51,12 +50,11 @@ class CycleGANNetworkConfig:
cycle_weight: weight of the cycle loss
generator_weight: weight of the generator's gan loss
discriminator_weight: weight of the discriminator gan loss
reload_path: path to a directory containing a saved CycleGAN model to use
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just a missing docstring entry.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant