Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(integrations): fix TF compatibility issues with WandbModelCheckpoint #4432

Merged
merged 25 commits into from Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9806361
fixed version issues wrt keras
soumik12345 Nov 3, 2022
fc4eb4a
Merge branch 'main' into fix/keras-version-patch
dmitryduev Nov 3, 2022
30ea4bf
Merge branch 'main' into fix/keras-version-patch
ayulockin Nov 4, 2022
f306a80
Merge branch 'main' into fix/keras-version-patch
ayulockin Nov 8, 2022
1696d37
tests working
ayulockin Nov 8, 2022
4da9169
make code-check happy
ayulockin Nov 8, 2022
371750a
add test for tf 2.4.x
ayulockin Nov 8, 2022
8f49127
Merge branch 'main' into fix/keras-version-patch
ayulockin Nov 9, 2022
dff2e01
Merge branch 'main' into fix/keras-version-patch
ayulockin Nov 9, 2022
65e729c
updated version check logic
soumik12345 Nov 9, 2022
5162e84
code-check + abstract parse_version
ayulockin Nov 9, 2022
4f747b0
Merge branch 'main' into fix/keras-version-patch
ayulockin Nov 10, 2022
eb2564f
Merge branch 'main' into fix/keras-version-patch
ayulockin Nov 10, 2022
47dff59
Merge branch 'main' into fix/keras-version-patch
ayulockin Nov 21, 2022
1121e60
address feedback
ayulockin Nov 24, 2022
ae6bf7a
Merge branch 'main' into fix/keras-version-patch
ayulockin Nov 24, 2022
f5d389b
fix test tf2.4
ayulockin Nov 24, 2022
e4ac0a1
make mypy happy
ayulockin Nov 24, 2022
fbef027
Merge branch 'main' into fix/keras-version-patch
ayulockin Nov 29, 2022
946c868
fix coveragerc
dmitryduev Nov 30, 2022
9be5d2f
Merge branch 'main' into fix/keras-version-patch
ayulockin Nov 30, 2022
ce25c4d
fix .codecov.yml
dmitryduev Nov 30, 2022
71f1622
Merge branch 'fix/keras-version-patch' of https://github.com/wandb/wa…
dmitryduev Nov 30, 2022
67756e1
fix is_old_tf_keras_version caching
dmitryduev Nov 30, 2022
423b075
Merge branch 'main' into fix/keras-version-patch
dmitryduev Nov 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/config.yml
Expand Up @@ -1226,6 +1226,7 @@ workflows:
- "metaflow"
- "tf115"
- "tf21"
- "tf24"
- "tf25"
- "tf26"
- "ray112"
Expand Down
4 changes: 2 additions & 2 deletions .codecov.yml
Expand Up @@ -4,14 +4,14 @@ codecov:
# To calculate after_n_builds use
# ./tools/coverage-tool.py jobs | wc -l
# also change comment block after_n_builds just below
after_n_builds: 37
after_n_builds: 38
wait_for_ci: no

comment:
layout: "reach, diff, flags, files"
behavior: default
require_changes: no
after_n_builds: 37
after_n_builds: 38

ignore:
- "wandb/vendor"
Expand Down
1 change: 1 addition & 0 deletions .coveragerc
Expand Up @@ -12,6 +12,7 @@ canonicalsrc =
.tox/func-s_sklearn-py37/lib/python3.7/site-packages/wandb/
.tox/func-s_tf115-py37/lib/python3.7/site-packages/wandb/
.tox/func-s_tf21-py37/lib/python3.7/site-packages/wandb/
.tox/func-s_tf24-py37/lib/python3.7/site-packages/wandb/
.tox/func-s_tf25-py37/lib/python3.7/site-packages/wandb/
.tox/func-s_tf26-py37/lib/python3.7/site-packages/wandb/

Expand Down
@@ -0,0 +1,16 @@
id: 0.keras.modelcheckpoint.tf24
tag:
shard: tf24
ayulockin marked this conversation as resolved.
Show resolved Hide resolved
plugin:
- wandb
command:
program: test_keras_model_checkpoint.py
depend:
requirements:
- tensorflow==2.4.1
assert:
- :wandb:runs_len: 1
- :op:contains:
- :wandb:runs[0][telemetry][3] # feature
- 39 # keras_model_checkpoint
- :wandb:runs[0][exitcode]: 0
Expand Up @@ -3,8 +3,6 @@
import wandb
from wandb.keras import WandbModelCheckpoint

tf.keras.utils.set_random_seed(1234)

run = wandb.init(project="keras")

