From 13fc9f3ca93604743dba9ad98af8c2c542385fa2 Mon Sep 17 00:00:00 2001 From: Ludwig Sickert Date: Tue, 22 Nov 2022 23:51:52 +0100 Subject: [PATCH] added documentation to most methods and generalized functions --- inseq/attr/feat/attention_attribution.py | 8 ++ inseq/attr/feat/attribution_utils.py | 13 --- inseq/attr/feat/ops/basic_attention.py | 100 +++++++++++++++-------- 3 files changed, 76 insertions(+), 45 deletions(-) diff --git a/inseq/attr/feat/attention_attribution.py b/inseq/attr/feat/attention_attribution.py index dbad75d4..25c2069c 100644 --- a/inseq/attr/feat/attention_attribution.py +++ b/inseq/attr/feat/attention_attribution.py @@ -120,6 +120,10 @@ def load( class AggregatedAttentionAtribution(AttentionAtribution): + """ + Aggregated attention attribution method. + Attention values of all layers are averaged. + """ method_name = "aggregated_attention" @@ -129,6 +133,10 @@ def __init__(self, attribution_model, **kwargs): class LastLayerAttentionAttribution(AttentionAtribution): + """ + Last-Layer attention attribution method. + Only the raw attention of the last hidden layer is retrieved. + """ method_name = "last_layer_attention" diff --git a/inseq/attr/feat/attribution_utils.py b/inseq/attr/feat/attribution_utils.py index d60766db..677feb46 100644 --- a/inseq/attr/feat/attribution_utils.py +++ b/inseq/attr/feat/attribution_utils.py @@ -231,16 +231,3 @@ def register_step_score( if agg_name not in DEFAULT_ATTRIBUTION_AGGREGATE_DICT["step_scores"]: DEFAULT_ATTRIBUTION_AGGREGATE_DICT["step_scores"][agg_name] = {} DEFAULT_ATTRIBUTION_AGGREGATE_DICT["step_scores"][agg_name][identifier] = agg_fn - - -def num_attention_heads(attention: torch.Tensor) -> int: - """ - Returns the number of heads an attention tensor has. - - Args: - attention: an attention tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)` - - Returns: - `int`: The number of attention heads - """ - return attention.size(1) diff --git a/inseq/attr/feat/ops/basic_attention.py b/inseq/attr/feat/ops/basic_attention.py index b9e90c62..b775c072 100644 --- a/inseq/attr/feat/ops/basic_attention.py +++ b/inseq/attr/feat/ops/basic_attention.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Tuple, Union +from typing import Any, Optional, Tuple, Union import logging @@ -21,9 +21,9 @@ from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import Attribution from captum.log import log_usage +from transformers.modeling_outputs import Seq2SeqLMOutput from ....utils.typing import MultiStepEmbeddingsTensor -from ..attribution_utils import num_attention_heads logger = logging.getLogger(__name__) @@ -41,9 +41,35 @@ class AttentionAttribution(Attribution): def has_convergence_delta(self) -> bool: return False - def _merge_attention_heads(self, attention: torch.Tensor, option: str = "average", head: int = None): + def _num_attention_heads(self, attention: torch.Tensor) -> int: + """ + Returns the number of heads an attention tensor has. - num_heads = num_attention_heads(attention[0]) + Args: + attention: an attention tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)` + + Returns: + `int`: The number of attention heads + """ + return attention.size(1) + + def _merge_attention_heads( + self, attention: torch.Tensor, option: str = "average", head: int = None + ) -> torch.Tensor: + + """ + Merges the attention values of the different heads together by either averaging across them, + selecting the head with the maximal values or selecting a specific attention head. + + Args: + attention: an attention tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)` + option: The method to use for merging. Should be one of `average` (default), `max`, or `single` + head: The index of the head to use, when option is set to `single` + + Returns: + `torch.Tensor`: The attention tensor with its attention heads merged. + """ + num_heads = self._num_attention_heads(attention[0]) if option == "single" and head is None: raise RuntimeError("An attention head has to be specified when choosing single-head attention attribution") @@ -51,7 +77,7 @@ def _merge_attention_heads(self, attention: torch.Tensor, option: str = "average if head is not None: if head > num_heads: raise RuntimeError( - "Attention head index for attribution too high. " f"The model only has {num_heads} heads." + f"Attention head index for attribution too high. The model only has {num_heads} heads." ) if option != "single": @@ -63,11 +89,11 @@ def _merge_attention_heads(self, attention: torch.Tensor, option: str = "average return attention.select(1, head) if option == "average": - return attention.mean(1, keepdim=True) + return attention.mean(1) # TODO: test this, I feel like this method is not doing what we want here elif option == "max": - return attention.max(1, keepdim=True) + return attention.max(1) else: raise RuntimeError( @@ -75,10 +101,36 @@ def _merge_attention_heads(self, attention: torch.Tensor, option: str = "average "Valid methods are: `average`, `max` and `single`" ) - def _get_batch_size(self, attention: torch.Tensor): - """returns the batch size of a tensor of shape `batch_size, heads, seq, seq`""" + def _get_batch_size(self, attention: torch.Tensor) -> int: + """returns the batch size of a tensor of shape `(batch_size, heads, seq, seq)`""" return attention.size(0) + def _extract_forward_pass_args( + self, inputs: MultiStepEmbeddingsTensor, forward_args: Optional[Tuple], is_target_attr: bool + ) -> dict: + """extracts the arguments needed for a standard forward pass + from the `inputs` and `additional_forward_args` parameters used by Captum""" + + use_embeddings = forward_args[6] if is_target_attr else forward_args[7] + + forward_pass_args = { + "attention_mask": forward_args[4] if is_target_attr else forward_args[5], + "decoder_attention_mask": forward_args[5] if is_target_attr else forward_args[6], + } + + if use_embeddings: + forward_pass_args["inputs_embeds"] = inputs[0] + forward_pass_args["decoder_inputs_embeds"] = inputs[1] if is_target_attr else forward_args[0] + else: + forward_pass_args["input_ids"] = forward_args[0] if is_target_attr else forward_args[1] + forward_pass_args["decoder_input_ids"] = forward_args[1] if is_target_attr else forward_args[2] + + return forward_pass_args + + def _run_forward_pass(self, **forward_args: dict) -> Seq2SeqLMOutput: + + pass + class AggregatedAttention(AttentionAttribution): """ @@ -100,28 +152,20 @@ def attribute( is_target_attribution = True if len(inputs) > 1 else False - input_ids = additional_forward_args[0] if is_target_attribution else additional_forward_args[1] - decoder_input_ids = additional_forward_args[1] if is_target_attribution else additional_forward_args[2] - attention_mask = additional_forward_args[4] if is_target_attribution else additional_forward_args[5] - decoder_attention_mask = additional_forward_args[5] if is_target_attribution else additional_forward_args[6] + forward_pass_args = self._extract_forward_pass_args(inputs, additional_forward_args, is_target_attribution) - outputs = self.forward_func.model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) + outputs = self.forward_func.model(**forward_pass_args) cross_aggregation = torch.stack(outputs.cross_attentions).mean(0) cross_aggregation = self._merge_attention_heads(cross_aggregation, merge_head_option, use_head) - cross_aggregation = torch.squeeze(cross_aggregation, 1).select(1, -1) + cross_aggregation = cross_aggregation.select(1, -1) attributions = (cross_aggregation,) if is_target_attribution: decoder_aggregation = torch.stack(outputs.decoder_attentions).mean(0) decoder_aggregation = self._merge_attention_heads(decoder_aggregation, merge_head_option, use_head) - decoder_aggregation = torch.squeeze(decoder_aggregation, 1).select(1, -1) + decoder_aggregation = decoder_aggregation.select(1, -1) attributions = attributions + (decoder_aggregation,) @@ -148,17 +192,9 @@ def attribute( is_target_attribution = True if len(inputs) > 1 else False - input_ids = additional_forward_args[0] if is_target_attribution else additional_forward_args[1] - decoder_input_ids = additional_forward_args[1] if is_target_attribution else additional_forward_args[2] - attention_mask = additional_forward_args[4] if is_target_attribution else additional_forward_args[5] - decoder_attention_mask = additional_forward_args[5] if is_target_attribution else additional_forward_args[6] - - outputs = self.forward_func.model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) + forward_pass_args = self._extract_forward_pass_args(inputs, additional_forward_args, is_target_attribution) + + outputs = self.forward_func.model(**forward_pass_args) last_layer_cross = outputs.cross_attentions[-1] last_layer_cross = self._merge_attention_heads(last_layer_cross, merge_head_option, use_head)