Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the state-of-the-art contrastive search decoding methods for the codebase of generation_utils.py #19477

Merged
merged 39 commits into from Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
942a7c6
add: the contrastive search for generaton_utils
gmftbyGMFTBY Oct 10, 2022
5909423
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 10, 2022
3e71819
add: testing scripts for contrastive search under examples/text-gener…
gmftbyGMFTBY Oct 10, 2022
47b2b3b
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 11, 2022
41e37a5
update the quality of codes
gmftbyGMFTBY Oct 11, 2022
e278a46
Merge branch 'csearch-pr-v2' of https://github.com/gmftbyGMFTBY/trans…
gmftbyGMFTBY Oct 11, 2022
38b100f
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
9abd1bb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
32e2a30
revise the docstring; make the generation_contrastive_search.py scripts;
gmftbyGMFTBY Oct 12, 2022
e9e2b26
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
1f1dac2
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
6226b9a
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
f3bfd87
revise the examples/pytorch/text-generation/run_generation_contrastiv…
gmftbyGMFTBY Oct 12, 2022
d3a91b8
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
ce26f9f
revise the necessary documents
gmftbyGMFTBY Oct 12, 2022
e801c6f
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
c78cf91
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
fb4174e
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
68429ad
fix: revise the docstring of generation_contrastive_search.py
gmftbyGMFTBY Oct 13, 2022
e1f0db9
Fix the code indentation
gmftbyGMFTBY Oct 13, 2022
42d78be
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
d2a5e02
fix: revise the nits and examples in contrastive_search docstring.
gmftbyGMFTBY Oct 13, 2022
1d4f782
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 14, 2022
d5f90fb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 14, 2022
d5d30b7
fix the copyright
gmftbyGMFTBY Oct 14, 2022
49000c6
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 17, 2022
3058e1c
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 17, 2022
c344a0a
delete generation_contrastive_search.py
gmftbyGMFTBY Oct 17, 2022
628ecda
Merge branch 'csearch-pr-v2' of https://github.com/gmftbyGMFTBY/trans…
gmftbyGMFTBY Oct 17, 2022
5ae4ce2
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 18, 2022
7af4cbb
revise the logic in contrastive_search
gmftbyGMFTBY Oct 18, 2022
b219a17
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 18, 2022
183d7cc
update the intergration test and the docstring
gmftbyGMFTBY Oct 18, 2022
65a1ebd
run the tests over
gmftbyGMFTBY Oct 18, 2022
4972bfb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 18, 2022
e11d342
add the slow decorate to the contrastive_search intergrate test
gmftbyGMFTBY Oct 18, 2022
da014bb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 19, 2022
2aa768c
add more test
gmftbyGMFTBY Oct 19, 2022
ced9f70
do the style, quality, consistency checks
gmftbyGMFTBY Oct 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/en/internal/generation_utils.mdx
Expand Up @@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.

This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`],
[`~generation_utils.GenerationMixin.greedy_search`],
[`~generation_utils.GenerationMixin.contrastive_search`],
[`~generation_utils.GenerationMixin.sample`],
[`~generation_utils.GenerationMixin.beam_search`],
[`~generation_utils.GenerationMixin.beam_sample`],
Expand Down Expand Up @@ -261,3 +262,5 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] top_k_top_p_filtering

[[autodoc]] tf_top_k_top_p_filtering

[[autodoc]] ContrastiveDecodingOneStepFast
1 change: 1 addition & 0 deletions docs/source/en/main_classes/text_generation.mdx
Expand Up @@ -26,6 +26,7 @@ Each framework has a generate method for auto-regressive text generation impleme
- sample
- beam_search
- beam_sample
- contrastive_search
gante marked this conversation as resolved.
Show resolved Hide resolved
- group_beam_search
- constrained_beam_search

Expand Down
139 changes: 139 additions & 0 deletions examples/pytorch/text-generation/run_generation_contrastive_search.py
@@ -0,0 +1,139 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
gante marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" The examples of running contrastive search on the auto-APIs;

Running this examples:
gante marked this conversation as resolved.
Show resolved Hide resolved
CUDA_VISIBLE_DEVICES=0 python run_generation_contrastive_search.py --model_name_or_path=gpt2-large --penalty_alpha=0.6 --k=4 --length=256
gante marked this conversation as resolved.
Show resolved Hide resolved
"""


import argparse
import logging

import numpy as np
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer


logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)


def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
)
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
)
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
)
parser.add_argument("--k", type=int, default=0)
parser.add_argument("--penalty_alpha", type=float, default=0.0)
parser.add_argument("--p", type=float, default=0.9)

parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")

parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
args = parser.parse_args()

args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()

logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")

set_seed(args)

# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)

# tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
# model = OPTForCausalLM.from_pretrained(args.model_name_or_path)
model.to(args.device)

if args.fp16:
model.half()

logger.info(args)
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")

inputs = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False)
inputs = {key: value.to(args.device) for key, value in inputs.items()}

output_sequences = model.generate(
**inputs,
max_length=args.length + len(inputs["input_ids"][0]),
penalty_alpha=args.penalty_alpha,
top_k=args.k,
)

generated_sequences = []
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
generated_sequence = generated_sequence.tolist()

# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, add_special_tokens=False)

# Remove all text after the stop token
text = text[: text.find(args.stop_token) if args.stop_token else None]

# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
prompt_text + text[len(tokenizer.decode(inputs["input_ids"][0], clean_up_tokenization_spaces=True)) :]
)

generated_sequences.append(total_sequence)
print(total_sequence)

return generated_sequences


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -800,6 +800,7 @@
"PhrasalConstraint",
]
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
_import_structure["generation_contrastive_search"] = ["ContrastiveDecodingOneStepFast"]
_import_structure["generation_logits_process"] = [
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
Expand Down Expand Up @@ -3746,6 +3747,7 @@
PhrasalConstraint,
)
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_contrastive_search import ContrastiveDecodingOneStepFast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe not put this in the public init as it could be prone to change in the future and I don't think most people will use it outside of generate no?

from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
Expand Down
197 changes: 197 additions & 0 deletions src/transformers/generation_contrastive_search.py
@@ -0,0 +1,197 @@
# coding=utf-8
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
gante marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch
from torch import nn

from .utils import logging


logger = logging.get_logger(__name__)


"""
This file contains the utils functions for the contrastive search, which will be called in `generation_utils`
"""
gante marked this conversation as resolved.
Show resolved Hide resolved


def ranking_fast(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible I'd be in favor of not adding a new file here. IMO ranking_fast can be changed to a logit processor and ContrastiveDecodingOneStepFast is could be fully moved into generation_utils.py right away.

IMO this will be easier to maintain and understand for people that know already how generate works - wdyt @gante ?

context_hidden: torch.FloatTensor,
next_hidden: torch.FloatTensor,
next_top_k_probs: torch.FloatTensor,
alpha: float,
beam_width: int,
) -> Tuple[torch.FloatTensor]:
"""
context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam
"""
_, context_len, embed_dim = context_hidden.size()
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]
scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
scores = torch.stack(torch.split(scores, beam_width)) # [B, K]
selected_scores, selected_idx = scores.max(dim=-1) # [B]
return selected_scores, selected_idx


def ContrastiveDecodingOneStepFast(
model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not very happy about passing the model into the function here and upper-casing the function.

Could we try to have this whole functionality directly inside the contrastive_search function? It would help readability a lot IMO.

@gante - what do you think?

beam_width: int = 1,
penalty_alpha: float = 0.0,
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
last_hidden_states: torch.FloatTensor = None,
logit_for_next_step: torch.FloatTensor = None,
is_encoder_decoder: bool = False,
**model_inputs,
) -> Tuple:
"""
contrastive search first selects top-k candidates by the logit scores; then these candidate tokens are fed into the
language models to compute the degeneration penalty, which will be used to re-rank these candidate tokens.

Args:
model: the generation model
gante marked this conversation as resolved.
Show resolved Hide resolved
beam_width (`int`, *optional*, defauls to 1):
If > 1, top k candidate tokens are selected for re-ranking
penalty_alpha (`float`, *optional*, defaults to 0.0):
If > 0, the model confidence will minus the degeneration penalty that is weighted by the penalty_alpha
parameter
past_key_value (`Tuple[Tuple[torch.FloatTensor]]`, *optional*, defaults to None):
it saves the cached key-value results that computed by previous steps
last_hidden_states (`torch.FloatTensor`, *optional*, defaults to None):
the last_hidden_states generated by the previous step
logit_for_next_step (`torch.FloatTensor`, *optional*, defaults to None):
the logit_for_next_step generated by the previous step
gante marked this conversation as resolved.
Show resolved Hide resolved
is_encoder_decode (`bool`, *optional*, defaults to False):
gante marked this conversation as resolved.
Show resolved Hide resolved
if True, the model is an encoder-decode model, else the model is an decoder-only model, such as GPT2 and
gante marked this conversation as resolved.
Show resolved Hide resolved
OPT

From: https://github.com/yxuansu/SimCTG/blob/main/simctg/utlisgpt.py
gante marked this conversation as resolved.
Show resolved Hide resolved
"""
bsz, seqlen, embed_dim = last_hidden_states.size()
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width)
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids)
past_key_values = enlarge_past_key_values(past_key_values, beam_width)

# build next attention mask
attention_mask = model_inputs["attention_mask"] # [B, S]
# decoder-only model need the full attention mask, not only the mask for the last token
if is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1)
attention_mask = attention_mask.unsqueeze(1).expand(-1, beam_width, -1).reshape(-1, attention_mask.size(-1))

# encoder-decoder model also contains the `encoder_outputs`
if is_encoder_decoder and "encoder_outputs" in model_inputs:
encoder_outputs = model_inputs["encoder_outputs"]
else:
encoder_outputs = None
next_model_inputs = model.prepare_inputs_for_generation(
top_k_ids.view(-1, 1),
past=past_key_values,
attention_mask=attention_mask,
use_cache=True,
encoder_outputs=encoder_outputs,
)
# compute the candidate tokens by the language model and collects their hidden_states
output = model(output_hidden_states=True, **next_model_inputs)
past_key_values = output.past_key_values
logits = output.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models
if is_encoder_decoder:
next_hidden = output.decoder_hidden_states[-1]
full_hidden_states = output.decoder_hidden_states
else:
next_hidden = output.hidden_states[-1]
full_hidden_states = output.hidden_states
context_hidden = (
last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz * beam_width, seqlen, embed_dim)
)

# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence
# the scores and index of the selected tokens are returned
selected_scores, selected_idx = ranking_fast(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante I think we can/should make this a logit processor that exceptionally takes context_hidden and next_hidden as inputs arguments as well . top_k_probs can be computed inside the logit processor, penalty_alpha and beam_width can be stored inside the logic processor at init

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the logic processor could then (just like all other processors) return the scores and then we compute selected_idx after

context_hidden,
next_hidden,
top_k_probs,
penalty_alpha,
beam_width,
)
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1)
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width))
next_hidden = next_hidden[range(bsz), selected_idx, :]
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)

decoder_hidden_states = []
for layer in full_hidden_states:
layer = torch.stack(torch.split(layer.squeeze(dim=1), beam_width))
layer = layer[range(bsz), selected_idx, :]
decoder_hidden_states.append(layer)

past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :]
return next_id.squeeze(dim=-1), past_key_values, last_hidden_states, logits, selected_scores, decoder_hidden_states


def enlarge_past_key_values(
past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int
) -> Tuple[Tuple[torch.FloatTensor]]:
"""
Copy and extend the past_key_values for the next step re-rank each item in `past_key_values` is the 4-dimension
matrix, whose shapre is [batch_size, num_head, seq_len, embed_dim] Suppose the size of the next token candidate
size is K, we need to obtain the new `past_key_values`, whose shape is [batch_size*K, num_head, seq_len, embed_dim]
"""
# from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
new_key_values = []
for layer in past_key_values:
items = []
# item is either the key or the value matrix
for item in layer:
bsz, num_head, seq_len, esz = item.size()
item = (
item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz * beam_width, num_head, seq_len, esz)
) # [bsz*beam, num_head, seq_len, esz]
items.append(item)
new_key_values.append(items)
return new_key_values


def select_past_key_values(
past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int, selected_idx: torch.FloatTensor
) -> Tuple[Tuple[torch.FloatTensor]]:
"""
Extract the `past_key_value` for the selected tokens, each item in `past_key_value` is the 4-dimension matrix,
whose shape is [batch_size*K, num_head, seq_len, embed_dim], where K is the number of the candidate tokens. We aim
to obtain the `past_key_value` of the selected next token, whose shape is [batch_size, num_head, seq_len,
embed_dim]
"""
new_key_values = []
for layer in past_key_values:
items = []
# item is either the key or the value matrix
for item in layer:
bsz_and_beam, num_head, seq_len, esz = item.size()
bsz = int(bsz_and_beam // beam_width)
item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz]
items.append(item)
new_key_values.append(items)
return new_key_values