Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better/correct test for callback #4049

Closed
Borda opened this issue Oct 10, 2020 · 12 comments 路 Fixed by #4009
Closed

Better/correct test for callback #4049

Borda opened this issue Oct 10, 2020 · 12 comments 路 Fixed by #4009
Assignees
Labels
ci Continuous Integration feature Is an improvement or enhancement
Milestone

Comments

@Borda
Copy link
Member

Borda commented Oct 10, 2020

馃殌 Feature

Reopen #4009 and let it merge...

Motivation

The actual test with bools is weak, do not check call order

@Borda Borda added feature Is an improvement or enhancement help wanted Open to be worked on ci Continuous Integration labels Oct 10, 2020
@Borda Borda self-assigned this Oct 10, 2020
@awaelchli awaelchli self-assigned this Oct 10, 2020
@rohitgr7
Copy link
Contributor

There are some hooks in both ModelHooks and Callback that are redundant I think.
on_epoch_start, on_epoch_end, on_batch_start, on_batch_end.

These already have their alternatives with on_train_* prefix.

Should we remove them?

@Borda Borda removed the help wanted Open to be worked on label Oct 10, 2020
@awaelchli
Copy link
Member

awaelchli commented Oct 10, 2020

from unittest.mock import MagicMock, call, ANY
from pytorch_lightning import Trainer, LightningModule
from tests.base import EvalModelTemplate
from unittest import mock


@mock.patch("torch.save")  # need to mock torch.save or we get pickle error
def test_callback_system(torch_save):
    model = EvalModelTemplate()
    # pretend to be a callback, record all calls
    callback = MagicMock()
    trainer = Trainer(callbacks=[callback], max_steps=1, num_sanity_val_steps=0)
    trainer.fit(model)

    # check if a method was called exactly once
    callback.on_fit_start.assert_called_once()

    # check how many times a method was called
    assert callback.on_train_batch_end.call_count == 1

    # check that a method was NEVER called
    callback.on_keyboard_interrupt.assert_not_called()

    # check with what a method was called
    callback.on_fit_end.assert_called_with(trainer, model)

    # check exact call order
    callback.assert_has_calls([
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.setup(trainer, None, "fit"),
        call.on_fit_start(trainer, model),
        call.on_pretrain_routine_start(trainer, model),
        call.on_pretrain_routine_end(trainer, model),
        call.on_train_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_train_epoch_start(trainer, model),
        # BATCH 0
        call.on_batch_start(trainer, model),
        # here we don't care about exact values in batch, so we say ANY
        call.on_train_batch_start(trainer, model, ANY, 0, 0),
        call.on_batch_end(trainer, model),
        # here we don't care about exact values in batch, so we say ANY
        call.on_train_batch_end(trainer, model, [], ANY, 0, 0),
        call.on_epoch_end(trainer, model),
        call.on_train_epoch_end(trainer, model, ANY),
        call.on_save_checkpoint(trainer, model),
        call.on_save_checkpoint().__bool__(),   # what's this lol?
        call.on_train_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, "fit"),
    ])

Here is a simple example of how to track calls with unittest.mock.
It is very elegant, easy to understand and allows you to check that methods were called with the expected arguments and in the exact order.

Please consider testing the callbacks this way.
The same could be applied to the model hooks (#4010). It is very straight forward and also easier to read than the old test.

@awaelchli
Copy link
Member

@Borda Borda removed their assignment Oct 13, 2020
@rohitgr7
Copy link
Contributor

@awaelchli you working on this?
Also, should we deprecate/remove #4049 (comment) first and then update the test?

@awaelchli
Copy link
Member

what I posted here is all I worked on. It is a fully functional test + demo of other functionalities. it can be extended a little bit and then replace all the custom callback tracking in the old test. Please feel free to take it and use this code. If not, I will find some time.

Also, should we deprecate/remove #4049 (comment) first and then update the test?

I believe hooks like on_epoch_start can be useful if we "redefine" them to be running on epoch start regardless of training, validation, or test. If this is not desired, I'd rather have them removed.

@stale
Copy link

stale bot commented Nov 15, 2020

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@Borda Borda added this to To do in Code Health via automation Nov 18, 2020
@Borda Borda added this to the 1.1 milestone Nov 18, 2020
@Borda
Copy link
Member Author

Borda commented Nov 27, 2020

It is very elegant, easy to understand and allows you to check that methods were called with the expected arguments and in the exact order.

I like that you can also test the called arguments but it seems that this assert_has_calls is order independent as the test passes when I shuffle calls inside or remove all to have callback.assert_has_calls([]) instead

UPDATE: sem that we can list all methods

assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
    ]

@Borda Borda self-assigned this Nov 27, 2020
@awaelchli
Copy link
Member

That's strange, because docs say it asserts the exact order in sequence:
https://docs.python.org/3/library/unittest.mock.html#unittest.mock.Mock.assert_has_calls

@Borda
Copy link
Member Author

Borda commented Nov 27, 2020

That's strange, because docs say it asserts the exact order in sequence:
https://docs.python.org/3/library/unittest.mock.html#unittest.mock.Mock.assert_has_calls

maybe some bug in implementation but is you simply test the replacement with an empty list, it passes too

@awaelchli
Copy link
Member

awaelchli commented Nov 27, 2020

yes it makes sense that it passes with empty list because any number of calls can occur before or after the sequence you pass in, so for example,

before()
a()
b()
c()
after()

assert_has_calls([]) # true
assert_has_calls([a]) # true
assert_has_calls([a, b]) # true
assert_has_calls([a, b, c]) # true
assert_has_calls([b, a, c]) # FALSE!!
assert_has_calls([before, a, b, c, after]) # true

Test Coverage automation moved this from To do to Done Nov 27, 2020
Code Health automation moved this from To do to Done Nov 27, 2020
@awaelchli
Copy link
Member

Thanks for taking care of this @Borda !

@Borda
Copy link
Member Author

Borda commented Nov 27, 2020

yes it makes sense that it passes with an empty list because any number of calls can occur before or after the sequence you pass in, so for example,

I see, it is a bit dangerous in case you miss some at the beginning or at the end and all seems to be fine... 8-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci Continuous Integration feature Is an improvement or enhancement
Projects
No open projects
3 participants