/
attention_attribution.py
145 lines (122 loc) 路 6.09 KB
/
attention_attribution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright 2021 The Inseq Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Attention-based feature attribution methods. """
from typing import Any, Callable, Dict, Union
import logging
from ...data import EncoderDecoderBatch, FeatureAttributionStepOutput
from ...utils import Registry, pretty_tensor
from ...utils.typing import ModelIdentifier, SingleScorePerStepTensor, TargetIdsTensor
from ..attribution_decorators import set_hook, unset_hook
from .feature_attribution import FeatureAttribution
from .ops import AggregatedAttention, LastLayerAttention
logger = logging.getLogger(__name__)
class AttentionAtribution(FeatureAttribution, Registry):
r"""Attention-based attribution method registry."""
@set_hook
def hook(self, **kwargs):
pass
@unset_hook
def unhook(self, **kwargs):
pass
def attribute_step(
self,
batch: EncoderDecoderBatch,
target_ids: TargetIdsTensor,
attributed_fn: Callable[..., SingleScorePerStepTensor],
attribute_target: bool = False,
attribution_args: Dict[str, Any] = {},
attributed_fn_args: Dict[str, Any] = {},
) -> FeatureAttributionStepOutput:
r"""
Performs a single attribution step for the specified target_ids,
given sources and targets in the batch.
Abstract method, must be implemented by subclasses.
Args:
batch (:class:`~inseq.data.EncoderDecoderBatch`): The batch of sequences on which attribution is performed.
target_ids (:obj:`torch.Tensor`): Target token ids of size `(batch_size)` corresponding to tokens
for which the attribution step must be performed.
attributed_fn (:obj:`Callable[..., SingleScorePerStepTensor]`): The function of model outputs
representing what should be attributed (e.g. output probits of model best prediction after softmax).
The parameter must be a function that taking multiple keyword arguments and returns a :obj:`tensor`
of size (batch_size,). If not provided, the default attributed function for the model will be used
(change attribution_model.default_attributed_fn_id).
attribute_target (:obj:`bool`, optional): Whether to attribute the target prefix or not. Defaults to False.
attribution_args (:obj:`dict`, `optional`): Additional arguments to pass to the attribution method.
Defaults to {}.
attributed_fn_args (:obj:`dict`, `optional`): Additional arguments to pass to the attributed function.
Defaults to {}.
Returns:
:class:`~inseq.data.FeatureAttributionStepOutput`: A dataclass containing attribution tensors for source
and target attributions of size `(batch_size, source_length)` and `(batch_size, prefix length)`.
(target optional if attribute_target=True), plus batch information and any step score present.
"""
logger.debug(f"batch: {batch},\ntarget_ids: {pretty_tensor(target_ids, lpad=4)}")
attribute_fn_args = self.format_attribute_args(
batch, target_ids, attributed_fn, attribute_target, attributed_fn_args, **attribution_args
)
attr = self.method.attribute(**attribute_fn_args, **attribution_args)
deltas = None
if (
attribution_args.get("return_convergence_delta", False)
and hasattr(self.method, "has_convergence_delta")
and self.method.has_convergence_delta()
):
attr, deltas = attr
return FeatureAttributionStepOutput(
source_attributions=attr if not isinstance(attr, tuple) else attr[0],
target_attributions=None
if not isinstance(attr, tuple) or (isinstance(attr, tuple) and len(attr) == 1)
else attr[1],
step_scores={"deltas": deltas} if deltas is not None else {},
)
@classmethod
def load(
cls,
method_name: str,
attribution_model=None,
model_name_or_path: Union[ModelIdentifier, None] = None,
**kwargs,
) -> "FeatureAttribution":
from inseq import AttributionModel
if model_name_or_path is None == attribution_model is None: # noqa
raise RuntimeError(
"Only one among an initialized model and a model identifier "
"must be defined when loading the attribution method."
)
if model_name_or_path:
attribution_model = AttributionModel.load(model_name_or_path)
model_name_or_path = None
if not attribution_model.model.config.output_attentions:
raise RuntimeError(
"Attention-based attribution methods require the `output_attentions` parameter to be set on the model."
)
return super().load(method_name, attribution_model, model_name_or_path, **kwargs)
class AggregatedAttentionAtribution(AttentionAtribution):
"""
Aggregated attention attribution method.
Attention values of all layers are averaged.
"""
method_name = "aggregated_attention"
def __init__(self, attribution_model, **kwargs):
super().__init__(attribution_model)
self.method = AggregatedAttention(attribution_model)
class LastLayerAttentionAttribution(AttentionAtribution):
"""
Last-Layer attention attribution method.
Only the raw attention of the last hidden layer is retrieved.
"""
method_name = "last_layer_attention"
def __init__(self, attribution_model, **kwargs):
super().__init__(attribution_model)
self.method = LastLayerAttention(attribution_model)