Skip to content

Commit

Permalink
Add asset zipping functionality to TFJS converter (#6915)
Browse files Browse the repository at this point in the history
* Add asset zipping functionality to TFJS converter

* Add TFDF to converter requirements

* Add TFDF dependency

* Fix assets overwrite bug

* Make copy assets conditional on TFDF input
  • Loading branch information
ahmedsabie committed Oct 12, 2022
1 parent 2cc528b commit 567754e
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 5 deletions.
3 changes: 2 additions & 1 deletion tfjs-converter/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ py_wheel(
"importlib_resources>=5.9.0",
"jax>=0.3.16",
"protobuf<3.20,>=3.9.2",
"tensorflow>=2.1.0,<3",
"tensorflow>=2.10.0,<3",
"tensorflow-decision-forests>=1.0.1",
"six>=1.12.0,<2",
"tensorflow-hub>=0.7.0,<0.13",
"packaging~=20.9",
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ jax>=0.3.16
importlib_resources>=5.9.0
protobuf<3.20,>=3.9.2
tensorflow>=2.10.0,<3
tensorflow-decision-forests>=1.0.1
six>=1.12.0,<2
tensorflow-hub>=0.7.0,<0.13; python_version >= "3"
packaging~=20.9
9 changes: 9 additions & 0 deletions tfjs-converter/python/tensorflowjs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ py_library(
deps = [requirement("tensorflow")],
)

py_library(
name = "expect_tensorflow_decision_forests_installed",
# This is a dummy rule used as a tensorflow dependency in open-source.
# We expect tensorflow-decision-forests to already be installed on
# the system, e.g. via
# `pip install tensorflow-decision-forests`.
deps = [requirement("tensorflow-decision-forests")],
)

py_library(
name = "expect_tensorflow_hub_installed",
# This is a dummy rule used as a tensorflow_hub dependency in open-source.
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ py_library(
":graph_rewrite_util",
"//tfjs-converter/python/tensorflowjs:expect_numpy_installed",
"//tfjs-converter/python/tensorflowjs:expect_packaging_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_decision_forests_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_hub_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
"//tfjs-converter/python/tensorflowjs:resource_loader",
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# File name for the indexing JSON file in an artifact directory.
ARTIFACT_MODEL_JSON_FILE_NAME = 'model.json'
ASSETS_DIRECTORY_NAME = 'assets'

# JSON string keys for fields of the indexing JSON.
ARTIFACT_MODEL_TOPOLOGY_KEY = 'modelTopology'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@

import json
import os
import shutil
import tempfile
from zipfile import ZipFile

# Required to load saved models that use TFDF.
import tensorflow_decision_forests
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import device_properties_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.io import gfile
from tensorflow.python.checkpoint.trackable_view import TrackableView
from tensorflow.python.eager import context
from tensorflow.python.framework import convert_to_constants
Expand Down Expand Up @@ -399,7 +405,7 @@ def write_artifacts(topology,
assert isinstance(weights_manifest, list)
model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

with tf.io.gfile.GFile(output_graph, 'w') as f:
with gfile.GFile(output_graph, 'w') as f:
json.dump(model_json, f)

def _remove_unused_control_flow_inputs(input_graph_def):
Expand All @@ -421,6 +427,49 @@ def _check_signature_in_model(saved_model, signature_name):
"are available: %s" % (signature_name,
saved_model.signatures.keys()))

def _copy_assets(saved_model_dir, output_dir):
input_assets_path = os.path.join(saved_model_dir, common.ASSETS_DIRECTORY_NAME)

if gfile.exists(input_assets_path) and gfile.isdir(input_assets_path):

tmp_dir = tempfile.mkdtemp()
zip_path = gfile.join(tmp_dir, common.ASSETS_DIRECTORY_NAME + '.zip')

with ZipFile(zip_path, 'w') as archive:
for (input_dir_path, _, file_names) in gfile.walk(input_assets_path):

relative_dir_path = os.path.relpath(input_dir_path, input_assets_path)

for file_name in file_names:

input_file_path = gfile.join(input_dir_path, file_name)
relative_file_path = gfile.join(relative_dir_path, file_name)

with gfile.GFile(input_file_path, 'rb') as input_file:
with archive.open(relative_file_path, 'w') as relative_file:
shutil.copyfileobj(input_file, relative_file)

output_assets_path = gfile.join(output_dir, common.ASSETS_DIRECTORY_NAME + '.zip')
gfile.copy(zip_path, output_assets_path, overwrite=True)

if gfile.isdir(tmp_dir):
gfile.rmtree(tmp_dir)

# TFDF stores the necessary files for its binary in the assets folder.
ASSET_REQUIRING_OPS = set([
'SimpleMLCreateModelResource'
'SimpleMLLoadModelFromPathWithHandle',
'SimpleMLInferenceOpWithHandle',
])

def _is_assets_required(model_ops):
return not ASSET_REQUIRING_OPS.isdisjoint(model_ops)

def _get_frozen_graph_ops(frozen_graph):
if frozen_graph is None:
return []
return [node.op for node in frozen_graph.as_graph_def().node]


def _freeze_saved_model_v1(saved_model_dir, saved_model_tags,
output_node_names):
Expand Down Expand Up @@ -745,8 +794,8 @@ def _convert_tf_saved_model(output_dir,
if signature_def is None:
signature_def = 'serving_default'

if not tf.io.gfile.exists(output_dir):
tf.io.gfile.makedirs(output_dir)
if not gfile.exists(output_dir):
gfile.makedirs(output_dir)
output_graph = os.path.join(
output_dir, common.ARTIFACT_MODEL_JSON_FILE_NAME)

Expand Down Expand Up @@ -852,6 +901,12 @@ def _convert_tf_saved_model(output_dir,
# tensorflow version.
tf_version = tf.__version__

if saved_model_dir:
model_ops = set(_get_frozen_graph_ops(frozen_graph)) |\
set(_get_frozen_graph_ops(frozen_initializer_graph))
if _is_assets_required(model_ops):
_copy_assets(saved_model_dir, output_dir)

optimize_graph(frozen_graph, signature,
output_graph, tf_version,
quantization_dtype_map=quantization_dtype_map,
Expand Down Expand Up @@ -1137,7 +1192,7 @@ def convert_tf_hub_module(module_handle, output_dir,
# TODO(vbardiovskyg): We can remove this v1 code path once loading of all v1
# modules is fixed on the TF side, or once the modules we cannot load become
# replaced with newer versions.
if tf.io.gfile.exists(os.path.join(module_path, _HUB_V1_MODULE_PB)):
if gfile.exists(os.path.join(module_path, _HUB_V1_MODULE_PB)):
print("Loading the module using TF 1.X interface from %s." % module_path)
convert_tf_hub_module_v1(module_path, output_dir, signature,
quantization_dtype_map,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import shutil
import tempfile
import unittest
import numpy as np

import tensorflow.compat.v2 as tf
from tensorflow_decision_forests.keras import GradientBoostedTreesModel
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
Expand All @@ -35,6 +37,7 @@
from tensorflowjs import version
from tensorflowjs.converters import graph_rewrite_util
from tensorflowjs.converters import tf_saved_model_conversion_v2
from tensorflowjs.converters.common import ASSETS_DIRECTORY_NAME

SAVED_MODEL_DIR = 'saved_model'
HUB_MODULE_DIR = 'hub_module'
Expand Down Expand Up @@ -246,6 +249,22 @@ def find_next_odd(v):
save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
save(root, save_dir, to_save)

def _create_saved_model_with_tfdf(self):
"""Test a basic TFDF model."""
P = 5
NUM_EXAMPLES = 10
NUM_FEATURES = 4

x_train = np.random.uniform(size=(NUM_EXAMPLES, NUM_FEATURES))
y_train = np.random.uniform(size=NUM_EXAMPLES) > 0.5
w_train = y_train * (P - 1) + 1 # 1 or p depending on the class.

model = GradientBoostedTreesModel()
model.fit(x=x_train, y=y_train, sample_weight=w_train)

save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
model.save(save_dir)

def _create_unsupported_saved_model(self):
root = tracking.AutoTrackable()
root.w = variables.Variable(tf.random.uniform([2, 2]))
Expand Down Expand Up @@ -936,6 +955,31 @@ def test_convert_saved_model_with_control_flow_v2(self):
glob.glob(
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))

def test_convert_saved_model_with_tfdf(self):
self._create_saved_model_with_tfdf()

tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
tf_saved_model_conversion_v2.convert_tf_saved_model(
tfjs_path, tfjs_path, skip_op_check=True
)

# Check model.json and weights manifest.
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
model_json = json.load(f)

# Check TFDF ops are present.
model_ops = [node['op'] for node in model_json['modelTopology']['node']]
self.assertTrue('SimpleMLInferenceOpWithHandle' in model_ops)

initializer_ops = [node['op'] for node in model_json['modelInitializer']['node']]
self.assertTrue('SimpleMLCreateModelResource' in initializer_ops)
self.assertTrue('SimpleMLLoadModelFromPathWithHandle' in initializer_ops)

# Check assets containing TFDF files were copied over.
self.assertTrue(
os.path.exists(
os.path.join(tfjs_path, ASSETS_DIRECTORY_NAME + '.zip')))

def test_convert_saved_model_sharded(self):
self._create_saved_model()
model_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
Expand Down

0 comments on commit 567754e

Please sign in to comment.