Skip to content

Commit

Permalink
Fixing question-answering with long contexts (huggingface#13873)
Browse files Browse the repository at this point in the history
* Tmp.

* Fixing BC for question answering with long context.

* Capping model_max_length to avoid tf overflow.

* Bad workaround bugged roberta.

* Fixing name.
  • Loading branch information
Narsil authored and stas00 committed Oct 12, 2021
1 parent c334196 commit 9b49dc8
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 98 deletions.
218 changes: 120 additions & 98 deletions src/transformers/pipelines/question_answering.py
Expand Up @@ -248,7 +248,13 @@ def __call__(self, *args, **kwargs):
return super().__call__(examples[0], **kwargs)
return super().__call__(examples, **kwargs)

def preprocess(self, example, padding="do_not_pad", doc_stride=128, max_question_len=64, max_seq_len=384):
def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_question_len=64, max_seq_len=None):

if max_seq_len is None:
max_seq_len = min(self.tokenizer.model_max_length, 384)
if doc_stride is None:
doc_stride = min(max_seq_len // 4, 128)

if not self.tokenizer.is_fast:
features = squad_convert_examples_to_features(
examples=[example],
Expand Down Expand Up @@ -277,7 +283,6 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=128, max_question
return_offsets_mapping=True,
return_special_tokens_mask=True,
)

# When the input is too long, it's converted in a batch of inputs with overflowing tokens
# and a stride of overlap between the inputs. If a batch of inputs is given, a special output
# "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample.
Expand Down Expand Up @@ -308,12 +313,15 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=128, max_question
token_type_ids_span_idx = (
encoded_inputs["token_type_ids"][span_idx] if "token_type_ids" in encoded_inputs else None
)
submask = p_mask[span_idx]
if isinstance(submask, np.ndarray):
submask = submask.tolist()
features.append(
SquadFeatures(
input_ids=input_ids_span_idx,
attention_mask=attention_mask_span_idx,
token_type_ids=token_type_ids_span_idx,
p_mask=p_mask[span_idx].tolist(),
p_mask=submask,
encoding=encoded_inputs[span_idx],
# We don't use the rest of the values - and actually
# for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample
Expand All @@ -330,26 +338,41 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=128, max_question
qas_id=None,
)
)
return {"features": features, "example": example}

split_features = []
for feature in features:
fw_args = {}
others = {}
model_input_names = self.tokenizer.model_input_names

for k, v in feature.__dict__.items():
if k in model_input_names:
if self.framework == "tf":
tensor = tf.constant(v)
if tensor.dtype == tf.int64:
tensor = tf.cast(tensor, tf.int32)
fw_args[k] = tf.expand_dims(tensor, 0)
elif self.framework == "pt":
tensor = torch.tensor(v)
if tensor.dtype == torch.int32:
tensor = tensor.long()
fw_args[k] = tensor.unsqueeze(0)
else:
others[k] = v
split_features.append({"fw_args": fw_args, "others": others})
return {"features": split_features, "example": example}

def _forward(self, model_inputs):
features = model_inputs["features"]
example = model_inputs["example"]
model_input_names = self.tokenizer.model_input_names
fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}

if self.framework == "tf":
fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
start, end = self.model(fw_args)[:2]
start, end = start.numpy(), end.numpy()
elif self.framework == "pt":
# Retrieve the score for the context tokens only (removing question tokens)
fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()}
# On Windows, the default int type in numpy is np.int32 so we get some non-long tensors.
fw_args = {k: v.long() if v.dtype == torch.int32 else v for (k, v) in fw_args.items()}
starts = []
ends = []
for feature in features:
fw_args = feature["fw_args"]
start, end = self.model(**fw_args)[:2]
start, end = start.cpu().numpy(), end.cpu().numpy()
return {"start": start, "end": end, "features": features, "example": example}
starts.append(start)
ends.append(end)
return {"starts": starts, "ends": ends, "features": features, "example": example}

def postprocess(
self,
Expand All @@ -360,90 +383,89 @@ def postprocess(
):
min_null_score = 1000000 # large and positive
answers = []
start_ = model_outputs["start"][0]
end_ = model_outputs["end"][0]
feature = model_outputs["features"][0]
example = model_outputs["example"]
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
undesired_tokens = np.abs(np.array(feature.p_mask) - 1)

if feature.attention_mask is not None:
undesired_tokens = undesired_tokens & feature.attention_mask

# Generate mask
undesired_tokens_mask = undesired_tokens == 0.0

# Make sure non-context indexes in the tensor cannot contribute to the softmax
start_ = np.where(undesired_tokens_mask, -10000.0, start_)
end_ = np.where(undesired_tokens_mask, -10000.0, end_)

# Normalize logits and spans to retrieve the answer
start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))

if handle_impossible_answer:
min_null_score = min(min_null_score, (start_[0] * end_[0]).item())

# Mask CLS
start_[0] = end_[0] = 0.0

starts, ends, scores = self.decode(start_, end_, top_k, max_answer_len, undesired_tokens)
if not self.tokenizer.is_fast:
char_to_word = np.array(example.char_to_word_offset)

# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
for s, e, score in zip(starts, ends, scores):
answers.append(
{
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
)
else:
# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
question_first = bool(self.tokenizer.padding_side == "right")
enc = feature.encoding

# Sometimes the max probability token is in the middle of a word so:
# - we start by finding the right word containing the token with `token_to_word`
# - then we convert this word in a character span with `word_to_chars`
sequence_index = 1 if question_first else 0
for s, e, score in zip(starts, ends, scores):
try:
start_word = enc.token_to_word(s)
end_word = enc.token_to_word(e)
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
except Exception:
# Some tokenizers don't really handle words. Keep to offsets then.
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]

answers.append(
{
"score": score.item(),
"start": start_index,
"end": end_index,
"answer": example.context_text[start_index:end_index],
}
)
for i, (feature_, start_, end_) in enumerate(
zip(model_outputs["features"], model_outputs["starts"], model_outputs["ends"])
):
feature = feature_["others"]
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
undesired_tokens = np.abs(np.array(feature["p_mask"]) - 1)

if feature_["fw_args"].get("attention_mask", None) is not None:
undesired_tokens = undesired_tokens & feature_["fw_args"]["attention_mask"].numpy()

# Generate mask
undesired_tokens_mask = undesired_tokens == 0.0

# Make sure non-context indexes in the tensor cannot contribute to the softmax
start_ = np.where(undesired_tokens_mask, -10000.0, start_)
end_ = np.where(undesired_tokens_mask, -10000.0, end_)

# Normalize logits and spans to retrieve the answer
start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))

if handle_impossible_answer:
min_null_score = min(min_null_score, (start_[0] * end_[0]).item())

# Mask CLS
start_[0, 0] = end_[0, 0] = 0.0

starts, ends, scores = self.decode(start_, end_, top_k, max_answer_len, undesired_tokens)
if not self.tokenizer.is_fast:
char_to_word = np.array(example.char_to_word_offset)

# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
for s, e, score in zip(starts, ends, scores):
token_to_orig_map = feature["token_to_orig_map"]
answers.append(
{
"score": score.item(),
"start": np.where(char_to_word == token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(example.doc_tokens[token_to_orig_map[s] : token_to_orig_map[e] + 1]),
}
)
else:
# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
question_first = bool(self.tokenizer.padding_side == "right")
enc = feature["encoding"]

# Sometimes the max probability token is in the middle of a word so:
# - we start by finding the right word containing the token with `token_to_word`
# - then we convert this word in a character span with `word_to_chars`
sequence_index = 1 if question_first else 0
for s, e, score in zip(starts, ends, scores):
try:
start_word = enc.token_to_word(s)
end_word = enc.token_to_word(e)
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
except Exception:
# Some tokenizers don't really handle words. Keep to offsets then.
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]

answers.append(
{
"score": score.item(),
"start": start_index,
"end": end_index,
"answer": example.context_text[start_index:end_index],
}
)

if handle_impossible_answer:
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})

answers = sorted(answers, key=lambda x: x["score"], reverse=True)[:top_k]
answers = sorted(answers, key=lambda x: x["score"], reverse=True)[:top_k]
if len(answers) == 1:
return answers[0]
return answers
Expand Down
5 changes: 5 additions & 0 deletions tests/test_modeling_led.py
Expand Up @@ -162,6 +162,11 @@ def get_config(self):
attention_window=self.attention_window,
)

def get_pipeline_config(self):
config = self.get_config()
config.max_position_embeddings = 100
return config

def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
global_attention_mask = torch.zeros_like(inputs_dict["input_ids"])
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_reformer.py
Expand Up @@ -189,6 +189,7 @@ def get_config(self):
def get_pipeline_config(self):
config = self.get_config()
config.vocab_size = 100
config.max_position_embeddings = 100
config.axial_pos_shape = (4, 25)
config.is_decoder = False
return config
Expand Down

0 comments on commit 9b49dc8

Please sign in to comment.