Skip to content

Commit

Permalink
Add generate kwargs to Seq2SeqTrainingArguments (#13339)
Browse files Browse the repository at this point in the history
* Add generate kwargs to Seq2SeqTrainingArguments

* typo

* Address review comments + doc

* Style
  • Loading branch information
sgugger committed Aug 31, 2021
1 parent 702f4a4 commit c76de10
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 23 deletions.
16 changes: 8 additions & 8 deletions examples/pytorch/summarization/run_summarization.py
Expand Up @@ -556,12 +556,15 @@ def compute_metrics(eval_preds):

# Evaluation
results = {}
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval:
logger.info("*** Evaluate ***")

metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

Expand All @@ -572,10 +575,7 @@ def compute_metrics(eval_preds):
logger.info("*** Predict ***")

predict_results = trainer.predict(
predict_dataset,
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
)
metrics = predict_results.metrics
max_predict_samples = (
Expand Down
15 changes: 8 additions & 7 deletions examples/pytorch/translation/run_translation.py
Expand Up @@ -549,12 +549,16 @@ def compute_metrics(eval_preds):

# Evaluation
results = {}
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval:
logger.info("*** Evaluate ***")

metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

Expand All @@ -565,10 +569,7 @@ def compute_metrics(eval_preds):
logger.info("*** Predict ***")

predict_results = trainer.predict(
predict_dataset,
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
)
metrics = predict_results.metrics
max_predict_samples = (
Expand Down
12 changes: 4 additions & 8 deletions src/transformers/trainer_seq2seq.py
Expand Up @@ -70,10 +70,8 @@ def evaluate(
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
if max_length is not None or not hasattr(self, "_max_length"):
self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

def predict(
Expand Down Expand Up @@ -119,10 +117,8 @@ def predict(
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
if max_length is not None or not hasattr(self, "_max_length"):
self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

def prediction_step(
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/training_args_seq2seq.py
Expand Up @@ -14,6 +14,7 @@

import logging
from dataclasses import dataclass, field
from typing import Optional

from .file_utils import add_start_docstrings
from .training_args import TrainingArguments
Expand All @@ -34,9 +35,29 @@ class Seq2SeqTrainingArguments(TrainingArguments):
the training set.
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
generation_max_length (:obj:`int`, `optional`):
The :obj:`max_length` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to
the :obj:`max_length` value of the model configuration.
generation_num_beams (:obj:`int`, `optional`):
The :obj:`num_beams` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to the
:obj:`num_beams` value of the model configuration.
"""

sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
predict_with_generate: bool = field(
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
)
generation_max_length: Optional[int] = field(
default=None,
metadata={
"help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
"to the `max_length` value of the model configuration."
},
)
generation_num_beams: Optional[int] = field(
default=None,
metadata={
"help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
"to the `num_beams` value of the model configuration."
},
)

0 comments on commit c76de10

Please sign in to comment.