Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use trained model if no custom checkpoint call back is specified. #3199

Closed
wants to merge 3 commits into from

Conversation

irasit
Copy link
Collaborator

@irasit irasit commented Oct 5, 2021

Checklist before submitting

  • Did you read the contributor guide?
  • Did you update the docs?
  • Did you write any tests to validate this change?
  • Did you update the CHANGELOG, if this change affects users?

Description

In lightning estimator, we do not need to load from check point if only want to return the last weights. The model after trianer.fit already have them.

Signed-off-by: Peng Zhang <pengz@uber.com>
Signed-off-by: Peng Zhang <pengz@uber.com>
@irasit irasit requested a review from chongxiaoc October 5, 2021 23:51
Signed-off-by: Peng Zhang <pengz@uber.com>
@@ -128,11 +128,16 @@ def train(serialized_model):
# Lightning requires to add checkpoint callbacks for all ranks.
# Otherwise we are seeing hanging in training.
_checkpoint_callback = checkpoint_callback
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For simplicity, just relying on this _checkpoint_callback variable != None check is enough.
If _checkpoint_callback is not None: modify dir_path and filename, and callbacks.append(_checkpoint_callback)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if _checkpoint_callback is None, it will set to default ModelCheckpoint in line 141.
If user do not pass in a checkpoint callback, we still want to add the default ModelCheckpoint in to callback list, right?

# rank 0 overwrites model with best checkpoint and returns.
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path)
# Rank 0 overwrites model with best checkpoint and returns.
best_model = model if use_last_model_weights else model.load_from_checkpoint(_checkpoint_callback.best_model_path, strict=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

best_model = model if not _checkpoint_callback else ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_checkpoint_callback is always not None because of the line 141

Copy link
Collaborator

@chongxiaoc chongxiaoc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add two nit comments.

@github-actions
Copy link

github-actions bot commented Oct 6, 2021

Unit Test Results

     751 files  +   366       751 suites  +366   6h 48m 43s ⏱️ + 2h 11m 10s
     699 tests ±       0       651 ✔️ +     92       48 💤  -      91  0  - 1 
16 113 runs  +7 736  11 308 ✔️ +5 649  4 805 💤 +2 088  0  - 1 

Results for commit 74d08f6. ± Comparison against base commit 062aaa0.

@github-actions
Copy link

github-actions bot commented Oct 6, 2021

Unit Test Results (with flaky tests)

     907 files  +     495       907 suites  +495   7h 13m 59s ⏱️ + 2h 24m 50s
     699 tests ±         0       652 ✔️ +     94       47 💤  -      92  0  - 2 
19 473 runs  +10 679  13 436 ✔️ +7 406  6 037 💤 +3 277  0  - 4 

Results for commit 74d08f6. ± Comparison against base commit 062aaa0.

@chongxiaoc
Copy link
Collaborator

duplicated with #3201

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants