Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI: Add flag to push TF weights directly into main #17720

Merged
merged 3 commits into from Jun 15, 2022
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
70 changes: 41 additions & 29 deletions src/transformers/commands/pt_to_tf.py
Expand Up @@ -45,7 +45,7 @@ def convert_command_factory(args: Namespace):

Returns: ServeCommand
"""
return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr, args.new_weights)
return PTtoTFCommand(args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push)


class PTtoTFCommand(BaseTransformersCLICommand):
Expand Down Expand Up @@ -76,13 +76,18 @@ def register_subcommand(parser: ArgumentParser):
default="",
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
)
train_parser.add_argument(
"--new-weights",
action="store_true",
help="Optional flag to create new TensorFlow weights, even if they already exist.",
)
train_parser.add_argument(
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
)
train_parser.add_argument(
"--new-weights",
"--push",
action="store_true",
help="Optional flag to create new TensorFlow weights, even if they already exist.",
help="Optional flag to push the weights directly to `main` (requires permissions)",
)
train_parser.set_defaults(func=convert_command_factory)

Expand Down Expand Up @@ -129,12 +134,13 @@ def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""):

return _find_pt_tf_differences(pt_outputs, tf_outputs, {})

def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args):
def __init__(self, model_name: str, local_dir: str, new_weights: bool, no_pr: bool, push: bool, *args):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._no_pr = no_pr
self._new_weights = new_weights
self._no_pr = no_pr
self._push = push

def get_text_inputs(self):
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
Expand Down Expand Up @@ -234,27 +240,33 @@ def run(self):
)
)

if not self._no_pr:
# TODO: remove try/except when the upload to PR feature is released
# (https://github.com/huggingface/huggingface_hub/pull/884)
try:
self._logger.warn("Uploading the weights into a new PR...")
hub_pr_url = upload_file(
path_or_fileobj=tf_weights_path,
path_in_repo=TF_WEIGHTS_NAME,
repo_id=self._model_name,
create_pr=True,
pr_commit_summary="Add TF weights",
pr_commit_description=(
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
f" difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
),
)
self._logger.warn(f"PR open in {hub_pr_url}")
except TypeError:
self._logger.warn(
f"You can now open a PR in https://huggingface.co/{self._model_name}/discussions, manually"
f" uploading the file in {tf_weights_path}"
)
if self._push:
repo.git_add(auto_lfs_track=True)
repo.git_commit("Add TF weights")
repo.git_push(blocking=True) # this prints a progress bar with the upload
self._logger.warn(f"TF weights pushed into {self._model_name}")
else:
if not self._no_pr:
gante marked this conversation as resolved.
Show resolved Hide resolved
# TODO: remove try/except when the upload to PR feature is released
# (https://github.com/huggingface/huggingface_hub/pull/884)
try:
self._logger.warn("Uploading the weights into a new PR...")
hub_pr_url = upload_file(
path_or_fileobj=tf_weights_path,
path_in_repo=TF_WEIGHTS_NAME,
repo_id=self._model_name,
create_pr=True,
pr_commit_summary="Add TF weights",
pr_commit_description=(
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
f" difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
),
)
self._logger.warn(f"PR open in {hub_pr_url}")
except TypeError:
self._logger.warn(
f"You can now open a PR in https://huggingface.co/{self._model_name}/discussions, manually"
f" uploading the file in {tf_weights_path}"
)