Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch pretraining, train_step_callback, staged training #1447

Open
albertz opened this issue Oct 25, 2023 · 7 comments
Open

PyTorch pretraining, train_step_callback, staged training #1447

albertz opened this issue Oct 25, 2023 · 7 comments
Assignees

Comments

@albertz
Copy link
Member

albertz commented Oct 25, 2023

(As initially discussed in #1120.)

How to handle pretraining? The current suggested APIs (get_model and co) might needs to be changed, because we do not want to call get_model every epoch? How would the APIs look like?

Or better: Leave the current functions, but have a separate function get_stage(epoch, step) -> None|str|int or so. When the stage changes, it signals to RETURNN to reconstruct the network.

Or, maybe more explicit: train_step_callback(model, epoch, step) -> StepCallbackReturn, StepCallbackReturn contains optionally a new model.

@JackTemaki
Copy link
Collaborator

For PyTorch I just handled the pre-training in the model defintion itself. This means of course I already construct all the layers from the beginning even if they are not used, but this did not require any extra work and allows for much more complex schemes (and also step based instead of epoch based).

The only thing I am missing right now is the functionality to change e.g. batch_size dynamically per epoch.

@albertz
Copy link
Member Author

albertz commented Oct 25, 2023

That's certainly also an option. But it's also limited in other ways. But you could combine it with such mechanism I was proposing here.

E.g. I often also have grown the layer dimensions. This would be tricky and/or inefficient in the way you describe.

Your option would also take a bit more GPU memory than necessary. E.g. my initial models were really tiny, 2 layer Conformer or so with small dimensions.

Also, as you say, changing any other config option, like batch size, or the dataset itself, etc.

In TF, we would always have reconstructed the whole model. The proposed API here would avoid that. Such step_callback would allow the user to modify a model inplace, e.g. changing the dimensions of some layers, or whatever else. The callback is also per step, so would allow for very fine-grained control. Maybe separately, we could have an epoch_callback.

Those callbacks would also allow to change config options, similar to what we did in TF (although the API would look different).

I'm leaning towards the last option, to have such step_callback (and additionally maybe epoch_callback), as it would be the most flexible. I'm now thinking about how to design the API in a way that would be most flexible and at the time simple and straightforward to use.

Currently I'm thinking about using such return value for the callback:

class StepCallbackReturn:
    """
    Return value of the step callback.
    """

    def __init__(
        self,
        *,
        updated_model: Union[rf.Module, torch.nn.Module] = None,
        config_updates: Optional[Dict[str, Any]] = None,
    ):
        """
        :param updated_model: Set this in case the model was updated.
            It can be the same instance as before or a new instance.
            You should set this when you modify the model in any way,
            such that the engine can recreate any wrapper objects
            (e.g. the DDP wrapped module, or the RF wrapped module).
        :param config_updates: Can include learning_rate, batch_size, etc.
        """
        self.updated_model = updated_model
        self.config_updates = config_updates

It could also be extended later by more logic.

@albertz
Copy link
Member Author

albertz commented Oct 26, 2023

To extend a bit on step_callback: When exactly do we want this to be executed? Right before a train step? Right after a train step? Only train or also forward?

I currently think: It should be right before a train step, but only if this is not the very first train step after initialization (i.e. if this is not step 0) (edit why only if not step 0? reconsider this, maybe for consistency better always...?). And I think it only makes sense in training. (Maybe we should call it train_step_callback?)

Example with initialization:

  • Epoch 1, step 0.
  • get_model()
  • train_step() + backprop + param update
  • Inc step: step 1.
  • step_callback()
  • train_step() + backprop + param update.
  • ...
  • Inc step.
  • Epoch 1 finished.
  • Save model checkpoint (Note: it is for the current epoch, and saves the step, but the step is already increased here, so it stores the next upcoming step.)
  • Inc epoch: epoch 2.
  • step_callback()
  • train_step() + backprop + param update.
  • ...

Example with loading existing model.

  • Found model checkpoint from epoch 10, step S. (Note, the checkpoint was saved after step S - 1 completed, so S is actually the next step.)
  • Epoch 10, step S - 1.
  • get_model()
  • Load model from checkpoint.
  • Inc step: step S.
  • Inc epoch: epoch 11.
  • (We are now in the same state as above just after we increased the epoch.)
  • step_callback() (maybe now the model should change when going from epoch 10 to 11, or when going from step S - 1 to step S)
  • train_step() + backprop + param update.
  • ...

@albertz albertz changed the title PyTorch pretraining PyTorch pretraining, step_callback, staged training Jan 3, 2024
@albertz albertz changed the title PyTorch pretraining, step_callback, staged training PyTorch pretraining, train_step_callback, staged training Jan 3, 2024
@albertz
Copy link
Member Author

albertz commented Jan 3, 2024

For other use cases (e.g. adapting the gradient accumulation or other settings), I'm also thinking about a train_epoch_callback or similar. The question is also here when exactly to call this, e.g. at the beginning of an epoch or afterwards. I think at the beginning makes more sense.

@albertz
Copy link
Member Author

albertz commented Jan 3, 2024

(Btw, about naming: in PyTorch, the forward hook is called afterwards, so it could also have been named "post forward hook", and then there is a pre forward hook, which is called before. Maybe not the same thing, though... Also, "hook" vs "callback".)

@albertz
Copy link
Member Author

albertz commented Jan 3, 2024

When it comes to settings (config updates), e.g. the example to change grad accum dynamically, or whatever else, I was also thinking, maybe doing that in one such callback function is not so easy and convenient for the user. Maybe instead, for a few supported selected settings (e.g. accum_grad_multiple_step), instead of just allowing an int, we can support a callable there, or maybe a more well-defined interface, e.g. an instance of StepwiseSettingIntf or EpochwiseSettingIntf. It would select all the necessary callbacks automatically at the beginning and then dynamically update this.

Advantages:

  • Much more convenient to the user.
  • More intuitive for the user.
  • Using this feature in an old RETURNN version would lead to a crash. Using the callback instead in an old RETURNN version would just silently ignore the setting and lead to unexpected behavior.

Disadvantages:

  • A bit more complex on RETURNN side.
  • Depending on the specific case, e.g. accum_grad_multiple_step, could be problematic with multi GPU, if the user does it somehow weird, or non-deterministic? But maybe not, not sure.
  • Will not work for everything. This is only about config updates, not about model updates.

@albertz
Copy link
Member Author

albertz commented Jan 5, 2024

accum_grad_multiple_step was extended now, that it can be a callable, which is executed every step.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants