Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add asset zipping functionality to TFJS converter #6915

Merged
merged 8 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion tfjs-converter/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ py_wheel(
"importlib_resources>=5.9.0",
"jax>=0.3.16",
"protobuf<3.20,>=3.9.2",
"tensorflow>=2.1.0,<3",
"tensorflow>=2.10.0,<3",
"tensorflow-decision-forests>=1.0.1",
"six>=1.12.0,<2",
"tensorflow-hub>=0.7.0,<0.13",
"packaging~=20.9",
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ jax>=0.3.16
importlib_resources>=5.9.0
protobuf<3.20,>=3.9.2
tensorflow>=2.10.0,<3
tensorflow-decision-forests>=1.0.1
six>=1.12.0,<2
tensorflow-hub>=0.7.0,<0.13; python_version >= "3"
packaging~=20.9
9 changes: 9 additions & 0 deletions tfjs-converter/python/tensorflowjs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ py_library(
deps = [requirement("tensorflow")],
)

py_library(
name = "expect_tensorflow_decision_forests_installed",
# This is a dummy rule used as a tensorflow dependency in open-source.
# We expect tensorflow-decision-forests to already be installed on
# the system, e.g. via
# `pip install tensorflow-decision-forests`.
deps = [requirement("tensorflow-decision-forests")],
)

py_library(
name = "expect_tensorflow_hub_installed",
# This is a dummy rule used as a tensorflow_hub dependency in open-source.
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ py_library(
":graph_rewrite_util",
"//tfjs-converter/python/tensorflowjs:expect_numpy_installed",
"//tfjs-converter/python/tensorflowjs:expect_packaging_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_decision_forests_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_hub_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
"//tfjs-converter/python/tensorflowjs:resource_loader",
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# File name for the indexing JSON file in an artifact directory.
ARTIFACT_MODEL_JSON_FILE_NAME = 'model.json'
ASSETS_DIRECTORY_NAME = 'assets'

# JSON string keys for fields of the indexing JSON.
ARTIFACT_MODEL_TOPOLOGY_KEY = 'modelTopology'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@

import json
import os
import shutil
import tempfile
from zipfile import ZipFile

# Required to load saved models that use TFDF.
import tensorflow_decision_forests
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import device_properties_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.io import gfile
from tensorflow.python.checkpoint.trackable_view import TrackableView
from tensorflow.python.eager import context
from tensorflow.python.framework import convert_to_constants
Expand Down Expand Up @@ -399,7 +405,7 @@ def write_artifacts(topology,
assert isinstance(weights_manifest, list)
model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

with tf.io.gfile.GFile(output_graph, 'w') as f:
with gfile.GFile(output_graph, 'w') as f:
json.dump(model_json, f)

def _remove_unused_control_flow_inputs(input_graph_def):
Expand All @@ -421,6 +427,33 @@ def _check_signature_in_model(saved_model, signature_name):
"are available: %s" % (signature_name,
saved_model.signatures.keys()))

def _copy_assets(saved_model_dir, output_dir):
input_assets_path = os.path.join(saved_model_dir, common.ASSETS_DIRECTORY_NAME)

if gfile.exists(input_assets_path) and gfile.isdir(input_assets_path):

tmp_dir = tempfile.mkdtemp()
zip_path = gfile.join(tmp_dir, common.ASSETS_DIRECTORY_NAME + '.zip')

with ZipFile(zip_path, 'w') as archive:
for (input_dir_path, _, file_names) in gfile.walk(input_assets_path):

relative_dir_path = os.path.relpath(input_dir_path, input_assets_path)

for file_name in file_names:

input_file_path = gfile.join(input_dir_path, file_name)
relative_file_path = gfile.join(relative_dir_path, file_name)

with gfile.GFile(input_file_path, 'rb') as input_file:
with archive.open(relative_file_path, 'w') as relative_file:
shutil.copyfileobj(input_file, relative_file)

output_assets_path = gfile.join(output_dir, common.ASSETS_DIRECTORY_NAME + '.zip')
gfile.copy(zip_path, output_assets_path, overwrite=True)

if gfile.isdir(tmp_dir):
gfile.rmtree(tmp_dir)

def _freeze_saved_model_v1(saved_model_dir, saved_model_tags,
output_node_names):
Expand Down Expand Up @@ -745,8 +778,8 @@ def _convert_tf_saved_model(output_dir,
if signature_def is None:
signature_def = 'serving_default'

if not tf.io.gfile.exists(output_dir):
tf.io.gfile.makedirs(output_dir)
if not gfile.exists(output_dir):
gfile.makedirs(output_dir)
output_graph = os.path.join(
output_dir, common.ARTIFACT_MODEL_JSON_FILE_NAME)

Expand All @@ -763,6 +796,7 @@ def _convert_tf_saved_model(output_dir,
model = _load_model(saved_model_dir, saved_model_tags_list)
_check_signature_in_model(model, signature_def)
concrete_func = model.signatures[signature_def]
_copy_assets(saved_model_dir, output_dir)
elif keras_model:
model = keras_model
input_signature = None
Expand Down Expand Up @@ -1137,7 +1171,7 @@ def convert_tf_hub_module(module_handle, output_dir,
# TODO(vbardiovskyg): We can remove this v1 code path once loading of all v1
# modules is fixed on the TF side, or once the modules we cannot load become
# replaced with newer versions.
if tf.io.gfile.exists(os.path.join(module_path, _HUB_V1_MODULE_PB)):
if gfile.exists(os.path.join(module_path, _HUB_V1_MODULE_PB)):
print("Loading the module using TF 1.X interface from %s." % module_path)
convert_tf_hub_module_v1(module_path, output_dir, signature,
quantization_dtype_map,
Expand Down