forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generation_contrastive_search.py
196 lines (176 loc) 路 9.15 KB
/
generation_contrastive_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# coding=utf-8
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains the utils functions for the contrastive search, which will be called in `generation_utils`
"""
from typing import Tuple
import torch
from torch import nn
from .utils import logging
logger = logging.get_logger(__name__)
def ranking_fast(
context_hidden: torch.FloatTensor,
next_hidden: torch.FloatTensor,
next_top_k_probs: torch.FloatTensor,
alpha: float,
beam_width: int,
) -> Tuple[torch.FloatTensor]:
"""
context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam
"""
_, context_len, embed_dim = context_hidden.size()
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]
scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
scores = torch.stack(torch.split(scores, beam_width)) # [B, K]
selected_scores, selected_idx = scores.max(dim=-1) # [B]
return selected_scores, selected_idx
def ContrastiveDecodingOneStepFast(
model,
beam_width: int = 1,
penalty_alpha: float = 0.0,
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
last_hidden_states: torch.FloatTensor = None,
logit_for_next_step: torch.FloatTensor = None,
is_encoder_decoder: bool = False,
**model_inputs,
) -> Tuple:
"""
contrastive search first selects top-k candidates by the logit scores; then these candidate tokens are fed into the
language models to compute the degeneration penalty, which will be used to re-rank these candidate tokens.
Args:
model ([`PretrainedModel`]): The decoder-only or encoder-decoder generation language models.
beam_width (`int`, *optional*, defauls to 1):
If > 1, top k candidate tokens are selected for re-ranking.
penalty_alpha (`float`, *optional*, defaults to 0.0):
If > 0, the model confidence will minus the degeneration penalty that is weighted by the penalty_alpha
parameter.
past_key_value (`Tuple[Tuple[torch.FloatTensor]]`):
It saves the cached key-value results that computed by previous steps.
last_hidden_states (`torch.FloatTensor`, *optional*):
The last_hidden_states generated by the previous step.
logit_for_next_step (`torch.FloatTensor`, *optional*):
The logit_for_next_step generated by the previous step.
is_encoder_decoder (`bool`, *optional*, defaults to `False`):
If `True`, the model is an encoder-decode model, otherwise the model is a decoder-only model.
From: https://github.com/yxuansu/SimCTG/blob/main/contrastive_search_explanation/README.md
"""
bsz, seqlen, embed_dim = last_hidden_states.size()
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width)
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids)
past_key_values = enlarge_past_key_values(past_key_values, beam_width)
# build next attention mask
if "attention_mask" in model_inputs:
attention_mask = model_inputs["attention_mask"] # [B, S]
# decoder-only model need the full attention mask, not only the mask for the last token
if is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1)
attention_mask = attention_mask.unsqueeze(1).expand(-1, beam_width, -1).reshape(-1, attention_mask.size(-1))
else:
attention_mask = None
# encoder-decoder model also contains the `encoder_outputs`
if is_encoder_decoder and "encoder_outputs" in model_inputs:
encoder_outputs = model_inputs["encoder_outputs"]
else:
encoder_outputs = None
next_model_inputs = model.prepare_inputs_for_generation(
top_k_ids.view(-1, 1),
past=past_key_values,
attention_mask=attention_mask,
use_cache=True,
encoder_outputs=encoder_outputs,
)
# compute the candidate tokens by the language model and collects their hidden_states
output = model(output_hidden_states=True, **next_model_inputs)
past_key_values = output.past_key_values
logits = output.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models
if is_encoder_decoder:
next_hidden = output.decoder_hidden_states[-1]
full_hidden_states = output.decoder_hidden_states
else:
next_hidden = output.hidden_states[-1]
full_hidden_states = output.hidden_states
context_hidden = (
last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz * beam_width, seqlen, embed_dim)
)
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence
# the scores and index of the selected tokens are returned
selected_scores, selected_idx = ranking_fast(
context_hidden,
next_hidden,
top_k_probs,
penalty_alpha,
beam_width,
)
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1)
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width))
next_hidden = next_hidden[range(bsz), selected_idx, :]
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
decoder_hidden_states = []
for layer in full_hidden_states:
layer = torch.stack(torch.split(layer.squeeze(dim=1), beam_width))
layer = layer[range(bsz), selected_idx, :]
decoder_hidden_states.append(layer)
past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :]
return next_id.squeeze(dim=-1), past_key_values, last_hidden_states, logits, selected_scores, decoder_hidden_states
def enlarge_past_key_values(
past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int
) -> Tuple[Tuple[torch.FloatTensor]]:
"""
Copy and extend the past_key_values for the next step re-rank each item in `past_key_values` is the 4-dimension
matrix, whose shapre is [batch_size, num_head, seq_len, embed_dim] Suppose the size of the next token candidate
size is K, we need to obtain the new `past_key_values`, whose shape is [batch_size*K, num_head, seq_len, embed_dim]
"""
# from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
new_key_values = []
for layer in past_key_values:
items = []
# item is either the key or the value matrix
for item in layer:
bsz, num_head, seq_len, esz = item.size()
item = (
item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz * beam_width, num_head, seq_len, esz)
) # [bsz*beam, num_head, seq_len, esz]
items.append(item)
new_key_values.append(items)
return new_key_values
def select_past_key_values(
past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int, selected_idx: torch.FloatTensor
) -> Tuple[Tuple[torch.FloatTensor]]:
"""
Extract the `past_key_value` for the selected tokens, each item in `past_key_value` is the 4-dimension matrix,
whose shape is [batch_size*K, num_head, seq_len, embed_dim], where K is the number of the candidate tokens. We aim
to obtain the `past_key_value` of the selected next token, whose shape is [batch_size, num_head, seq_len,
embed_dim]
"""
new_key_values = []
for layer in past_key_values:
items = []
# item is either the key or the value matrix
for item in layer:
bsz_and_beam, num_head, seq_len, esz = item.size()
bsz = int(bsz_and_beam // beam_width)
item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz]
items.append(item)
new_key_values.append(items)
return new_key_values