Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic Attention attribution #148

Merged
merged 35 commits into from Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b9ccbf2
added jupyterlab dependency (for easier testing)
lsickert Sep 23, 2022
08d5dbd
initial commit attention methods, added output_attentions parameter t…
lsickert Oct 17, 2022
57ae54a
Merge branch 'main' into attention-attribution
lsickert Oct 19, 2022
544321c
added basic attention method stubs\n added attention method registry
lsickert Oct 24, 2022
d0e859f
reverted changes to output generation (forward pass done inside attri…
lsickert Nov 21, 2022
a2a2021
first working version of basic attention methods
lsickert Nov 21, 2022
7b3c4fd
fixed rounding of values in cli output
lsickert Nov 22, 2022
13fc9f3
added documentation to most methods and generalized functions
lsickert Nov 22, 2022
bb13036
Merge branch 'main' into attention-attribution\n\nNeeded to downgrade…
lsickert Nov 23, 2022
9340343
removed python 3.11 build target
lsickert Nov 24, 2022
3cfd706
fix safety warnings
lsickert Nov 24, 2022
2765d63
set correct python version in pyproject.toml
lsickert Nov 25, 2022
4dd442f
regenerated requirements without 3.11
lsickert Nov 25, 2022
b14407c
Merge branch 'main' into attention-attribution, quick fix for mps issue
lsickert Dec 9, 2022
6a72166
Merge branch 'main' into attention-attribution
lsickert Dec 12, 2022
6535b09
merge branch 'main' into attention-attribution
lsickert Jan 2, 2023
624435e
update deps after merge
lsickert Jan 2, 2023
06f89a8
include 3.11 as build target
lsickert Jan 2, 2023
7bcbe92
fix different attribution_step argument formatting
lsickert Jan 2, 2023
b2fc73c
added basic decoder-only support
lsickert Jan 2, 2023
b044b4c
fixed output error for decoder only models
lsickert Jan 3, 2023
8c344b7
removed unnecessary convergence delta references in attention attribu…
lsickert Jan 3, 2023
f51cf25
allow negative indices when selecting a specific attention head for a…
lsickert Jan 4, 2023
c6a9e70
added missing negation to head checking
lsickert Jan 4, 2023
6c9cfae
fixed last_layer_attention attribution
lsickert Jan 4, 2023
b78bcc1
use custom format_attribute_args function for attention methods
lsickert Jan 9, 2023
d27f1c3
always use decoder_input_embeds in forward output
lsickert Jan 9, 2023
cacaa31
reworked LastLayerAttention to work with any single layer and allow a…
lsickert Jan 9, 2023
a8d5264
Minor bugfixes and version bumps
gsarti Jan 10, 2023
966f63c
Generalized attention attribution
gsarti Jan 10, 2023
1301a02
updated documentation and added 'min' aggregation function
lsickert Jan 13, 2023
914ee8f
Tests, typing fix, additional checks
gsarti Jan 14, 2023
7c825ad
Fix style
gsarti Jan 14, 2023
f6f0a64
added tests for attention utils
lsickert Jan 15, 2023
f40f63b
classmethod -> staticmethod where possible
gsarti Jan 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/main_classes/feature_attribution.rst
Expand Up @@ -61,3 +61,14 @@ Layer Attribution Methods

.. autoclass:: inseq.attr.feat.LayerDeepLiftAttribution
:members:


Attention Attribution Methods
-----------------------------------------------------------------------------------------------------------------------

.. autoclass:: inseq.attr.feat.AttentionAttributionRegistry
:members:


.. autoclass:: inseq.attr.feat.AttentionAttribution
:members:
7 changes: 3 additions & 4 deletions inseq/attr/feat/__init__.py
@@ -1,4 +1,4 @@
from .attention_attribution import AggregatedAttentionAtribution, AttentionAtribution, SingleLayerAttentionAttribution
from .attention_attribution import AttentionAttribution, AttentionAttributionRegistry
from .attribution_utils import STEP_SCORES_MAP, extract_args, join_token_ids, list_step_scores, register_step_score
from .feature_attribution import FeatureAttribution, list_feature_attribution_methods
from .gradient_attribution import (
Expand Down Expand Up @@ -31,7 +31,6 @@
"LayerIntegratedGradientsAttribution",
"LayerGradientXActivationAttribution",
"LayerDeepLiftAttribution",
"AttentionAtribution",
"AggregatedAttentionAtribution",
"SingleLayerAttentionAttribution",
"AttentionAttributionRegistry",
"AttentionAttribution",
]
86 changes: 39 additions & 47 deletions inseq/attr/feat/attention_attribution.py
Expand Up @@ -19,17 +19,17 @@

from ...data import Batch, EncoderDecoderBatch, FeatureAttributionStepOutput
from ...utils import Registry, pretty_tensor
from ...utils.typing import ModelIdentifier, SingleScorePerStepTensor, TargetIdsTensor
from ...utils.typing import SingleScorePerStepTensor, TargetIdsTensor
from ..attribution_decorators import set_hook, unset_hook
from .attribution_utils import get_source_target_attributions
from .attribution_utils import STEP_SCORES_MAP, get_source_target_attributions
from .feature_attribution import FeatureAttribution
from .ops import AggregatedAttention, SingleLayerAttention
from .ops import Attention


logger = logging.getLogger(__name__)


class AttentionAtribution(FeatureAttribution, Registry):
class AttentionAttributionRegistry(FeatureAttribution, Registry):
r"""Attention-based attribution method registry."""

@set_hook
Expand Down Expand Up @@ -69,6 +69,12 @@ def format_attribute_args(
:obj:`dict`: A dictionary containing the formatted attribution arguments.
"""
logger.debug(f"batch: {batch},\ntarget_ids: {pretty_tensor(target_ids, lpad=4)}")
if attributed_fn != STEP_SCORES_MAP[self.attribution_model.default_attributed_fn_id]:
logger.warning(
"Attention-based attribution methods are output agnostic, since they do not rely on specific output"
" targets to compute input saliency. As such, using a custom attributed function for attention"
" attribution methods does not produce any effect of the method's results."
)
attribute_fn_args = {
"batch": batch,
"additional_forward_args": (
Expand Down Expand Up @@ -114,53 +120,39 @@ def attribute_step(
step_scores={},
)

@classmethod
def load(
cls,
method_name: str,
attribution_model=None,
model_name_or_path: Union[ModelIdentifier, None] = None,
**kwargs,
) -> "FeatureAttribution":
from inseq import AttributionModel

if model_name_or_path is None == attribution_model is None: # noqa
raise RuntimeError(
"Only one among an initialized model and a model identifier "
"must be defined when loading the attribution method."
)
if model_name_or_path:
attribution_model = AttributionModel.load(model_name_or_path)
model_name_or_path = None

if not attribution_model.model.config.output_attentions:
lsickert marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
"Attention-based attribution methods require the `output_attentions` parameter to be set on the model."
)
return super().load(method_name, attribution_model, model_name_or_path, **kwargs)


class AggregatedAttentionAtribution(AttentionAtribution):
"""
Aggregated attention attribution method.
Attention values of all layers are averaged.
"""

method_name = "aggregated_attention"

def __init__(self, attribution_model, **kwargs):
super().__init__(attribution_model)
self.method = AggregatedAttention(attribution_model)


class SingleLayerAttentionAttribution(AttentionAtribution):
class AttentionAttribution(AttentionAttributionRegistry):
"""
Single-Layer attention attribution method.
Only the raw attention of the last hidden layer is retrieved.
The basic attention attribution method, which retrieves the attention weights from the model.

Attribute Args:
aggregate_heads_fn (:obj:`str` or :obj:`callable`): The method to use for aggregating across heads.
Can be one of `average` (default if heads is tuple or None), `max`, or `single` (default if heads is
int), or a custom function defined by the user.
aggregate_layers_fn (:obj:`str` or :obj:`callable`): The method to use for aggregating across layers.
Can be one of `average` (default if layers is tuple), `max`, or `single` (default if layers is int or
None), or a custom function defined by the user.
heads (:obj:`int` or :obj:`tuple[int, int]` or :obj:`list(int)`, optional): If a single value is specified,
the head at the corresponding index is used. If a tuple of two indices is specified, all heads between
the indices will be aggregated using aggregate_fn. If a list of indices is specified, the respective
heads will be used for aggregation. If aggregate_fn is "single", a head must be specified.
Otherwise, all heads are passed to aggregate_fn by default.
layers (:obj:`int` or :obj:`tuple[int, int]` or :obj:`list(int)`, optional): If a single value is specified
, the layer at the corresponding index is used. If a tuple of two indices is specified, all layers
among the indices will be aggregated using aggregate_fn. If a list of indices is specified, the
respective layers will be used for aggregation. If aggregate_fn is "single", the last layer is
used by default. Otherwise, all available layers are passed to aggregate_fn by default.

Example:

- ``model.attribute(src)`` will return the average attention for all heads of the last layer.
- ``model.attribute(src, heads=0)`` will return the attention weights for the first head of the last layer.
- ``model.attribute(src, heads=(0, 5), aggregate_heads_fn="max", layers=[0, 2, 7])`` will return the maximum
attention weights for the first 5 heads averaged across the first, third, and eighth layers.
"""

method_name = "single_layer_attention"
method_name = "attention"

def __init__(self, attribution_model, **kwargs):
super().__init__(attribution_model)
self.method = SingleLayerAttention(attribution_model)
self.method = Attention(attribution_model)
4 changes: 2 additions & 2 deletions inseq/attr/feat/ops/__init__.py
@@ -1,6 +1,6 @@
from .basic_attention import AggregatedAttention, SingleLayerAttention
from .basic_attention import Attention
from .discretized_integrated_gradients import DiscretetizedIntegratedGradients
from .monotonic_path_builder import MonotonicPathBuilder


__all__ = ["DiscretetizedIntegratedGradients", "MonotonicPathBuilder", "AggregatedAttention", "SingleLayerAttention"]
__all__ = ["DiscretetizedIntegratedGradients", "MonotonicPathBuilder", "Attention"]