x = np.random.randint(255, size=(100, 28, 28, 1))
Expand Down Expand Up @@ -37,9 +35,9 @@ def get_model():
WandbModelCheckpoint(
filepath="wandb/model/model_{epoch}",
monitor="accuracy",
save_best_only=True,
save_best_only=False,
save_weights_only=False,
save_freq=1,
save_freq=2,
)
],
)
5 changes: 3 additions & 2 deletions tox.ini
Expand Up @@ -7,7 +7,7 @@ envlist=
flake8,
docstrings,
py{36,37,38,39,launch,launch38},
func-s_{base,sklearn,metaflow,tf115,tf21,tf25,tf26,ray112,ray2,service,docs,
func-s_{base,sklearn,metaflow,tf115,tf21,tf24,tf25,tf26,ray112,ray2,service,docs,
imports{1,2,3,4,5,6,7,8,9,10,11,12},noml,grpc}-py37,
standalone-{cpu,gpu,tpu,local}-py38,
func-cover,
Expand Down Expand Up @@ -409,7 +409,7 @@ commands =
cp .coverage coverage.xml cover-results/
coverage report -m --ignore-errors --skip-covered --omit "wandb/vendor/*"

[testenv:func-s_{base,sklearn,metaflow,tf115,tf21,tf25,tf26,ray112,ray2,service,py310,docs,imports1,imports2,imports3,imports4,imports5,imports6,imports7,imports8,imports9,imports10,imports11,imports12,noml,grpc,kfp}-{py36,py37,py38,py39,py310}]
[testenv:func-s_{base,sklearn,metaflow,tf115,tf21,tf24,tf25,tf26,ray112,ray2,service,py310,docs,imports1,imports2,imports3,imports4,imports5,imports6,imports7,imports8,imports9,imports10,imports11,imports12,noml,grpc,kfp}-{py36,py37,py38,py39,py310}]
install_command = pip install --extra-index-url https://download.pytorch.org/whl/cpu {opts} {packages}
commands_pre =
setenv =
Expand Down Expand Up @@ -437,6 +437,7 @@ commands =
func-s_metaflow-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard metaflow run {posargs:--all}
func-s_tf115-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf115 run {posargs:--all}
func-s_tf21-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf21 run {posargs:--all}
func-s_tf24-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf24 run {posargs:--all}
func-s_tf25-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf25 run {posargs:--all}
func-s_tf26-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard tf26 run {posargs:--all}
func-s_ray112-py{36,37,38,39,310}: yea {env:CI_PYTEST_SPLIT_ARGS:} --strict --shard ray112 run {posargs:--all}
Expand Down
36 changes: 29 additions & 7 deletions wandb/integration/keras/callbacks/model_checkpoint.py
Expand Up @@ -4,6 +4,7 @@
import sys
from typing import Any, Dict, List, Optional, Union

import tensorflow as tf # type: ignore
from tensorflow.keras import callbacks # type: ignore

import wandb
Expand Down Expand Up @@ -108,16 +109,22 @@ def __init__(
if self.save_best_only:
self._check_filepath()

self._is_old_tf_keras_version: Optional[bool] = None

def on_train_batch_end(
self, batch: int, logs: Optional[Dict[str, float]] = None
) -> None:
if self._should_save_on_batch(batch):
# Save the model
self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)
# Get filepath where the model checkpoint is saved.
filepath = self._get_file_path(
epoch=self._current_epoch, batch=batch, logs=logs
)
if self.is_old_tf_keras_version:
# Save the model and get filepath
self._save_model(epoch=self._current_epoch, logs=logs)
filepath = self._get_file_path(epoch=self._current_epoch, logs=logs)
else:
# Save the model and get filepath
self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)
filepath = self._get_file_path(
epoch=self._current_epoch, batch=batch, logs=logs
)
# Log the model as artifact
aliases = ["latest", f"epoch_{self._current_epoch}_batch_{batch}"]
self._log_ckpt_as_artifact(filepath, aliases=aliases)
Expand All @@ -127,7 +134,10 @@ def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, float]] = None) -> N
# Check if model checkpoint is created at the end of epoch.
if self.save_freq == "epoch":
# Get filepath where the model checkpoint is saved.
filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs)
if self.is_old_tf_keras_version:
filepath = self._get_file_path(epoch=epoch, logs=logs)
else:
filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs)
# Log the model as artifact
aliases = ["latest", f"epoch_{epoch}"]
self._log_ckpt_as_artifact(filepath, aliases=aliases)
Expand Down Expand Up @@ -173,3 +183,15 @@ def _check_filepath(self) -> None:
"This ensures correct interpretation of the logged artifacts.",
repeat=False,
)

@property
def is_old_tf_keras_version(self) -> Optional[bool]:
if self._is_old_tf_keras_version is None:
from pkg_resources import parse_version

if parse_version(tf.keras.__version__) < parse_version("2.6.0"):
self._is_old_tf_keras_version = True
else:
self._is_old_tf_keras_version = False

return self._is_old_tf_keras_version