Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repeated calls to generator and discriminator's forward in GAN tutorial #306

Open
jhauret opened this issue Jan 18, 2024 · 1 comment
Open

Comments

@jhauret
Copy link

jhauret commented Jan 18, 2024

馃摎 Documentation

In the training_step of the GAN(L.LightningModule) the generator and discriminator forward are called several times on the same input. Obviously this slows down the training because more computation is required. I wonder if we could just reuse the results of the first call. After all, the toggle_optimiser/untoggle_optimiser functions should make it safe, right?

For the generator:

  • First call: to log images self.generated_imgs = self(z)
  • Second call: Inside the generator optimization self.discriminator(self(z))
  • Third call: Inside the discriminator optimization self.discriminator(self(z).detach())

For the discriminator:

  • First call: Inside the generator optimization self.discriminator(self(z))
  • Second call: Inside the discriminator optimization self.discriminator(self(z).detach())

cc @Borda

@awaelchli awaelchli transferred this issue from Lightning-AI/pytorch-lightning Jan 20, 2024
@jhauret
Copy link
Author

jhauret commented Jan 22, 2024

Still on the basic GAN tutorial, I spotted a few more track of improvements:

  • on_validation_epoch_end is ignored because validation_step is not defined.
  • After training, the quality of the samples is still very poor. I understand that the code is just an entry point for newcomers, but this poor performance makes the user doubtful about the tutorial. We could easily improve this by using a small convolutional architecture for both the generator and the discriminator without complicating the code. (e.g. Generator, Discriminator )
  • add_image called with the constant argument global_step=0 is overwriting the results of previous epochs.
  • The repeated calls to the generator and discriminator could be simplified to improve training speed (thanks to retain_graph=True).

I am willing to implement and submit a PR if you find this helpful 馃槂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant