New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use trained model if no custom checkpoint call back is specified. #3199
Conversation
Signed-off-by: Peng Zhang <pengz@uber.com>
Signed-off-by: Peng Zhang <pengz@uber.com>
@@ -128,11 +128,16 @@ def train(serialized_model): | |||
# Lightning requires to add checkpoint callbacks for all ranks. | |||
# Otherwise we are seeing hanging in training. | |||
_checkpoint_callback = checkpoint_callback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For simplicity, just relying on this _checkpoint_callback
variable != None check is enough.
If _checkpoint_callback
is not None: modify dir_path
and filename
, and callbacks.append(_checkpoint_callback)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if _checkpoint_callback is None, it will set to default ModelCheckpoint in line 141.
If user do not pass in a checkpoint callback, we still want to add the default ModelCheckpoint
in to callback list, right?
# rank 0 overwrites model with best checkpoint and returns. | ||
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path) | ||
# Rank 0 overwrites model with best checkpoint and returns. | ||
best_model = model if use_last_model_weights else model.load_from_checkpoint(_checkpoint_callback.best_model_path, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
best_model = model if not _checkpoint_callback else ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_checkpoint_callback is always not None because of the line 141
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add two nit comments.
duplicated with #3201 |
Checklist before submitting
Description
In lightning estimator, we do not need to load from check point if only want to return the last weights. The model after trianer.fit already have them.