diff --git a/src/transformers/commands/pt_to_tf.py b/src/transformers/commands/pt_to_tf.py index 77a822544f41f..3a2465093c415 100644 --- a/src/transformers/commands/pt_to_tf.py +++ b/src/transformers/commands/pt_to_tf.py @@ -45,7 +45,7 @@ def convert_command_factory(args: Namespace): Returns: ServeCommand """ - return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr, args.new_weights) + return PTtoTFCommand(args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push) class PTtoTFCommand(BaseTransformersCLICommand): @@ -76,13 +76,18 @@ def register_subcommand(parser: ArgumentParser): default="", help="Optional local directory of the model repository. Defaults to /tmp/{model_name}", ) + train_parser.add_argument( + "--new-weights", + action="store_true", + help="Optional flag to create new TensorFlow weights, even if they already exist.", + ) train_parser.add_argument( "--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights." ) train_parser.add_argument( - "--new-weights", + "--push", action="store_true", - help="Optional flag to create new TensorFlow weights, even if they already exist.", + help="Optional flag to push the weights directly to `main` (requires permissions)", ) train_parser.set_defaults(func=convert_command_factory) @@ -129,12 +134,13 @@ def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""): return _find_pt_tf_differences(pt_outputs, tf_outputs, {}) - def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args): + def __init__(self, model_name: str, local_dir: str, new_weights: bool, no_pr: bool, push: bool, *args): self._logger = logging.get_logger("transformers-cli/pt_to_tf") self._model_name = model_name self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name) - self._no_pr = no_pr self._new_weights = new_weights + self._no_pr = no_pr + self._push = push def get_text_inputs(self): tokenizer = AutoTokenizer.from_pretrained(self._local_dir) @@ -234,7 +240,12 @@ def run(self): ) ) - if not self._no_pr: + if self._push: + repo.git_add(auto_lfs_track=True) + repo.git_commit("Add TF weights") + repo.git_push(blocking=True) # this prints a progress bar with the upload + self._logger.warn(f"TF weights pushed into {self._model_name}") + elif not self._no_pr: # TODO: remove try/except when the upload to PR feature is released # (https://github.com/huggingface/huggingface_hub/pull/884) try: