Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF - Fix] Fix imports from tensorflow.python.keras with tf.__version__ >= 2.6.0 #3403

Merged
merged 9 commits into from Feb 26, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- fix the example of pytorch_lightning_mnist.py ([#3245](https://github.com/horovod/horovod/pull/3245))

- Call _setup in remote trainers to point to the correct shared lib path ([#3258](https://github.com/horovod/horovod/pull/3258))

- Fix imports from tensorflow.python.keras with tensorflow 2.6.0+ ([#3403](https://github.com/horovod/horovod/pull/3403))

## [v0.23.0] - 2021-10-06

### Added
Expand Down
16 changes: 11 additions & 5 deletions Dockerfile.test.cpu
Expand Up @@ -84,6 +84,13 @@ RUN if [[ ${SPARK_PACKAGE} != *"-preview"* ]]; then \
(cd /spark/python && python setup.py sdist && pip install --no-cache-dir dist/pyspark-*.tar.gz && rm dist/pyspark-*); \
fi

# Pin cloudpickle to 1.3.0
# Dill breaks clouldpickle > 1.3.0 when using Spark2
# https://github.com/cloudpipe/cloudpickle/issues/393
RUN if [[ ${PYSPARK_PACKAGE} == "pyspark==2."* ]]; then \
pip install --no-cache-dir cloudpickle==1.3.0; \
fi

# Install Ray.
# Updating to 1.7.0 to pass ray tests
RUN pip install --no-cache-dir ray==1.7.0
Expand Down Expand Up @@ -145,21 +152,20 @@ RUN if [[ ${MPI_KIND} != "None" ]]; then \
fi

# Install TensorFlow and Keras (releases).
# Pin h5py only for tensorflow<2.5: https://github.com/h5py/h5py/issues/1732
# Pin scipy!=1.4.0: https://github.com/scipy/scipy/issues/11237
RUN if [[ ${TENSORFLOW_PACKAGE} != "tf-nightly" ]]; then \
pip install --no-cache-dir ${TENSORFLOW_PACKAGE}; \
if [[ ${KERAS_PACKAGE} != "None" ]]; then \
if [[ ${TENSORFLOW_PACKAGE} == tensorflow*==1.* ]] || [[ ${TENSORFLOW_PACKAGE} == tensorflow*==2.[012345].* ]]; then \
h5py="h5py<3"; \
fi; \
pip uninstall -y keras-nightly; \
pip install --no-cache-dir ${KERAS_PACKAGE} ${h5py:-} "scipy!=1.4.0" "pandas<1.1.0"; \
pip install --no-cache-dir ${KERAS_PACKAGE} "scipy!=1.4.0" "pandas<1.1.0"; \
fi; \
mkdir -p ~/.keras; \
python -c "import tensorflow as tf; tf.keras.datasets.mnist.load_data()"; \
fi

# Pin h5py < 3 for tensorflow: https://github.com/tensorflow/tensorflow/issues/44467
RUN pip install 'h5py<3.0' --force-reinstall

# Install PyTorch (releases).
# Pin Pillow<7.0 for torchvision < 0.5.0: https://github.com/pytorch/vision/issues/1718
# Pin Pillow!=8.3.0 for torchvision: https://github.com/pytorch/vision/issues/4146
Expand Down
9 changes: 4 additions & 5 deletions Dockerfile.test.gpu
Expand Up @@ -113,23 +113,22 @@ RUN if [[ ${MPI_KIND} != "None" ]]; then \
fi

# Install TensorFlow and Keras (releases).
# Pin h5py only for tensorflow<2.5: https://github.com/h5py/h5py/issues/1732
# Pin scipy!=1.4.0: https://github.com/scipy/scipy/issues/11237
RUN if [[ ${TENSORFLOW_PACKAGE} != "tf-nightly-gpu" ]]; then \
pip install --no-cache-dir ${TENSORFLOW_PACKAGE}; \
if [[ ${KERAS_PACKAGE} != "None" ]]; then \
if [[ ${TENSORFLOW_PACKAGE} == tensorflow*==1.* ]] || [[ ${TENSORFLOW_PACKAGE} == tensorflow*==2.[012345].* ]]; then \
h5py="h5py<3"; \
fi; \
pip uninstall -y keras-nightly; \
pip install --no-cache-dir ${KERAS_PACKAGE} ${h5py:-} "scipy!=1.4.0" "pandas<1.1.0"; \
pip install --no-cache-dir ${KERAS_PACKAGE} "scipy!=1.4.0" "pandas<1.1.0"; \
fi; \
mkdir -p ~/.keras; \
ldconfig /usr/local/cuda/targets/x86_64-linux/lib/stubs; \
python -c "import tensorflow as tf; tf.keras.datasets.mnist.load_data()"; \
ldconfig; \
fi

# Pin h5py < 3 for tensorflow: https://github.com/tensorflow/tensorflow/issues/44467
RUN pip install 'h5py<3.0' --force-reinstall

# Install PyTorch (releases).
# Pin Pillow<7.0 for torchvision < 0.5.0: https://github.com/pytorch/vision/issues/1718
# Pin Pillow!=8.3.0 for torchvision: https://github.com/pytorch/vision/issues/4146
Expand Down
4 changes: 2 additions & 2 deletions docker-compose.test.yml
Expand Up @@ -95,7 +95,7 @@ services:
KERAS_PACKAGE: None
PYTORCH_PACKAGE: torch-nightly
TORCHVISION_PACKAGE: torchvision
PYTORCH_LIGHTNING_PACKAGE: pytorch_lightning
PYTORCH_LIGHTNING_PACKAGE: pytorch_lightning==1.3.8
MXNET_PACKAGE: mxnet-nightly

test-cpu-gloo-py3_7-tf2_7_0-keras2_7_0-torch1_10_1-mxnet1_9_0-pyspark2_4_8:
Expand Down Expand Up @@ -208,7 +208,7 @@ services:
TENSORFLOW_PACKAGE: tf-nightly-gpu
KERAS_PACKAGE: None
PYTORCH_PACKAGE: torch-nightly-cu111
PYTORCH_LIGHTNING_PACKAGE: pytorch_lightning
PYTORCH_LIGHTNING_PACKAGE: pytorch_lightning==1.3.8
TORCHVISION_PACKAGE: torchvision
MXNET_PACKAGE: mxnet-nightly-cu112

Expand Down
21 changes: 21 additions & 0 deletions horovod/common/util.py
Expand Up @@ -265,3 +265,24 @@ def is_iterable(x):
except TypeError:
return False
return True


@_cache
def is_version_greater_equal_than(ver, target):
chongxiaoc marked this conversation as resolved.
Show resolved Hide resolved
from distutils.version import LooseVersion
if any([not isinstance(_str, str) for _str in (ver, target)]):
raise ValueError("This function only accepts string arguments. \n"
"Received:\n"
"\t- ver (type {type_ver}: {val_ver})"
"\t- target (type {type_target}: {val_target})".format(
type_ver=(type(ver)),
val_ver=ver,
type_target=(type(target)),
val_target=target,
))

if len(target.split(".")) != 3:
raise ValueError("We only accepts target version values in the form "
"of: major.minor.patch. Received: {}".format(target))

return LooseVersion(ver) >= LooseVersion(target)
13 changes: 11 additions & 2 deletions horovod/spark/keras/tensorflow.py
Expand Up @@ -15,8 +15,17 @@

import json

from tensorflow.python.keras import backend as K
from tensorflow.python.keras import optimizers
import tensorflow as tf

from horovod.common.util import is_version_greater_equal_than

if is_version_greater_equal_than(tf.__version__, "2.6.0"):
from keras import backend as K
from keras import optimizers
else:
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import optimizers

from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import serialization

Expand Down
9 changes: 7 additions & 2 deletions horovod/tensorflow/keras/__init__.py
Expand Up @@ -19,7 +19,13 @@
import tensorflow as tf

from tensorflow import keras
from tensorflow.python.keras import backend as K

from horovod.common.util import is_version_greater_equal_than

if is_version_greater_equal_than(tf.__version__, "2.6.0"):
from keras import backend as K
else:
from tensorflow.python.keras import backend as K

from horovod.tensorflow import init
from horovod.tensorflow import shutdown
Expand Down Expand Up @@ -247,4 +253,3 @@ def load_model(filepath, custom_optimizers=None, custom_objects=None, compressio
def wrap_optimizer(cls):
return lambda **kwargs: DistributedOptimizer(cls(**kwargs), compression=compression)
return _impl.load_model(keras, wrap_optimizer, _OPTIMIZER_MODULES, filepath, custom_optimizers, custom_objects)

10 changes: 9 additions & 1 deletion horovod/tensorflow/keras/callbacks.py
Expand Up @@ -13,8 +13,16 @@
# limitations under the License.
# ==============================================================================


import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras import backend as K

from horovod.common.util import is_version_greater_equal_than

if is_version_greater_equal_than(tf.__version__, "2.6.0"):
from keras import backend as K
else:
from tensorflow.python.keras import backend as K

from horovod._keras import callbacks as _impl

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Expand Up @@ -129,7 +129,8 @@ def build_extensions(self):
tensorflow_cpu_require_list = ['tensorflow-cpu']
tensorflow_gpu_require_list = ['tensorflow-gpu']
keras_require_list = ['keras>=2.0.8,!=2.0.9,!=2.1.0,!=2.1.1']
pytorch_require_list = ['torch', 'pytorch_lightning']
# pytorch-lightning 1.3.8 is a stable version to work with horovod
pytorch_require_list = ['torch', 'pytorch_lightning==1.3.8']
mxnet_require_list = ['mxnet>=1.4.1']
pyspark_require_list = ['pyspark>=2.3.2;python_version<"3.8"',
'pyspark>=3.0.0;python_version>="3.8"']
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_spark.py
Expand Up @@ -1665,7 +1665,7 @@ def test_spark_task_service_execute_command(self):
file = os.path.sep.join([d, 'command_executed'])
self.do_test_spark_task_service_executes_command(client, file)

@mock.patch('horovod.runner.common.util.safe_shell_exec.GRACEFUL_TERMINATION_TIME_S', 0.5)
@mock.patch('horovod.runner.common.util.safe_shell_exec.GRACEFUL_TERMINATION_TIME_S', 30)
def test_spark_task_service_abort_command(self):
with spark_task_service(index=0) as (service, client, _):
with tempdir() as d:
Expand Down
3 changes: 0 additions & 3 deletions test/integration/test_spark_keras.py
Expand Up @@ -72,9 +72,6 @@ def fit(model, train_data, val_data, steps_per_epoch, validation_steps, callback
return fit


#PR3099 [https://github.com/horovod/horovod/pull/3099] doesn't fix
#Tensorflow>=2.6.0 tests
@pytest.mark.skipif(LooseVersion(tf.__version__) >= LooseVersion('2.6.0'), reason='TensorFlow>=2.6.0 tests')
class SparkKerasTests(tf.test.TestCase):
def __init__(self, *args, **kwargs):
super(SparkKerasTests, self).__init__(*args, **kwargs)
Expand Down
4 changes: 4 additions & 0 deletions test/parallel/test_tensorflow.py
Expand Up @@ -4051,6 +4051,8 @@ def test_horovod_join_allreduce(self):
self.assertSequenceEqual(ret_values, [ret] * size,
msg="hvd.join() did not return the same value on each rank")

@pytest.mark.skipif(LooseVersion(tf.__version__) >=
chongxiaoc marked this conversation as resolved.
Show resolved Hide resolved
LooseVersion('2.9.0'), reason='https://github.com/horovod/horovod/issues/3422')
def test_horovod_syncbn_gpu(self):
"""Test that the SyncBatchNormalization implementation is correct on GPU."""
# Only do this test if there are GPUs available.
Expand Down Expand Up @@ -4098,6 +4100,8 @@ def test_horovod_syncbn_gpu(self):
self.assertAllClose(self.evaluate(sync_bn.moving_mean), self.evaluate(bn.moving_mean))
self.assertAllClose(self.evaluate(sync_bn.moving_variance), self.evaluate(bn.moving_variance))

@pytest.mark.skipif(LooseVersion(tf.__version__) >=
LooseVersion('2.9.0'), reason='https://github.com/horovod/horovod/issues/3422')
def test_horovod_syncbn_cpu(self):
"""Test that the SyncBatchNormalization implementation is correct on CPU."""

Expand Down
11 changes: 10 additions & 1 deletion test/parallel/test_tensorflow2_keras.py
Expand Up @@ -26,7 +26,16 @@
import pytest

from tensorflow import keras
from tensorflow.python.keras.optimizer_v2 import optimizer_v2

from horovod.common.util import is_version_greater_equal_than

if is_version_greater_equal_than(tf.__version__, "2.6.0"):
if LooseVersion(keras.__version__) < LooseVersion("2.9.0"):
from keras.optimizer_v2 import optimizer_v2
else:
from keras.optimizers.optimizer_v2 import optimizer_v2
else:
from tensorflow.python.keras.optimizer_v2 import optimizer_v2

import horovod.tensorflow.keras as hvd

Expand Down
14 changes: 12 additions & 2 deletions test/parallel/test_tensorflow_keras.py
Expand Up @@ -25,8 +25,18 @@

from distutils.version import LooseVersion
from tensorflow import keras
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.optimizer_v2 import optimizer_v2

from horovod.common.util import is_version_greater_equal_than

if is_version_greater_equal_than(tf.__version__, "2.6.0"):
from keras import backend as K
if LooseVersion(keras.__version__) < LooseVersion("2.9.0"):
from keras.optimizer_v2 import optimizer_v2
else:
from keras.optimizers.optimizer_v2 import optimizer_v2
else:
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.optimizer_v2 import optimizer_v2

import horovod.tensorflow.keras as hvd

Expand Down