Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add asset zipping functionality to TFJS converter #6915

Merged
merged 8 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion tfjs-converter/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ py_wheel(
"importlib_resources>=5.9.0",
"jax>=0.3.16",
"protobuf<3.20,>=3.9.2",
"tensorflow>=2.1.0,<3",
"tensorflow>=2.10.0,<3",
"tensorflow-decision-forests>=1.0.1",
"six>=1.12.0,<2",
"tensorflow-hub>=0.7.0,<0.13",
"packaging~=20.9",
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ jax>=0.3.16
importlib_resources>=5.9.0
protobuf<3.20,>=3.9.2
tensorflow>=2.10.0,<3
tensorflow-decision-forests>=1.0.1
six>=1.12.0,<2
tensorflow-hub>=0.7.0,<0.13; python_version >= "3"
packaging~=20.9
9 changes: 9 additions & 0 deletions tfjs-converter/python/tensorflowjs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ py_library(
deps = [requirement("tensorflow")],
)

py_library(
name = "expect_tensorflow_decision_forests_installed",
# This is a dummy rule used as a tensorflow dependency in open-source.
# We expect tensorflow-decision-forests to already be installed on
# the system, e.g. via
# `pip install tensorflow-decision-forests`.
deps = [requirement("tensorflow-decision-forests")],
)

py_library(
name = "expect_tensorflow_hub_installed",
# This is a dummy rule used as a tensorflow_hub dependency in open-source.
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ py_library(
":graph_rewrite_util",
"//tfjs-converter/python/tensorflowjs:expect_numpy_installed",
"//tfjs-converter/python/tensorflowjs:expect_packaging_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_decision_forests_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_hub_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
"//tfjs-converter/python/tensorflowjs:resource_loader",
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# File name for the indexing JSON file in an artifact directory.
ARTIFACT_MODEL_JSON_FILE_NAME = 'model.json'
ASSETS_DIRECTORY_NAME = 'assets'

# JSON string keys for fields of the indexing JSON.
ARTIFACT_MODEL_TOPOLOGY_KEY = 'modelTopology'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@

import json
import os
import shutil
import tempfile
from zipfile import ZipFile

# Required to load saved models that use TFDF.
import tensorflow_decision_forests
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
Expand Down Expand Up @@ -421,6 +426,33 @@ def _check_signature_in_model(saved_model, signature_name):
"are available: %s" % (signature_name,
saved_model.signatures.keys()))

def _copy_assets(saved_model_dir, output_dir):
input_assets_path = os.path.join(saved_model_dir, common.ASSETS_DIRECTORY_NAME)

if os.path.isdir(input_assets_path) and os.path.isdir(input_assets_path):

tmp_dir = tempfile.mkdtemp()
zip_path = os.path.join(tmp_dir, common.ASSETS_DIRECTORY_NAME + '.zip')

with ZipFile(zip_path, 'w') as archive:
for (input_dir_path, _, file_names) in tf.io.gfile.walk(input_assets_path):

relative_dir_path = os.path.relpath(input_dir_path, input_assets_path)

for file_name in file_names:

input_file_path = os.path.join(input_dir_path, file_name)
relative_file_path = os.path.join(relative_dir_path, file_name)

with tf.io.gfile.GFile(input_file_path, 'rb') as input_file:
with archive.open(relative_file_path, 'w') as relative_file:
shutil.copyfileobj(input_file, relative_file)

output_assets_path = os.path.join(output_dir, common.ASSETS_DIRECTORY_NAME + '.zip')
tf.io.gfile.copy(zip_path, output_assets_path)

if os.path.isdir(tmp_dir):
shutil.rmtree(tmp_dir)

def _freeze_saved_model_v1(saved_model_dir, saved_model_tags,
output_node_names):
Expand Down Expand Up @@ -763,6 +795,7 @@ def _convert_tf_saved_model(output_dir,
model = _load_model(saved_model_dir, saved_model_tags_list)
_check_signature_in_model(model, signature_def)
concrete_func = model.signatures[signature_def]
_copy_assets(saved_model_dir, output_dir)
elif keras_model:
model = keras_model
input_signature = None
Expand Down