From 9806361224c0421adbca8747c9295e0f98d91948 Mon Sep 17 00:00:00 2001 From: soumik12345 <19soumik.rakshit96@gmail.com> Date: Thu, 3 Nov 2022 18:15:09 +0530 Subject: [PATCH 01/12] fixed version issues wrt keras --- .../keras/callbacks/model_checkpoint.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/wandb/integration/keras/callbacks/model_checkpoint.py b/wandb/integration/keras/callbacks/model_checkpoint.py index 646ad4acb63..24abdddb46d 100644 --- a/wandb/integration/keras/callbacks/model_checkpoint.py +++ b/wandb/integration/keras/callbacks/model_checkpoint.py @@ -2,8 +2,10 @@ import os import string import sys +from pkg_resources import parse_version from typing import Any, Dict, List, Optional, Union +import tensorflow as tf from tensorflow.keras import callbacks # type: ignore import wandb @@ -113,9 +115,12 @@ def on_train_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None: # Save the model self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) # Get filepath where the model checkpoint is saved. - filepath = self._get_file_path( - epoch=self._current_epoch, batch=batch, logs=logs - ) + if parse_version(tf.keras.__version__) <= parse_version("2.6.0"): + filepath = self._get_file_path(epoch=self._current_epoch, logs=logs) + else: + filepath = self._get_file_path( + epoch=self._current_epoch, batch=batch, logs=logs + ) # Log the model as artifact aliases = ["latest", f"epoch_{self._current_epoch}_batch_{batch}"] self._log_ckpt_as_artifact(filepath, aliases=aliases) @@ -125,7 +130,10 @@ def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None: # Check if model checkpoint is created at the end of epoch. if self.save_freq == "epoch": # Get filepath where the model checkpoint is saved. - filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs) + if parse_version(tf.keras.__version__) <= parse_version("2.6.0"): + filepath = self._get_file_path(epoch=epoch, logs=logs) + else: + filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs) # Log the model as artifact aliases = ["latest", f"epoch_{epoch}"] self._log_ckpt_as_artifact(filepath, aliases=aliases) From 1696d37f97816caed02ffb952c628a07f8d2c1d9 Mon Sep 17 00:00:00 2001 From: ayulockin Date: Tue, 8 Nov 2022 16:56:35 +0000 Subject: [PATCH 02/12] tests working --- .../t0_main/keras/test_keras_model_checkpoint.py | 6 ++---- wandb/integration/keras/callbacks/model_checkpoint.py | 5 ++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/functional_tests/t0_main/keras/test_keras_model_checkpoint.py b/tests/functional_tests/t0_main/keras/test_keras_model_checkpoint.py index 437b55faf0c..37cd62eda60 100644 --- a/tests/functional_tests/t0_main/keras/test_keras_model_checkpoint.py +++ b/tests/functional_tests/t0_main/keras/test_keras_model_checkpoint.py @@ -3,8 +3,6 @@ import wandb from wandb.keras import WandbModelCheckpoint -tf.keras.utils.set_random_seed(1234) - run = wandb.init(project="keras") x = np.random.randint(255, size=(100, 28, 28, 1)) @@ -37,9 +35,9 @@ def get_model(): WandbModelCheckpoint( filepath="wandb/model/model_{epoch}", monitor="accuracy", - save_best_only=True, + save_best_only=False, save_weights_only=False, - save_freq=1, + save_freq=2, ) ], ) diff --git a/wandb/integration/keras/callbacks/model_checkpoint.py b/wandb/integration/keras/callbacks/model_checkpoint.py index 24abdddb46d..44cff646193 100644 --- a/wandb/integration/keras/callbacks/model_checkpoint.py +++ b/wandb/integration/keras/callbacks/model_checkpoint.py @@ -112,12 +112,11 @@ def __init__( def on_train_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None: if self._should_save_on_batch(batch): - # Save the model - self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) - # Get filepath where the model checkpoint is saved. if parse_version(tf.keras.__version__) <= parse_version("2.6.0"): + self._save_model(epoch=self._current_epoch, logs=logs) filepath = self._get_file_path(epoch=self._current_epoch, logs=logs) else: + self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) filepath = self._get_file_path( epoch=self._current_epoch, batch=batch, logs=logs ) From 4da91697621a5696a6a5fdafc278273952300b9d Mon Sep 17 00:00:00 2001 From: ayulockin Date: Tue, 8 Nov 2022 17:03:32 +0000 Subject: [PATCH 03/12] make code-check happy --- wandb/integration/keras/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wandb/integration/keras/callbacks/model_checkpoint.py b/wandb/integration/keras/callbacks/model_checkpoint.py index 44cff646193..44555de34e2 100644 --- a/wandb/integration/keras/callbacks/model_checkpoint.py +++ b/wandb/integration/keras/callbacks/model_checkpoint.py @@ -2,10 +2,10 @@ import os import string import sys -from pkg_resources import parse_version from typing import Any, Dict, List, Optional, Union -import tensorflow as tf +import tensorflow as tf # type: ignore +from pkg_resources import parse_version from tensorflow.keras import callbacks # type: ignore import wandb From 371750a28615063cbff3056097f03b4abae3f613 Mon Sep 17 00:00:00 2001 From: ayulockin Date: Tue, 8 Nov 2022 17:42:57 +0000 Subject: [PATCH 04/12] add test for tf 2.4.x --- .../keras/keras_model_checkpoint_tf_2_4.yea | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 tests/functional_tests/t0_main/keras/keras_model_checkpoint_tf_2_4.yea diff --git a/tests/functional_tests/t0_main/keras/keras_model_checkpoint_tf_2_4.yea b/tests/functional_tests/t0_main/keras/keras_model_checkpoint_tf_2_4.yea new file mode 100644 index 00000000000..8ca7728a883 --- /dev/null +++ b/tests/functional_tests/t0_main/keras/keras_model_checkpoint_tf_2_4.yea @@ -0,0 +1,16 @@ +id: 0.keras.modelcheckpoint.tf24 +tag: + shard: tf24 +plugin: + - wandb +command: + program: test_keras_model_checkpoint.py +depend: + requirements: + - tensorflow==2.4.1 +assert: + - :wandb:runs_len: 1 + - :op:contains: + - :wandb:runs[0][telemetry][3] # feature + - 39 # keras_model_checkpoint + - :wandb:runs[0][exitcode]: 0 From 65e729cd767930c6e3c42141dbb1bf1800681968 Mon Sep 17 00:00:00 2001 From: soumik12345 <19soumik.rakshit96@gmail.com> Date: Wed, 9 Nov 2022 13:17:18 +0530 Subject: [PATCH 05/12] updated version check logic --- .../keras/callbacks/model_checkpoint.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/wandb/integration/keras/callbacks/model_checkpoint.py b/wandb/integration/keras/callbacks/model_checkpoint.py index 44555de34e2..78f8105f9c4 100644 --- a/wandb/integration/keras/callbacks/model_checkpoint.py +++ b/wandb/integration/keras/callbacks/model_checkpoint.py @@ -2,10 +2,10 @@ import os import string import sys +from pkg_resources import parse_version from typing import Any, Dict, List, Optional, Union -import tensorflow as tf # type: ignore -from pkg_resources import parse_version +import tensorflow as tf from tensorflow.keras import callbacks # type: ignore import wandb @@ -110,13 +110,18 @@ def __init__( if self.save_best_only: self._check_filepath() + self.is_old_tf_keras_version: bool = parse_version( + tf.keras.__version__ + ) <= parse_version("2.6.0") + def on_train_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None: if self._should_save_on_batch(batch): - if parse_version(tf.keras.__version__) <= parse_version("2.6.0"): - self._save_model(epoch=self._current_epoch, logs=logs) + # Save the model + self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) + # Get filepath where the model checkpoint is saved. + if self.is_old_tf_keras_version: filepath = self._get_file_path(epoch=self._current_epoch, logs=logs) else: - self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) filepath = self._get_file_path( epoch=self._current_epoch, batch=batch, logs=logs ) @@ -129,7 +134,7 @@ def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None: # Check if model checkpoint is created at the end of epoch. if self.save_freq == "epoch": # Get filepath where the model checkpoint is saved. - if parse_version(tf.keras.__version__) <= parse_version("2.6.0"): + if self.is_old_tf_keras_version: filepath = self._get_file_path(epoch=epoch, logs=logs) else: filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs) From 5162e84240ba5554724d853062637a2e7dd448b7 Mon Sep 17 00:00:00 2001 From: ayulockin Date: Wed, 9 Nov 2022 12:23:23 +0000 Subject: [PATCH 06/12] code-check + abstract parse_version --- .../keras/callbacks/model_checkpoint.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/wandb/integration/keras/callbacks/model_checkpoint.py b/wandb/integration/keras/callbacks/model_checkpoint.py index 78f8105f9c4..4c68bc1be47 100644 --- a/wandb/integration/keras/callbacks/model_checkpoint.py +++ b/wandb/integration/keras/callbacks/model_checkpoint.py @@ -2,10 +2,9 @@ import os import string import sys -from pkg_resources import parse_version from typing import Any, Dict, List, Optional, Union -import tensorflow as tf +import tensorflow as tf # type: ignore from tensorflow.keras import callbacks # type: ignore import wandb @@ -110,9 +109,8 @@ def __init__( if self.save_best_only: self._check_filepath() - self.is_old_tf_keras_version: bool = parse_version( - tf.keras.__version__ - ) <= parse_version("2.6.0") + # Patch to make the callback compatible with TF version < 2.6.0. + self._patch_tf_keras_version() def on_train_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None: if self._should_save_on_batch(batch): @@ -183,3 +181,10 @@ def _check_filepath(self) -> None: "This ensures correct interpretation of the logged artifacts.", repeat=False, ) + + def _patch_tf_keras_version(self) -> None: + from pkg_resources import parse_version + + self.is_old_tf_keras_version: bool = parse_version( + tf.keras.__version__ + ) < parse_version("2.6.0") From 1121e607b82e46d05ab0b59d73a4599c668cc532 Mon Sep 17 00:00:00 2001 From: ayulockin Date: Thu, 24 Nov 2022 14:39:08 +0000 Subject: [PATCH 07/12] address feedback --- .circleci/config.yml | 1 + tox.ini | 5 +++-- .../integration/keras/callbacks/model_checkpoint.py | 13 +++++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 3a807df14fd..d2f4ff5188c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1226,6 +1226,7 @@ workflows: - "metaflow" - "tf115" - "tf21" + - "tf24" - "tf25" - "tf26" - "ray112" diff --git a/tox.ini b/tox.ini index e64d5de625a..9171cf96820 100644 --- a/tox.ini +++ b/tox.ini @@ -7,7 +7,7 @@ envlist= flake8, docstrings, py{36,37,38,39,launch,launch38}, - func-s_{base,sklearn,metaflow,tf115,tf21,tf25,tf26,ray112,ray2,service,docs, + func-s_{base,sklearn,metaflow,tf115,tf21,tf24,tf25,tf26,ray112,ray2,service,docs, imports{1,2,3,4,5,6,7,8,9,10,11,12},noml,grpc}-py37, standalone-{cpu,gpu,tpu,local}-py38, func-cover, @@ -409,7 +409,7 @@ commands = cp .coverage coverage.xml cover-results/ coverage report -m --ignore-errors --skip-covered --omit "wandb/vendor/*" -[testenv:func-s_{base,sklearn,metaflow,tf115,tf21,tf25,tf26,ray112,ray2,service,py310,docs,imports1,imports2,imports3,imports4,imports5,imports6,imports7,imports8,imports9,imports10,imports11,imports12,noml,grpc,kfp}-{py36,py37,py38,py39,py310}] +[testenv:func-s_{base,sklearn,metaflow,tf115,tf21,tf24,tf25,tf26,ray112,ray2,service,py310,docs,imports1,imports2,imports3,imports4,imports5,imports6,imports7,imports8,imports9,imports10,imports11,imports12,noml,grpc,kfp}-{py36,py37,py38,py39,py310}] install_command = pip install --extra-index-url https://download.pytorch.org/whl/cpu {opts} {packages} commands_pre = setenv = @@ -437,6 +437,7 @@ commands = func-s_metaflow-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard metaflow run {posargs:--all} func-s_tf115-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf115 run {posargs:--all} func-s_tf21-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf21 run {posargs:--all} + func-s_tf24-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf24 run {posargs:--all} func-s_tf25-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf25 run {posargs:--all} func-s_tf26-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf26 run {posargs:--all} func-s_ray112-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard ray112 run {posargs:--all} diff --git a/wandb/integration/keras/callbacks/model_checkpoint.py b/wandb/integration/keras/callbacks/model_checkpoint.py index 4c68bc1be47..ab858caeb0e 100644 --- a/wandb/integration/keras/callbacks/model_checkpoint.py +++ b/wandb/integration/keras/callbacks/model_checkpoint.py @@ -109,8 +109,7 @@ def __init__( if self.save_best_only: self._check_filepath() - # Patch to make the callback compatible with TF version < 2.6.0. - self._patch_tf_keras_version() + self._is_old_tf_keras_version: Optional[bool] = None def on_train_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None: if self._should_save_on_batch(batch): @@ -182,9 +181,11 @@ def _check_filepath(self) -> None: repeat=False, ) - def _patch_tf_keras_version(self) -> None: + @property + def is_old_tf_keras_version(self): from pkg_resources import parse_version - self.is_old_tf_keras_version: bool = parse_version( - tf.keras.__version__ - ) < parse_version("2.6.0") + if parse_version(tf.keras.__version__) < parse_version("2.6.0"): + self._is_old_tf_keras_version = True + + return self._is_old_tf_keras_version From f5d389b6a1d13a89cbee55bda498d3398148af7e Mon Sep 17 00:00:00 2001 From: ayulockin Date: Thu, 24 Nov 2022 15:02:58 +0000 Subject: [PATCH 08/12] fix test tf2.4 --- wandb/integration/keras/callbacks/model_checkpoint.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/wandb/integration/keras/callbacks/model_checkpoint.py b/wandb/integration/keras/callbacks/model_checkpoint.py index ab858caeb0e..3c668f21cab 100644 --- a/wandb/integration/keras/callbacks/model_checkpoint.py +++ b/wandb/integration/keras/callbacks/model_checkpoint.py @@ -113,12 +113,13 @@ def __init__( def on_train_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None: if self._should_save_on_batch(batch): - # Save the model - self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) - # Get filepath where the model checkpoint is saved. if self.is_old_tf_keras_version: + # Save the model and get filepath + self._save_model(epoch=self._current_epoch, logs=logs) filepath = self._get_file_path(epoch=self._current_epoch, logs=logs) else: + # Save the model and get filepath + self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) filepath = self._get_file_path( epoch=self._current_epoch, batch=batch, logs=logs ) From e4ac0a14fb94dd97b4faa300b7380b47ac698b65 Mon Sep 17 00:00:00 2001 From: ayulockin Date: Thu, 24 Nov 2022 15:18:09 +0000 Subject: [PATCH 09/12] make mypy happy --- wandb/integration/keras/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wandb/integration/keras/callbacks/model_checkpoint.py b/wandb/integration/keras/callbacks/model_checkpoint.py index 3c668f21cab..b872ddfcba4 100644 --- a/wandb/integration/keras/callbacks/model_checkpoint.py +++ b/wandb/integration/keras/callbacks/model_checkpoint.py @@ -183,7 +183,7 @@ def _check_filepath(self) -> None: ) @property - def is_old_tf_keras_version(self): + def is_old_tf_keras_version(self) -> Optional[bool]: from pkg_resources import parse_version if parse_version(tf.keras.__version__) < parse_version("2.6.0"): From 946c86850e7d2cf3b8b3a7fbb086e527423f7d3b Mon Sep 17 00:00:00 2001 From: Dmitry Duev Date: Tue, 29 Nov 2022 19:02:02 -0800 Subject: [PATCH 10/12] fix coveragerc --- .coveragerc | 1 + 1 file changed, 1 insertion(+) diff --git a/.coveragerc b/.coveragerc index b8003214e21..1ee4159c69c 100644 --- a/.coveragerc +++ b/.coveragerc @@ -12,6 +12,7 @@ canonicalsrc = .tox/func-s_sklearn-py37/lib/python3.7/site-packages/wandb/ .tox/func-s_tf115-py37/lib/python3.7/site-packages/wandb/ .tox/func-s_tf21-py37/lib/python3.7/site-packages/wandb/ + .tox/func-s_tf24-py37/lib/python3.7/site-packages/wandb/ .tox/func-s_tf25-py37/lib/python3.7/site-packages/wandb/ .tox/func-s_tf26-py37/lib/python3.7/site-packages/wandb/ From ce25c4dcb5808786f27c28c683d62dbb293a1e3b Mon Sep 17 00:00:00 2001 From: Dmitry Duev Date: Wed, 30 Nov 2022 10:34:44 -0800 Subject: [PATCH 11/12] fix .codecov.yml --- .codecov.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index 6b8077b18a8..f5075ab2931 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -4,14 +4,14 @@ codecov: # To calculate after_n_builds use # ./tools/coverage-tool.py jobs | wc -l # also change comment block after_n_builds just below - after_n_builds: 37 + after_n_builds: 38 wait_for_ci: no comment: layout: "reach, diff, flags, files" behavior: default require_changes: no - after_n_builds: 37 + after_n_builds: 38 ignore: - "wandb/vendor" From 67756e112f2b8ea87399e18868b5a1bcc3a34ad5 Mon Sep 17 00:00:00 2001 From: Dmitry Duev Date: Wed, 30 Nov 2022 11:10:08 -0800 Subject: [PATCH 12/12] fix is_old_tf_keras_version caching --- wandb/integration/keras/callbacks/model_checkpoint.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/wandb/integration/keras/callbacks/model_checkpoint.py b/wandb/integration/keras/callbacks/model_checkpoint.py index a8328ede56c..dc9f6e3e4d7 100644 --- a/wandb/integration/keras/callbacks/model_checkpoint.py +++ b/wandb/integration/keras/callbacks/model_checkpoint.py @@ -186,9 +186,12 @@ def _check_filepath(self) -> None: @property def is_old_tf_keras_version(self) -> Optional[bool]: - from pkg_resources import parse_version + if self._is_old_tf_keras_version is None: + from pkg_resources import parse_version - if parse_version(tf.keras.__version__) < parse_version("2.6.0"): - self._is_old_tf_keras_version = True + if parse_version(tf.keras.__version__) < parse_version("2.6.0"): + self._is_old_tf_keras_version = True + else: + self._is_old_tf_keras_version = False return self._is_old_tf_keras_version