From 74ec6226610b22c4a2b509d9a7bcf8e90c3f4373 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Mon, 3 May 2021 00:52:05 +0100 Subject: [PATCH] keras: fix resume from `initial_epoch` - set initial epochs - update tests --- tests/tests_keras.py | 20 ++++++-------------- tqdm/keras.py | 3 ++- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/tests/tests_keras.py b/tests/tests_keras.py index 88cf98761..220f9461d 100644 --- a/tests/tests_keras.py +++ b/tests/tests_keras.py @@ -38,9 +38,7 @@ def test_keras(capsys): desc="training", data_size=len(x), batch_size=batch_size, - verbose=0, - )], - ) + verbose=0)]) _, res = capsys.readouterr() assert "training: " in res assert "{epochs}/{epochs}".format(epochs=epochs) in res @@ -59,9 +57,7 @@ def test_keras(capsys): desc="training", data_size=len(x), batch_size=batch_size, - verbose=2, - )], - ) + verbose=2)]) _, res = capsys.readouterr() assert "training: " in res assert "{epochs}/{epochs}".format(epochs=epochs) in res @@ -74,8 +70,7 @@ def test_keras(capsys): epochs=epochs, batch_size=batch_size, verbose=False, - callbacks=[TqdmCallback(desc="training", verbose=2)], - ) + callbacks=[TqdmCallback(desc="training", verbose=2)]) _, res = capsys.readouterr() assert "training: " in res assert "{epochs}/{epochs}".format(epochs=epochs) in res @@ -90,12 +85,9 @@ def test_keras(capsys): epochs=epochs, batch_size=batch_size, verbose=False, - callbacks=[TqdmCallback( - desc="training", - verbose=2 - )], - ) + callbacks=[TqdmCallback(desc="training", verbose=0, + miniters=1, mininterval=0, maxinterval=0)]) _, res = capsys.readouterr() assert "training: " in res + assert "{epochs}/{epochs}".format(epochs=initial_epoch - 1) not in res assert "{epochs}/{epochs}".format(epochs=epochs) in res - assert "{batches}/{batches}".format(batches=batches) in res diff --git a/tqdm/keras.py b/tqdm/keras.py index 0ed3f14a4..68a88d451 100644 --- a/tqdm/keras.py +++ b/tqdm/keras.py @@ -74,7 +74,8 @@ def on_train_begin(self, *_, **__): def on_epoch_begin(self, epoch, *_, **__): if self.epoch_bar.n < epoch: - self.epoch_bar.update(epoch-self.epoch_bar.n) + ebar = self.epoch_bar + ebar.n = ebar.last_print_n = ebar.initial = epoch if self.verbose: params = self.params.get total = params('samples', params(