From 8c7d8b3058e40374c8a211fb9bf3faf0c589ada3 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Thu, 1 Dec 2022 01:01:10 +0530 Subject: [PATCH] fix(integrations): Older TF versions compatibility issues with WandbModelCheckpoint (#4432) fix(integrations): Older TF versions compatibility issues with WandbModelCheckpoint Co-authored-by: Dmitry Duev Co-authored-by: Ayush Thakur Co-authored-by: Dmitry Duev --- .circleci/config.yml | 1 + .codecov.yml | 4 +-- .coveragerc | 1 + .../keras/keras_model_checkpoint_tf_2_4.yea | 16 +++++++++ .../keras/test_keras_model_checkpoint.py | 6 ++-- tox.ini | 5 +-- .../keras/callbacks/model_checkpoint.py | 36 +++++++++++++++---- 7 files changed, 54 insertions(+), 15 deletions(-) create mode 100644 tests/functional_tests/t0_main/keras/keras_model_checkpoint_tf_2_4.yea diff --git a/.circleci/config.yml b/.circleci/config.yml index 3a807df14fd..d2f4ff5188c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1226,6 +1226,7 @@ workflows: - "metaflow" - "tf115" - "tf21" + - "tf24" - "tf25" - "tf26" - "ray112" diff --git a/.codecov.yml b/.codecov.yml index 6b8077b18a8..f5075ab2931 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -4,14 +4,14 @@ codecov: # To calculate after_n_builds use # ./tools/coverage-tool.py jobs | wc -l # also change comment block after_n_builds just below - after_n_builds: 37 + after_n_builds: 38 wait_for_ci: no comment: layout: "reach, diff, flags, files" behavior: default require_changes: no - after_n_builds: 37 + after_n_builds: 38 ignore: - "wandb/vendor" diff --git a/.coveragerc b/.coveragerc index b8003214e21..1ee4159c69c 100644 --- a/.coveragerc +++ b/.coveragerc @@ -12,6 +12,7 @@ canonicalsrc = .tox/func-s_sklearn-py37/lib/python3.7/site-packages/wandb/ .tox/func-s_tf115-py37/lib/python3.7/site-packages/wandb/ .tox/func-s_tf21-py37/lib/python3.7/site-packages/wandb/ + .tox/func-s_tf24-py37/lib/python3.7/site-packages/wandb/ .tox/func-s_tf25-py37/lib/python3.7/site-packages/wandb/ .tox/func-s_tf26-py37/lib/python3.7/site-packages/wandb/ diff --git a/tests/functional_tests/t0_main/keras/keras_model_checkpoint_tf_2_4.yea b/tests/functional_tests/t0_main/keras/keras_model_checkpoint_tf_2_4.yea new file mode 100644 index 00000000000..8ca7728a883 --- /dev/null +++ b/tests/functional_tests/t0_main/keras/keras_model_checkpoint_tf_2_4.yea @@ -0,0 +1,16 @@ +id: 0.keras.modelcheckpoint.tf24 +tag: + shard: tf24 +plugin: + - wandb +command: + program: test_keras_model_checkpoint.py +depend: + requirements: + - tensorflow==2.4.1 +assert: + - :wandb:runs_len: 1 + - :op:contains: + - :wandb:runs[0][telemetry][3] # feature + - 39 # keras_model_checkpoint + - :wandb:runs[0][exitcode]: 0 diff --git a/tests/functional_tests/t0_main/keras/test_keras_model_checkpoint.py b/tests/functional_tests/t0_main/keras/test_keras_model_checkpoint.py index 437b55faf0c..37cd62eda60 100644 --- a/tests/functional_tests/t0_main/keras/test_keras_model_checkpoint.py +++ b/tests/functional_tests/t0_main/keras/test_keras_model_checkpoint.py @@ -3,8 +3,6 @@ import wandb from wandb.keras import WandbModelCheckpoint -tf.keras.utils.set_random_seed(1234) - run = wandb.init(project="keras") x = np.random.randint(255, size=(100, 28, 28, 1)) @@ -37,9 +35,9 @@ def get_model(): WandbModelCheckpoint( filepath="wandb/model/model_{epoch}", monitor="accuracy", - save_best_only=True, + save_best_only=False, save_weights_only=False, - save_freq=1, + save_freq=2, ) ], ) diff --git a/tox.ini b/tox.ini index 791d9c4ae30..71a6ce507ab 100644 --- a/tox.ini +++ b/tox.ini @@ -7,7 +7,7 @@ envlist= flake8, docstrings, py{36,37,38,39,launch,launch38}, - func-s_{base,sklearn,metaflow,tf115,tf21,tf25,tf26,ray112,ray2,service,docs, + func-s_{base,sklearn,metaflow,tf115,tf21,tf24,tf25,tf26,ray112,ray2,service,docs, imports{1,2,3,4,5,6,7,8,9,10,11,12},noml,grpc}-py37, standalone-{cpu,gpu,tpu,local}-py38, func-cover, @@ -409,7 +409,7 @@ commands = cp .coverage coverage.xml cover-results/ coverage report -m --ignore-errors --skip-covered --omit "wandb/vendor/*" -[testenv:func-s_{base,sklearn,metaflow,tf115,tf21,tf25,tf26,ray112,ray2,service,py310,docs,imports1,imports2,imports3,imports4,imports5,imports6,imports7,imports8,imports9,imports10,imports11,imports12,noml,grpc,kfp}-{py36,py37,py38,py39,py310}] +[testenv:func-s_{base,sklearn,metaflow,tf115,tf21,tf24,tf25,tf26,ray112,ray2,service,py310,docs,imports1,imports2,imports3,imports4,imports5,imports6,imports7,imports8,imports9,imports10,imports11,imports12,noml,grpc,kfp}-{py36,py37,py38,py39,py310}] install_command = pip install --extra-index-url https://download.pytorch.org/whl/cpu {opts} {packages} commands_pre = setenv = @@ -437,6 +437,7 @@ commands = func-s_metaflow-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard metaflow run {posargs:--all} func-s_tf115-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf115 run {posargs:--all} func-s_tf21-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf21 run {posargs:--all} + func-s_tf24-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf24 run {posargs:--all} func-s_tf25-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf25 run {posargs:--all} func-s_tf26-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf26 run {posargs:--all} func-s_ray112-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard ray112 run {posargs:--all} diff --git a/wandb/integration/keras/callbacks/model_checkpoint.py b/wandb/integration/keras/callbacks/model_checkpoint.py index d1482ce9ee0..dc9f6e3e4d7 100644 --- a/wandb/integration/keras/callbacks/model_checkpoint.py +++ b/wandb/integration/keras/callbacks/model_checkpoint.py @@ -4,6 +4,7 @@ import sys from typing import Any, Dict, List, Optional, Union +import tensorflow as tf # type: ignore from tensorflow.keras import callbacks # type: ignore import wandb @@ -108,16 +109,22 @@ def __init__( if self.save_best_only: self._check_filepath() + self._is_old_tf_keras_version: Optional[bool] = None + def on_train_batch_end( self, batch: int, logs: Optional[Dict[str, float]] = None ) -> None: if self._should_save_on_batch(batch): - # Save the model - self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) - # Get filepath where the model checkpoint is saved. - filepath = self._get_file_path( - epoch=self._current_epoch, batch=batch, logs=logs - ) + if self.is_old_tf_keras_version: + # Save the model and get filepath + self._save_model(epoch=self._current_epoch, logs=logs) + filepath = self._get_file_path(epoch=self._current_epoch, logs=logs) + else: + # Save the model and get filepath + self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) + filepath = self._get_file_path( + epoch=self._current_epoch, batch=batch, logs=logs + ) # Log the model as artifact aliases = ["latest", f"epoch_{self._current_epoch}_batch_{batch}"] self._log_ckpt_as_artifact(filepath, aliases=aliases) @@ -127,7 +134,10 @@ def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, float]] = None) -> N # Check if model checkpoint is created at the end of epoch. if self.save_freq == "epoch": # Get filepath where the model checkpoint is saved. - filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs) + if self.is_old_tf_keras_version: + filepath = self._get_file_path(epoch=epoch, logs=logs) + else: + filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs) # Log the model as artifact aliases = ["latest", f"epoch_{epoch}"] self._log_ckpt_as_artifact(filepath, aliases=aliases) @@ -173,3 +183,15 @@ def _check_filepath(self) -> None: "This ensures correct interpretation of the logged artifacts.", repeat=False, ) + + @property + def is_old_tf_keras_version(self) -> Optional[bool]: + if self._is_old_tf_keras_version is None: + from pkg_resources import parse_version + + if parse_version(tf.keras.__version__) < parse_version("2.6.0"): + self._is_old_tf_keras_version = True + else: + self._is_old_tf_keras_version = False + + return self._is_old_tf_keras_version