From 3f899b9f77a2585b8b8bed2e6305c28ce8ec1d98 Mon Sep 17 00:00:00 2001 From: Gonzalo Tixilima Date: Fri, 19 Mar 2021 11:18:43 -0500 Subject: [PATCH 1/2] fix: display wrong epoch on keras resume training --- tests/tests_keras.py | 19 +++++++++++++++++++ tqdm/keras.py | 6 ++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/tests/tests_keras.py b/tests/tests_keras.py index b26cdbb78..88cf98761 100644 --- a/tests/tests_keras.py +++ b/tests/tests_keras.py @@ -80,3 +80,22 @@ def test_keras(capsys): assert "training: " in res assert "{epochs}/{epochs}".format(epochs=epochs) in res assert "{batches}/{batches}".format(batches=batches) in res + + # continue training (start from epoch != 0) + initial_epoch = 3 + model.fit( + x, + x, + initial_epoch=initial_epoch, + epochs=epochs, + batch_size=batch_size, + verbose=False, + callbacks=[TqdmCallback( + desc="training", + verbose=2 + )], + ) + _, res = capsys.readouterr() + assert "training: " in res + assert "{epochs}/{epochs}".format(epochs=epochs) in res + assert "{batches}/{batches}".format(batches=batches) in res diff --git a/tqdm/keras.py b/tqdm/keras.py index 45caf6189..0ed3f14a4 100644 --- a/tqdm/keras.py +++ b/tqdm/keras.py @@ -69,10 +69,12 @@ def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1, def on_train_begin(self, *_, **__): params = self.params.get auto_total = params('epochs', params('nb_epoch', None)) - if auto_total is not None: + if auto_total is not None and auto_total != self.epoch_bar.total: self.epoch_bar.reset(total=auto_total) - def on_epoch_begin(self, *_, **__): + def on_epoch_begin(self, epoch, *_, **__): + if self.epoch_bar.n < epoch: + self.epoch_bar.update(epoch-self.epoch_bar.n) if self.verbose: params = self.params.get total = params('samples', params( From 74ec6226610b22c4a2b509d9a7bcf8e90c3f4373 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Mon, 3 May 2021 00:52:05 +0100 Subject: [PATCH 2/2] keras: fix resume from `initial_epoch` - set initial epochs - update tests --- tests/tests_keras.py | 20 ++++++-------------- tqdm/keras.py | 3 ++- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/tests/tests_keras.py b/tests/tests_keras.py index 88cf98761..220f9461d 100644 --- a/tests/tests_keras.py +++ b/tests/tests_keras.py @@ -38,9 +38,7 @@ def test_keras(capsys): desc="training", data_size=len(x), batch_size=batch_size, - verbose=0, - )], - ) + verbose=0)]) _, res = capsys.readouterr() assert "training: " in res assert "{epochs}/{epochs}".format(epochs=epochs) in res @@ -59,9 +57,7 @@ def test_keras(capsys): desc="training", data_size=len(x), batch_size=batch_size, - verbose=2, - )], - ) + verbose=2)]) _, res = capsys.readouterr() assert "training: " in res assert "{epochs}/{epochs}".format(epochs=epochs) in res @@ -74,8 +70,7 @@ def test_keras(capsys): epochs=epochs, batch_size=batch_size, verbose=False, - callbacks=[TqdmCallback(desc="training", verbose=2)], - ) + callbacks=[TqdmCallback(desc="training", verbose=2)]) _, res = capsys.readouterr() assert "training: " in res assert "{epochs}/{epochs}".format(epochs=epochs) in res @@ -90,12 +85,9 @@ def test_keras(capsys): epochs=epochs, batch_size=batch_size, verbose=False, - callbacks=[TqdmCallback( - desc="training", - verbose=2 - )], - ) + callbacks=[TqdmCallback(desc="training", verbose=0, + miniters=1, mininterval=0, maxinterval=0)]) _, res = capsys.readouterr() assert "training: " in res + assert "{epochs}/{epochs}".format(epochs=initial_epoch - 1) not in res assert "{epochs}/{epochs}".format(epochs=epochs) in res - assert "{batches}/{batches}".format(batches=batches) in res diff --git a/tqdm/keras.py b/tqdm/keras.py index 0ed3f14a4..68a88d451 100644 --- a/tqdm/keras.py +++ b/tqdm/keras.py @@ -74,7 +74,8 @@ def on_train_begin(self, *_, **__): def on_epoch_begin(self, epoch, *_, **__): if self.epoch_bar.n < epoch: - self.epoch_bar.update(epoch-self.epoch_bar.n) + ebar = self.epoch_bar + ebar.n = ebar.last_print_n = ebar.initial = epoch if self.verbose: params = self.params.get total = params('samples', params(