Skip to content

Commit

Permalink
added documentation to most methods and generalized functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lsickert committed Nov 22, 2022
1 parent 7b3c4fd commit 13fc9f3
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 45 deletions.
8 changes: 8 additions & 0 deletions inseq/attr/feat/attention_attribution.py
Expand Up @@ -120,6 +120,10 @@ def load(


class AggregatedAttentionAtribution(AttentionAtribution):
"""
Aggregated attention attribution method.
Attention values of all layers are averaged.
"""

method_name = "aggregated_attention"

Expand All @@ -129,6 +133,10 @@ def __init__(self, attribution_model, **kwargs):


class LastLayerAttentionAttribution(AttentionAtribution):
"""
Last-Layer attention attribution method.
Only the raw attention of the last hidden layer is retrieved.
"""

method_name = "last_layer_attention"

Expand Down
13 changes: 0 additions & 13 deletions inseq/attr/feat/attribution_utils.py
Expand Up @@ -231,16 +231,3 @@ def register_step_score(
if agg_name not in DEFAULT_ATTRIBUTION_AGGREGATE_DICT["step_scores"]:
DEFAULT_ATTRIBUTION_AGGREGATE_DICT["step_scores"][agg_name] = {}
DEFAULT_ATTRIBUTION_AGGREGATE_DICT["step_scores"][agg_name][identifier] = agg_fn


def num_attention_heads(attention: torch.Tensor) -> int:
"""
Returns the number of heads an attention tensor has.
Args:
attention: an attention tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`
Returns:
`int`: The number of attention heads
"""
return attention.size(1)
100 changes: 68 additions & 32 deletions inseq/attr/feat/ops/basic_attention.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Tuple, Union
from typing import Any, Optional, Tuple, Union

import logging

Expand All @@ -21,9 +21,9 @@
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import Attribution
from captum.log import log_usage
from transformers.modeling_outputs import Seq2SeqLMOutput

from ....utils.typing import MultiStepEmbeddingsTensor
from ..attribution_utils import num_attention_heads


logger = logging.getLogger(__name__)
Expand All @@ -41,17 +41,43 @@ class AttentionAttribution(Attribution):
def has_convergence_delta(self) -> bool:
return False

def _merge_attention_heads(self, attention: torch.Tensor, option: str = "average", head: int = None):
def _num_attention_heads(self, attention: torch.Tensor) -> int:
"""
Returns the number of heads an attention tensor has.
num_heads = num_attention_heads(attention[0])
Args:
attention: an attention tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`
Returns:
`int`: The number of attention heads
"""
return attention.size(1)

def _merge_attention_heads(
self, attention: torch.Tensor, option: str = "average", head: int = None
) -> torch.Tensor:

"""
Merges the attention values of the different heads together by either averaging across them,
selecting the head with the maximal values or selecting a specific attention head.
Args:
attention: an attention tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`
option: The method to use for merging. Should be one of `average` (default), `max`, or `single`
head: The index of the head to use, when option is set to `single`
Returns:
`torch.Tensor`: The attention tensor with its attention heads merged.
"""
num_heads = self._num_attention_heads(attention[0])

if option == "single" and head is None:
raise RuntimeError("An attention head has to be specified when choosing single-head attention attribution")

if head is not None:
if head > num_heads:
raise RuntimeError(
"Attention head index for attribution too high. " f"The model only has {num_heads} heads."
f"Attention head index for attribution too high. The model only has {num_heads} heads."
)

if option != "single":
Expand All @@ -63,22 +89,48 @@ def _merge_attention_heads(self, attention: torch.Tensor, option: str = "average
return attention.select(1, head)

if option == "average":
return attention.mean(1, keepdim=True)
return attention.mean(1)

# TODO: test this, I feel like this method is not doing what we want here
elif option == "max":
return attention.max(1, keepdim=True)
return attention.max(1)

else:
raise RuntimeError(
"Invalid merge method for attention heads specified. "
"Valid methods are: `average`, `max` and `single`"
)

def _get_batch_size(self, attention: torch.Tensor):
"""returns the batch size of a tensor of shape `batch_size, heads, seq, seq`"""
def _get_batch_size(self, attention: torch.Tensor) -> int:
"""returns the batch size of a tensor of shape `(batch_size, heads, seq, seq)`"""
return attention.size(0)

def _extract_forward_pass_args(
self, inputs: MultiStepEmbeddingsTensor, forward_args: Optional[Tuple], is_target_attr: bool
) -> dict:
"""extracts the arguments needed for a standard forward pass
from the `inputs` and `additional_forward_args` parameters used by Captum"""

use_embeddings = forward_args[6] if is_target_attr else forward_args[7]

forward_pass_args = {
"attention_mask": forward_args[4] if is_target_attr else forward_args[5],
"decoder_attention_mask": forward_args[5] if is_target_attr else forward_args[6],
}

if use_embeddings:
forward_pass_args["inputs_embeds"] = inputs[0]
forward_pass_args["decoder_inputs_embeds"] = inputs[1] if is_target_attr else forward_args[0]
else:
forward_pass_args["input_ids"] = forward_args[0] if is_target_attr else forward_args[1]
forward_pass_args["decoder_input_ids"] = forward_args[1] if is_target_attr else forward_args[2]

return forward_pass_args

def _run_forward_pass(self, **forward_args: dict) -> Seq2SeqLMOutput:

pass


class AggregatedAttention(AttentionAttribution):
"""
Expand All @@ -100,28 +152,20 @@ def attribute(

is_target_attribution = True if len(inputs) > 1 else False

input_ids = additional_forward_args[0] if is_target_attribution else additional_forward_args[1]
decoder_input_ids = additional_forward_args[1] if is_target_attribution else additional_forward_args[2]
attention_mask = additional_forward_args[4] if is_target_attribution else additional_forward_args[5]
decoder_attention_mask = additional_forward_args[5] if is_target_attribution else additional_forward_args[6]
forward_pass_args = self._extract_forward_pass_args(inputs, additional_forward_args, is_target_attribution)

outputs = self.forward_func.model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
outputs = self.forward_func.model(**forward_pass_args)

cross_aggregation = torch.stack(outputs.cross_attentions).mean(0)
cross_aggregation = self._merge_attention_heads(cross_aggregation, merge_head_option, use_head)
cross_aggregation = torch.squeeze(cross_aggregation, 1).select(1, -1)
cross_aggregation = cross_aggregation.select(1, -1)

attributions = (cross_aggregation,)

if is_target_attribution:
decoder_aggregation = torch.stack(outputs.decoder_attentions).mean(0)
decoder_aggregation = self._merge_attention_heads(decoder_aggregation, merge_head_option, use_head)
decoder_aggregation = torch.squeeze(decoder_aggregation, 1).select(1, -1)
decoder_aggregation = decoder_aggregation.select(1, -1)

attributions = attributions + (decoder_aggregation,)

Expand All @@ -148,17 +192,9 @@ def attribute(

is_target_attribution = True if len(inputs) > 1 else False

input_ids = additional_forward_args[0] if is_target_attribution else additional_forward_args[1]
decoder_input_ids = additional_forward_args[1] if is_target_attribution else additional_forward_args[2]
attention_mask = additional_forward_args[4] if is_target_attribution else additional_forward_args[5]
decoder_attention_mask = additional_forward_args[5] if is_target_attribution else additional_forward_args[6]

outputs = self.forward_func.model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
forward_pass_args = self._extract_forward_pass_args(inputs, additional_forward_args, is_target_attribution)

outputs = self.forward_func.model(**forward_pass_args)

last_layer_cross = outputs.cross_attentions[-1]
last_layer_cross = self._merge_attention_heads(last_layer_cross, merge_head_option, use_head)
Expand Down

0 comments on commit 13fc9f3

Please sign in to comment.