Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generate kwargs to Seq2SeqTrainingArguments #13339

Merged
merged 4 commits into from Aug 31, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/pytorch/summarization/run_summarization.py
Expand Up @@ -556,12 +556,15 @@ def compute_metrics(eval_preds):

# Evaluation
results = {}
max_length = (
training_args.generate_max_length
if training_args.generate_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generate_num_beams
if training_args.do_eval:
logger.info("*** Evaluate ***")

metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

Expand All @@ -572,10 +575,7 @@ def compute_metrics(eval_preds):
logger.info("*** Predict ***")

predict_results = trainer.predict(
predict_dataset,
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
)
metrics = predict_results.metrics
max_predict_samples = (
Expand Down
15 changes: 8 additions & 7 deletions examples/pytorch/translation/run_translation.py
Expand Up @@ -549,12 +549,16 @@ def compute_metrics(eval_preds):

# Evaluation
results = {}
max_length = (
training_args.generate_max_length
if training_args.generate_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generate_num_beams
if training_args.do_eval:
logger.info("*** Evaluate ***")

metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

Expand All @@ -565,10 +569,7 @@ def compute_metrics(eval_preds):
logger.info("*** Predict ***")

predict_results = trainer.predict(
predict_dataset,
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
)
metrics = predict_results.metrics
max_predict_samples = (
Expand Down
12 changes: 4 additions & 8 deletions src/transformers/trainer_seq2seq.py
Expand Up @@ -70,10 +70,8 @@ def evaluate(
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
if max_length is not None or not hasattr(self, "_max_length"):
self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
self._max_length = max_length if max_length is not None else self.args.generate_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generate_num_beams
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

def predict(
Expand Down Expand Up @@ -119,10 +117,8 @@ def predict(
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
if max_length is not None or not hasattr(self, "_max_length"):
self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
self._max_length = max_length if max_length is not None else self.args.generate_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generate_num_beams
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

def prediction_step(
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/training_args_seq2seq.py
Expand Up @@ -14,6 +14,7 @@

import logging
from dataclasses import dataclass, field
from typing import Optional

from .file_utils import add_start_docstrings
from .training_args import TrainingArguments
Expand All @@ -40,3 +41,17 @@ class Seq2SeqTrainingArguments(TrainingArguments):
predict_with_generate: bool = field(
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
)
generate_max_length: Optional[int] = field(
default=None,
metadata={
"help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
"to the `max_length` value of the model configuration."
},
)
generate_num_beams: Optional[int] = field(
default=None,
metadata={
"help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
"to the `num_beams` value of the model configuration."
},
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generate_* name comes from the generate method, but for users that aren't power-users I don't think it's the best name as the verb implies an action; how about renaming to generation_max_length and generation_num_beans?