-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better/correct test for callback #4049
Comments
There are some hooks in both These already have their alternatives with Should we remove them? |
from unittest.mock import MagicMock, call, ANY
from pytorch_lightning import Trainer, LightningModule
from tests.base import EvalModelTemplate
from unittest import mock
@mock.patch("torch.save") # need to mock torch.save or we get pickle error
def test_callback_system(torch_save):
model = EvalModelTemplate()
# pretend to be a callback, record all calls
callback = MagicMock()
trainer = Trainer(callbacks=[callback], max_steps=1, num_sanity_val_steps=0)
trainer.fit(model)
# check if a method was called exactly once
callback.on_fit_start.assert_called_once()
# check how many times a method was called
assert callback.on_train_batch_end.call_count == 1
# check that a method was NEVER called
callback.on_keyboard_interrupt.assert_not_called()
# check with what a method was called
callback.on_fit_end.assert_called_with(trainer, model)
# check exact call order
callback.assert_has_calls([
call.on_init_start(trainer),
call.on_init_end(trainer),
call.setup(trainer, None, "fit"),
call.on_fit_start(trainer, model),
call.on_pretrain_routine_start(trainer, model),
call.on_pretrain_routine_end(trainer, model),
call.on_train_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_train_epoch_start(trainer, model),
# BATCH 0
call.on_batch_start(trainer, model),
# here we don't care about exact values in batch, so we say ANY
call.on_train_batch_start(trainer, model, ANY, 0, 0),
call.on_batch_end(trainer, model),
# here we don't care about exact values in batch, so we say ANY
call.on_train_batch_end(trainer, model, [], ANY, 0, 0),
call.on_epoch_end(trainer, model),
call.on_train_epoch_end(trainer, model, ANY),
call.on_save_checkpoint(trainer, model),
call.on_save_checkpoint().__bool__(), # what's this lol?
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, "fit"),
]) Here is a simple example of how to track calls with unittest.mock. Please consider testing the callbacks this way. |
@awaelchli you working on this? |
what I posted here is all I worked on. It is a fully functional test + demo of other functionalities. it can be extended a little bit and then replace all the custom callback tracking in the old test. Please feel free to take it and use this code. If not, I will find some time.
I believe hooks like |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
I like that you can also test the called arguments but it seems that this UPDATE: sem that we can list all methods
|
That's strange, because docs say it asserts the exact order in sequence: |
maybe some bug in implementation but is you simply test the replacement with an empty list, it passes too |
yes it makes sense that it passes with empty list because any number of calls can occur before or after the sequence you pass in, so for example, before()
a()
b()
c()
after()
assert_has_calls([]) # true
assert_has_calls([a]) # true
assert_has_calls([a, b]) # true
assert_has_calls([a, b, c]) # true
assert_has_calls([b, a, c]) # FALSE!!
assert_has_calls([before, a, b, c, after]) # true |
Thanks for taking care of this @Borda ! |
I see, it is a bit dangerous in case you miss some at the beginning or at the end and all seems to be fine... 8-) |
馃殌 Feature
Reopen #4009 and let it merge...
Motivation
The actual test with bools is weak, do not check call order
The text was updated successfully, but these errors were encountered: