You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the training_step of the GAN(L.LightningModule) the generator and discriminator forward are called several times on the same input. Obviously this slows down the training because more computation is required. I wonder if we could just reuse the results of the first call. After all, the toggle_optimiser/untoggle_optimiser functions should make it safe, right?
For the generator:
First call: to log images self.generated_imgs = self(z)
Second call: Inside the generator optimization self.discriminator(self(z))
Third call: Inside the discriminator optimization self.discriminator(self(z).detach())
For the discriminator:
First call: Inside the generator optimization self.discriminator(self(z))
Second call: Inside the discriminator optimization self.discriminator(self(z).detach())
After training, the quality of the samples is still very poor. I understand that the code is just an entry point for newcomers, but this poor performance makes the user doubtful about the tutorial. We could easily improve this by using a small convolutional architecture for both the generator and the discriminator without complicating the code. (e.g.Generator, Discriminator )
add_image called with the constant argument global_step=0 is overwriting the results of previous epochs.
The repeated calls to the generator and discriminator could be simplified to improve training speed (thanks to retain_graph=True).
I am willing to implement and submit a PR if you find this helpful 馃槂
馃摎 Documentation
In the training_step of the GAN(L.LightningModule) the generator and discriminator forward are called several times on the same input. Obviously this slows down the training because more computation is required. I wonder if we could just reuse the results of the first call. After all, the
toggle_optimiser
/untoggle_optimiser
functions should make it safe, right?For the generator:
self.generated_imgs = self(z)
self.discriminator(self(z))
self.discriminator(self(z).detach())
For the discriminator:
self.discriminator(self(z))
self.discriminator(self(z).detach())
cc @Borda
The text was updated successfully, but these errors were encountered: