Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example of gradient clipping with manual optimization does not handle gradient unscaling properly #18089

Open
function2-llx opened this issue Jul 15, 2023 · 3 comments 路 May be fixed by #19536
Open
Assignees
Labels
bug Something isn't working precision: amp Automatic Mixed Precision
Milestone

Comments

@function2-llx
Copy link
Contributor

function2-llx commented Jul 15, 2023

馃摎 Documentation

The doc of manual optimization give an example of gradient clipping (added by #16023):

from lightning.pytorch import LightningModule


class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.automatic_optimization = False

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()

        # compute loss
        loss = self.compute_loss(batch)

        opt.zero_grad()
        self.manual_backward(loss)

        # clip gradients
        self.clip_gradients(opt, gradient_clip_val=0.5, gradient_clip_algorithm="norm")

        opt.step()

However, it seems that this example does not handle gradient unscaling properly. The gradients should be unscaled when using mixed precision training before calling self.clip_gradients.

cc @carmocca @justusschock @awaelchli @Borda

@function2-llx function2-llx added docs Documentation related needs triage Waiting to be triaged by maintainers labels Jul 15, 2023
@function2-llx
Copy link
Contributor Author

function2-llx commented Jul 15, 2023

I'm not sure if this is a limitation or not, currently I actually find no simple way to achieve this.

For automatic optimization, gradient unscaling is performed right after the optimizer closure (training step, zero grad, backward) and before the gradient clipping (called in self._after_closure).
https://github.com/Lightning-AI/lightning/blob/e9c42ed11f68aafc18fe64a26d87118d57a5743c/src/lightning/pytorch/plugins/precision/amp.py#L78-L84
https://github.com/Lightning-AI/lightning/blob/e9c42ed11f68aafc18fe64a26d87118d57a5743c/src/lightning/pytorch/plugins/precision/precision_plugin.py#L77-L87
https://github.com/Lightning-AI/lightning/blob/e9c42ed11f68aafc18fe64a26d87118d57a5743c/src/lightning/pytorch/plugins/precision/precision_plugin.py#L116-L125

However, for manual optimization, the calling order is:

  1. epoch_loop.manual_optimization.run()
  2. model.training_step()
  3. inside the training step, user manually backward the loss (with gradient scaling), and call optimizer.step().
  4. In optimizer.step(), the gradients are unscaled, but the gradient clipping for the unscaled gradients are disabled due to manual optimization (in _after_closure -> _clip_gradients).

Above all, there seems to be no space for the user to insert gradient unscaling in training_step, since it's always unscaled in optimizer.step(). On ther other hand, The user is also unable to clip gradients after the gradients are unscaled before the optimizer's actual step.

So here comes a question, why not just also allow automatic gradient clipping for manual optimization? If users are supposed to take care of gradient clipping, most of the time they just simply call self.clip_gradients for unscaled gradients just like automatic optimization; if they want to do extra stuffs, they can make it via configure_gradient_clipping.

@kkoutini
Copy link

Hi, did you manage to find the correct way to do the scaling before the cliping for manual optimization ?

@function2-llx
Copy link
Contributor Author

@kkoutini No, I give up and use fabric instead.

@awaelchli awaelchli added bug Something isn't working and removed docs Documentation related needs triage Waiting to be triaged by maintainers labels Feb 27, 2024
@awaelchli awaelchli added this to the 2.2.x milestone Feb 27, 2024
@awaelchli awaelchli self-assigned this Feb 27, 2024
@awaelchli awaelchli added the precision: amp Automatic Mixed Precision label Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working precision: amp Automatic Mixed Precision
Projects
None yet
3 participants