Skip to content

Commit

Permalink
fix(integrations): Older TF versions compatibility issues with WandbM…
Browse files Browse the repository at this point in the history
…odelCheckpoint (#4432)

fix(integrations): Older TF versions compatibility issues with WandbModelCheckpoint

Co-authored-by: Dmitry Duev <dmitryduev@users.noreply.github.com>
Co-authored-by: Ayush Thakur <mein2work@gmail.com>
Co-authored-by: Dmitry Duev <dima@wandb.com>
  • Loading branch information
4 people committed Nov 30, 2022
1 parent 48f9258 commit 8c7d8b3
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 15 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Expand Up @@ -1226,6 +1226,7 @@ workflows:
- "metaflow"
- "tf115"
- "tf21"
- "tf24"
- "tf25"
- "tf26"
- "ray112"
Expand Down
4 changes: 2 additions & 2 deletions .codecov.yml
Expand Up @@ -4,14 +4,14 @@ codecov:
# To calculate after_n_builds use
# ./tools/coverage-tool.py jobs | wc -l
# also change comment block after_n_builds just below
after_n_builds: 37
after_n_builds: 38
wait_for_ci: no

comment:
layout: "reach, diff, flags, files"
behavior: default
require_changes: no
after_n_builds: 37
after_n_builds: 38

ignore:
- "wandb/vendor"
Expand Down
1 change: 1 addition & 0 deletions .coveragerc
Expand Up @@ -12,6 +12,7 @@ canonicalsrc =
.tox/func-s_sklearn-py37/lib/python3.7/site-packages/wandb/
.tox/func-s_tf115-py37/lib/python3.7/site-packages/wandb/
.tox/func-s_tf21-py37/lib/python3.7/site-packages/wandb/
.tox/func-s_tf24-py37/lib/python3.7/site-packages/wandb/
.tox/func-s_tf25-py37/lib/python3.7/site-packages/wandb/
.tox/func-s_tf26-py37/lib/python3.7/site-packages/wandb/

Expand Down
@@ -0,0 +1,16 @@
id: 0.keras.modelcheckpoint.tf24
tag:
shard: tf24
plugin:
- wandb
command:
program: test_keras_model_checkpoint.py
depend:
requirements:
- tensorflow==2.4.1
assert:
- :wandb:runs_len: 1
- :op:contains:
- :wandb:runs[0][telemetry][3] # feature
- 39 # keras_model_checkpoint
- :wandb:runs[0][exitcode]: 0
Expand Up @@ -3,8 +3,6 @@
import wandb
from wandb.keras import WandbModelCheckpoint

tf.keras.utils.set_random_seed(1234)

run = wandb.init(project="keras")

x = np.random.randint(255, size=(100, 28, 28, 1))
Expand Down Expand Up @@ -37,9 +35,9 @@ def get_model():
WandbModelCheckpoint(
filepath="wandb/model/model_{epoch}",
monitor="accuracy",
save_best_only=True,
save_best_only=False,
save_weights_only=False,
save_freq=1,
save_freq=2,
)
],
)
5 changes: 3 additions & 2 deletions tox.ini
Expand Up @@ -7,7 +7,7 @@ envlist=
flake8,
docstrings,
py{36,37,38,39,launch,launch38},
func-s_{base,sklearn,metaflow,tf115,tf21,tf25,tf26,ray112,ray2,service,docs,
func-s_{base,sklearn,metaflow,tf115,tf21,tf24,tf25,tf26,ray112,ray2,service,docs,
imports{1,2,3,4,5,6,7,8,9,10,11,12},noml,grpc}-py37,
standalone-{cpu,gpu,tpu,local}-py38,
func-cover,
Expand Down Expand Up @@ -409,7 +409,7 @@ commands =
cp .coverage coverage.xml cover-results/
coverage report -m --ignore-errors --skip-covered --omit "wandb/vendor/*"

[testenv:func-s_{base,sklearn,metaflow,tf115,tf21,tf25,tf26,ray112,ray2,service,py310,docs,imports1,imports2,imports3,imports4,imports5,imports6,imports7,imports8,imports9,imports10,imports11,imports12,noml,grpc,kfp}-{py36,py37,py38,py39,py310}]
[testenv:func-s_{base,sklearn,metaflow,tf115,tf21,tf24,tf25,tf26,ray112,ray2,service,py310,docs,imports1,imports2,imports3,imports4,imports5,imports6,imports7,imports8,imports9,imports10,imports11,imports12,noml,grpc,kfp}-{py36,py37,py38,py39,py310}]
install_command = pip install --extra-index-url https://download.pytorch.org/whl/cpu {opts} {packages}
commands_pre =
setenv =
Expand Down Expand Up @@ -437,6 +437,7 @@ commands =
func-s_metaflow-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard metaflow run {posargs:--all}
func-s_tf115-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf115 run {posargs:--all}
func-s_tf21-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf21 run {posargs:--all}
func-s_tf24-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf24 run {posargs:--all}
func-s_tf25-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf25 run {posargs:--all}
func-s_tf26-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf26 run {posargs:--all}
func-s_ray112-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard ray112 run {posargs:--all}
Expand Down
36 changes: 29 additions & 7 deletions wandb/integration/keras/callbacks/model_checkpoint.py
Expand Up @@ -4,6 +4,7 @@
import sys
from typing import Any, Dict, List, Optional, Union

import tensorflow as tf # type: ignore
from tensorflow.keras import callbacks # type: ignore

import wandb
Expand Down Expand Up @@ -108,16 +109,22 @@ def __init__(
if self.save_best_only:
self._check_filepath()

self._is_old_tf_keras_version: Optional[bool] = None

def on_train_batch_end(
self, batch: int, logs: Optional[Dict[str, float]] = None
) -> None:
if self._should_save_on_batch(batch):
# Save the model
self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)
# Get filepath where the model checkpoint is saved.
filepath = self._get_file_path(
epoch=self._current_epoch, batch=batch, logs=logs
)
if self.is_old_tf_keras_version:
# Save the model and get filepath
self._save_model(epoch=self._current_epoch, logs=logs)
filepath = self._get_file_path(epoch=self._current_epoch, logs=logs)
else:
# Save the model and get filepath
self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)
filepath = self._get_file_path(
epoch=self._current_epoch, batch=batch, logs=logs
)
# Log the model as artifact
aliases = ["latest", f"epoch_{self._current_epoch}_batch_{batch}"]
self._log_ckpt_as_artifact(filepath, aliases=aliases)
Expand All @@ -127,7 +134,10 @@ def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, float]] = None) -> N
# Check if model checkpoint is created at the end of epoch.
if self.save_freq == "epoch":
# Get filepath where the model checkpoint is saved.
filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs)
if self.is_old_tf_keras_version:
filepath = self._get_file_path(epoch=epoch, logs=logs)
else:
filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs)
# Log the model as artifact
aliases = ["latest", f"epoch_{epoch}"]
self._log_ckpt_as_artifact(filepath, aliases=aliases)
Expand Down Expand Up @@ -173,3 +183,15 @@ def _check_filepath(self) -> None:
"This ensures correct interpretation of the logged artifacts.",
repeat=False,
)

@property
def is_old_tf_keras_version(self) -> Optional[bool]:
if self._is_old_tf_keras_version is None:
from pkg_resources import parse_version

if parse_version(tf.keras.__version__) < parse_version("2.6.0"):
self._is_old_tf_keras_version = True
else:
self._is_old_tf_keras_version = False

return self._is_old_tf_keras_version

0 comments on commit 8c7d8b3

Please sign in to comment.