From 05c237ea94e08786abbac6c6185cfdfa262a8c53 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 2 Mar 2022 10:38:13 +0000 Subject: [PATCH] Update TF QA example (#15870) --- .../tensorflow/question-answering/run_qa.py | 85 ++++++------------- 1 file changed, 24 insertions(+), 61 deletions(-) diff --git a/examples/tensorflow/question-answering/run_qa.py b/examples/tensorflow/question-answering/run_qa.py index a3b5c73568b0a..7d0cb6bb45dff 100755 --- a/examples/tensorflow/question-answering/run_qa.py +++ b/examples/tensorflow/question-answering/run_qa.py @@ -32,6 +32,8 @@ from transformers import ( AutoConfig, AutoTokenizer, + DataCollatorWithPadding, + DefaultDataCollator, EvalPrediction, HfArgumentParser, PreTrainedTokenizerFast, @@ -209,51 +211,6 @@ def on_epoch_end(self, epoch, logs=None): self.model.save_pretrained(self.output_dir) -def convert_dataset_for_tensorflow( - dataset, batch_size, dataset_mode="variable_batch", shuffle=True, drop_remainder=True -): - """Converts a Hugging Face dataset to a Tensorflow Dataset. The dataset_mode controls whether we pad all batches - to the maximum sequence length, or whether we only pad to the maximum length within that batch. The former - is most useful when training on TPU, as a new graph compilation is required for each sequence length. - """ - - def densify_ragged_batch(features, label=None): - features = { - feature: ragged_tensor.to_tensor(shape=batch_shape[feature]) if feature in tensor_keys else ragged_tensor - for feature, ragged_tensor in features.items() - } - if label is None: - return features - else: - return features, label - - tensor_keys = ["attention_mask", "input_ids"] - label_keys = ["start_positions", "end_positions"] - if dataset_mode == "variable_batch": - batch_shape = {key: None for key in tensor_keys} - data = {key: tf.ragged.constant(dataset[key]) for key in tensor_keys} - elif dataset_mode == "constant_batch": - data = {key: tf.ragged.constant(dataset[key]) for key in tensor_keys} - batch_shape = { - key: tf.concat(([batch_size], ragged_tensor.bounding_shape()[1:]), axis=0) - for key, ragged_tensor in data.items() - } - else: - raise ValueError("Unknown dataset mode!") - - if all([key in dataset.features for key in label_keys]): - for key in label_keys: - data[key] = tf.convert_to_tensor(dataset[key]) - dummy_labels = tf.zeros_like(dataset[key]) - tf_dataset = tf.data.Dataset.from_tensor_slices((data, dummy_labels)) - else: - tf_dataset = tf.data.Dataset.from_tensor_slices(data) - if shuffle: - tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset)) - tf_dataset = tf_dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder).map(densify_ragged_batch) - return tf_dataset - - # endregion @@ -391,6 +348,12 @@ def main(): ) max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + if data_args.pad_to_max_length or isinstance(training_args.strategy, tf.distribute.TPUStrategy): + logger.info("Padding all batches to max length because argument was set or we're on TPU.") + padding = "max_length" + else: + padding = False + # Training preprocessing def prepare_train_features(examples): # Some of the questions have lots of whitespace on the left, which is not useful and will make the @@ -409,7 +372,7 @@ def prepare_train_features(examples): stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, - padding="max_length" if data_args.pad_to_max_length else False, + padding=padding, ) # Since one example might give us several features if it has a long context, we need a map from a feature to @@ -508,7 +471,7 @@ def prepare_validation_features(examples): stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, - padding="max_length" if data_args.pad_to_max_length else False, + padding=padding, ) # Since one example might give us several features if it has a long context, we need a map from a feature to @@ -631,27 +594,27 @@ def compute_metrics(p: EvalPrediction): clipnorm=training_args.max_grad_norm, ) - def dummy_loss(y_true, y_pred): - return tf.reduce_mean(y_pred) - - losses = {"loss": dummy_loss} - model.compile(optimizer=optimizer, loss=losses) + # no user-specified loss = will use the model internal loss + model.compile(optimizer=optimizer) # endregion # region Training + if padding: + data_collator = DefaultDataCollator(return_tensors="tf") + else: + data_collator = DataCollatorWithPadding(tokenizer, return_tensors="tf") + tensor_keys = ["attention_mask", "input_ids"] + label_keys = ["start_positions", "end_positions"] + if training_args.do_train: # Make a tf.data.Dataset for this - if isinstance(training_args.strategy, tf.distribute.TPUStrategy) or data_args.pad_to_max_length: - logger.info("Padding all batches to max length because argument was set or we're on TPU.") - dataset_mode = "constant_batch" - else: - dataset_mode = "variable_batch" - training_dataset = convert_dataset_for_tensorflow( - processed_datasets["train"], + training_dataset = processed_datasets["train"].to_tf_dataset( + # labels are passed as input, as we will use the model's internal loss + columns=tensor_keys + label_keys, + shuffle=True, batch_size=training_args.per_device_train_batch_size, - dataset_mode=dataset_mode, + collate_fn=data_collator, drop_remainder=True, - shuffle=True, ) model.fit(training_dataset, epochs=int(training_args.num_train_epochs)) # endregion