Skip to content

Commit

Permalink
Merge branch 'fix/resumeTrainingKeras' into devel
Browse files Browse the repository at this point in the history
- closes #1150
- fixes #1138
  • Loading branch information
casperdcl committed May 3, 2021
2 parents d9372a2 + 74ec622 commit 9aea3ae
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
27 changes: 19 additions & 8 deletions tests/tests_keras.py
Expand Up @@ -38,9 +38,7 @@ def test_keras(capsys):
desc="training",
data_size=len(x),
batch_size=batch_size,
verbose=0,
)],
)
verbose=0)])
_, res = capsys.readouterr()
assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res
Expand All @@ -59,9 +57,7 @@ def test_keras(capsys):
desc="training",
data_size=len(x),
batch_size=batch_size,
verbose=2,
)],
)
verbose=2)])
_, res = capsys.readouterr()
assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res
Expand All @@ -74,9 +70,24 @@ def test_keras(capsys):
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[TqdmCallback(desc="training", verbose=2)],
)
callbacks=[TqdmCallback(desc="training", verbose=2)])
_, res = capsys.readouterr()
assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) in res

# continue training (start from epoch != 0)
initial_epoch = 3
model.fit(
x,
x,
initial_epoch=initial_epoch,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[TqdmCallback(desc="training", verbose=0,
miniters=1, mininterval=0, maxinterval=0)])
_, res = capsys.readouterr()
assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=initial_epoch - 1) not in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res
7 changes: 5 additions & 2 deletions tqdm/keras.py
Expand Up @@ -69,10 +69,13 @@ def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1,
def on_train_begin(self, *_, **__):
params = self.params.get
auto_total = params('epochs', params('nb_epoch', None))
if auto_total is not None:
if auto_total is not None and auto_total != self.epoch_bar.total:
self.epoch_bar.reset(total=auto_total)

def on_epoch_begin(self, *_, **__):
def on_epoch_begin(self, epoch, *_, **__):
if self.epoch_bar.n < epoch:
ebar = self.epoch_bar
ebar.n = ebar.last_print_n = ebar.initial = epoch
if self.verbose:
params = self.params.get
total = params('samples', params(
Expand Down

0 comments on commit 9aea3ae

Please sign in to comment.