Skip to content

Commit

Permalink
Update TF QA example (#15870)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Mar 2, 2022
1 parent 6e57a56 commit 05c237e
Showing 1 changed file with 24 additions and 61 deletions.
85 changes: 24 additions & 61 deletions examples/tensorflow/question-answering/run_qa.py
Expand Up @@ -32,6 +32,8 @@
from transformers import (
AutoConfig,
AutoTokenizer,
DataCollatorWithPadding,
DefaultDataCollator,
EvalPrediction,
HfArgumentParser,
PreTrainedTokenizerFast,
Expand Down Expand Up @@ -209,51 +211,6 @@ def on_epoch_end(self, epoch, logs=None):
self.model.save_pretrained(self.output_dir)


def convert_dataset_for_tensorflow(
dataset, batch_size, dataset_mode="variable_batch", shuffle=True, drop_remainder=True
):
"""Converts a Hugging Face dataset to a Tensorflow Dataset. The dataset_mode controls whether we pad all batches
to the maximum sequence length, or whether we only pad to the maximum length within that batch. The former
is most useful when training on TPU, as a new graph compilation is required for each sequence length.
"""

def densify_ragged_batch(features, label=None):
features = {
feature: ragged_tensor.to_tensor(shape=batch_shape[feature]) if feature in tensor_keys else ragged_tensor
for feature, ragged_tensor in features.items()
}
if label is None:
return features
else:
return features, label

tensor_keys = ["attention_mask", "input_ids"]
label_keys = ["start_positions", "end_positions"]
if dataset_mode == "variable_batch":
batch_shape = {key: None for key in tensor_keys}
data = {key: tf.ragged.constant(dataset[key]) for key in tensor_keys}
elif dataset_mode == "constant_batch":
data = {key: tf.ragged.constant(dataset[key]) for key in tensor_keys}
batch_shape = {
key: tf.concat(([batch_size], ragged_tensor.bounding_shape()[1:]), axis=0)
for key, ragged_tensor in data.items()
}
else:
raise ValueError("Unknown dataset mode!")

if all([key in dataset.features for key in label_keys]):
for key in label_keys:
data[key] = tf.convert_to_tensor(dataset[key])
dummy_labels = tf.zeros_like(dataset[key])
tf_dataset = tf.data.Dataset.from_tensor_slices((data, dummy_labels))
else:
tf_dataset = tf.data.Dataset.from_tensor_slices(data)
if shuffle:
tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
tf_dataset = tf_dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder).map(densify_ragged_batch)
return tf_dataset


# endregion


Expand Down Expand Up @@ -391,6 +348,12 @@ def main():
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

if data_args.pad_to_max_length or isinstance(training_args.strategy, tf.distribute.TPUStrategy):
logger.info("Padding all batches to max length because argument was set or we're on TPU.")
padding = "max_length"
else:
padding = False

# Training preprocessing
def prepare_train_features(examples):
# Some of the questions have lots of whitespace on the left, which is not useful and will make the
Expand All @@ -409,7 +372,7 @@ def prepare_train_features(examples):
stride=data_args.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length" if data_args.pad_to_max_length else False,
padding=padding,
)

# Since one example might give us several features if it has a long context, we need a map from a feature to
Expand Down Expand Up @@ -508,7 +471,7 @@ def prepare_validation_features(examples):
stride=data_args.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length" if data_args.pad_to_max_length else False,
padding=padding,
)

# Since one example might give us several features if it has a long context, we need a map from a feature to
Expand Down Expand Up @@ -631,27 +594,27 @@ def compute_metrics(p: EvalPrediction):
clipnorm=training_args.max_grad_norm,
)

def dummy_loss(y_true, y_pred):
return tf.reduce_mean(y_pred)

losses = {"loss": dummy_loss}
model.compile(optimizer=optimizer, loss=losses)
# no user-specified loss = will use the model internal loss
model.compile(optimizer=optimizer)
# endregion

# region Training
if padding:
data_collator = DefaultDataCollator(return_tensors="tf")
else:
data_collator = DataCollatorWithPadding(tokenizer, return_tensors="tf")
tensor_keys = ["attention_mask", "input_ids"]
label_keys = ["start_positions", "end_positions"]

if training_args.do_train:
# Make a tf.data.Dataset for this
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) or data_args.pad_to_max_length:
logger.info("Padding all batches to max length because argument was set or we're on TPU.")
dataset_mode = "constant_batch"
else:
dataset_mode = "variable_batch"
training_dataset = convert_dataset_for_tensorflow(
processed_datasets["train"],
training_dataset = processed_datasets["train"].to_tf_dataset(
# labels are passed as input, as we will use the model's internal loss
columns=tensor_keys + label_keys,
shuffle=True,
batch_size=training_args.per_device_train_batch_size,
dataset_mode=dataset_mode,
collate_fn=data_collator,
drop_remainder=True,
shuffle=True,
)
model.fit(training_dataset, epochs=int(training_args.num_train_epochs))
# endregion
Expand Down

0 comments on commit 05c237e

Please sign in to comment.