From 3f899b9f77a2585b8b8bed2e6305c28ce8ec1d98 Mon Sep 17 00:00:00 2001 From: Gonzalo Tixilima Date: Fri, 19 Mar 2021 11:18:43 -0500 Subject: [PATCH] fix: display wrong epoch on keras resume training --- tests/tests_keras.py | 19 +++++++++++++++++++ tqdm/keras.py | 6 ++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/tests/tests_keras.py b/tests/tests_keras.py index b26cdbb78..88cf98761 100644 --- a/tests/tests_keras.py +++ b/tests/tests_keras.py @@ -80,3 +80,22 @@ def test_keras(capsys): assert "training: " in res assert "{epochs}/{epochs}".format(epochs=epochs) in res assert "{batches}/{batches}".format(batches=batches) in res + + # continue training (start from epoch != 0) + initial_epoch = 3 + model.fit( + x, + x, + initial_epoch=initial_epoch, + epochs=epochs, + batch_size=batch_size, + verbose=False, + callbacks=[TqdmCallback( + desc="training", + verbose=2 + )], + ) + _, res = capsys.readouterr() + assert "training: " in res + assert "{epochs}/{epochs}".format(epochs=epochs) in res + assert "{batches}/{batches}".format(batches=batches) in res diff --git a/tqdm/keras.py b/tqdm/keras.py index 45caf6189..0ed3f14a4 100644 --- a/tqdm/keras.py +++ b/tqdm/keras.py @@ -69,10 +69,12 @@ def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1, def on_train_begin(self, *_, **__): params = self.params.get auto_total = params('epochs', params('nb_epoch', None)) - if auto_total is not None: + if auto_total is not None and auto_total != self.epoch_bar.total: self.epoch_bar.reset(total=auto_total) - def on_epoch_begin(self, *_, **__): + def on_epoch_begin(self, epoch, *_, **__): + if self.epoch_bar.n < epoch: + self.epoch_bar.update(epoch-self.epoch_bar.n) if self.verbose: params = self.params.get total = params('samples', params(