Skip to content

Commit

Permalink
Feature/reagent branch for ReAGent (#250)
Browse files Browse the repository at this point in the history
* Middleware for ReAgent

* ReAGent implementation

* Add reagent_core

* lint for ReAGent

* update dependencies for ReAGent

* Added caching for POSTagTokenSampler, minor fixes

* reagent: adapt codebase structure & cleanup unused classes

* reagent: adaptation for encoder-decoder models

* reagent: adapt type annotation

* reagent: implement attribute_target for encoder-decoder models

* reagent: increase the default num_probes

* Various fixes to style, imports and naming

* Bump cryptography package (fix safety)

* Finished revising initial ReAGent implementation

---------

Co-authored-by: xuan25 <xuan@xuan25.com>
Co-authored-by: Gabriele Sarti <gabriele.sarti996@gmail.com>
  • Loading branch information
3 people committed Apr 13, 2024
1 parent 112590b commit c9f2acf
Show file tree
Hide file tree
Showing 17 changed files with 951 additions and 11 deletions.
8 changes: 7 additions & 1 deletion README.md
Expand Up @@ -149,6 +149,8 @@ Use the `inseq.list_feature_attribution_methods` function to list all available

- `value_zeroing`: [Quantifying Context Mixing in Transformers](https://aclanthology.org/2023.eacl-main.245/) (Mohebbi et al. 2023)

- `reagent`: [ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models](https://arxiv.org/abs/2402.00794) (Zhao et al., 2024)

#### Step functions

Step functions are used to extract custom scores from the model at each step of the attribution process with the `step_scores` argument in `model.attribute`. They can also be used as targets for attribution methods relying on model outputs (e.g. gradient-based methods) by passing them as the `attributed_fn` argument. The following step functions are currently supported:
Expand Down Expand Up @@ -303,7 +305,10 @@ If you use Inseq in your research we suggest to include a mention to the specifi

## Research using Inseq

Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below. If you know more, please let us know or submit a pull request (*last updated: February 2024*).
Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below.

> [!TIP]
> Last update: April 2024. Please open a pull request to add your publication to the list.
<details>
<summary><b>2023</b></summary>
Expand All @@ -324,6 +329,7 @@ Inseq has been used in various research projects. A list of known publications t
<ol>
<li><a href="https://arxiv.org/abs/2401.12576">LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools</a> (Wang et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2402.00794">ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models</a> (Zhao et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2404.02421">Revisiting subword tokenization: A case study on affixal negation in large language models</a> (Truong et al., 2024)</li>
</ol>

</details>
23 changes: 22 additions & 1 deletion docs/source/main_classes/feature_attribution.rst
Expand Up @@ -90,4 +90,25 @@ Perturbation-based Attribution Methods
:members:

.. autoclass:: inseq.attr.feat.ValueZeroingAttribution
:members:
:members:

.. autoclass:: inseq.attr.feat.ReagentAttribution
:members:

.. automethod:: __init__

.. code:: python
import inseq
model = inseq.load_model(
"gpt2-medium",
"reagent",
keep_top_n=5,
stopping_condition_top_k=3,
replacing_ratio=0.3,
max_probe_steps=3000,
num_probes=8
)
out = model.attribute("Super Mario Land is a game that developed by")
out.show()
2 changes: 2 additions & 0 deletions inseq/attr/feat/__init__.py
Expand Up @@ -18,6 +18,7 @@
LimeAttribution,
OcclusionAttribution,
PerturbationAttributionRegistry,
ReagentAttribution,
ValueZeroingAttribution,
)

Expand All @@ -43,4 +44,5 @@
"SequentialIntegratedGradientsAttribution",
"ValueZeroingAttribution",
"PerturbationAttributionRegistry",
"ReagentAttribution",
]
2 changes: 2 additions & 0 deletions inseq/attr/feat/ops/__init__.py
@@ -1,6 +1,7 @@
from .discretized_integrated_gradients import DiscretetizedIntegratedGradients
from .lime import Lime
from .monotonic_path_builder import MonotonicPathBuilder
from .reagent import Reagent
from .sequential_integrated_gradients import SequentialIntegratedGradients
from .value_zeroing import ValueZeroing

Expand All @@ -9,5 +10,6 @@
"MonotonicPathBuilder",
"ValueZeroing",
"Lime",
"Reagent",
"SequentialIntegratedGradients",
]
134 changes: 134 additions & 0 deletions inseq/attr/feat/ops/reagent.py
@@ -0,0 +1,134 @@
from typing import TYPE_CHECKING, Any, Union

import torch
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from torch import Tensor
from typing_extensions import override

from ....utils.typing import InseqAttribution
from .reagent_core import (
AggregateRationalizer,
DeltaProbImportanceScoreEvaluator,
POSTagTokenSampler,
TopKStoppingConditionEvaluator,
UniformTokenReplacer,
)

if TYPE_CHECKING:
from ....models import HuggingfaceModel


class Reagent(InseqAttribution):
r"""Recursive attribution generator (ReAGent) method.
Measures importance as the drop in prediction probability produced by replacing a token with a plausible
alternative predicted by a LM.
Reference implementation:
`ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models
<https://arxiv.org/abs/2402.00794>`__
Args:
forward_func (callable): The forward function of the model or any modification of it
keep_top_n (int): If set to a value greater than 0, the top n tokens based on their importance score will be
kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``.
keep_ratio (float): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be
replaced instead of being kept.
stopping_condition_top_k (int): Threshold indicating that the stop condition achieved when the predicted target
exist in top k predictions
replacing_ratio (float): replacing ratio of tokens for probing
max_probe_steps (int): max_probe_steps
num_probes (int): number of probes in parallel
Example:
```
import inseq
model = inseq.load_model("gpt2-medium", "reagent",
keep_top_n=5,
stopping_condition_top_k=3,
replacing_ratio=0.3,
max_probe_steps=3000,
num_probes=8
)
out = model.attribute("Super Mario Land is a game that developed by")
out.show()
```
"""

def __init__(
self,
attribution_model: "HuggingfaceModel",
keep_top_n: int = 5,
keep_ratio: float = None,
invert_keep: bool = False,
stopping_condition_top_k: int = 3,
replacing_ratio: float = 0.3,
max_probe_steps: int = 3000,
num_probes: int = 16,
) -> None:
super().__init__(attribution_model)

model = attribution_model.model
tokenizer = attribution_model.tokenizer
model_name = attribution_model.model_name

sampler = POSTagTokenSampler(tokenizer=tokenizer, identifier=model_name, device=attribution_model.device)
stopping_condition_evaluator = TopKStoppingConditionEvaluator(
model=model,
sampler=sampler,
top_k=stopping_condition_top_k,
keep_top_n=keep_top_n,
keep_ratio=keep_ratio,
invert_keep=invert_keep,
)
importance_score_evaluator = DeltaProbImportanceScoreEvaluator(
model=model,
tokenizer=tokenizer,
token_replacer=UniformTokenReplacer(sampler=sampler, ratio=replacing_ratio),
stopping_condition_evaluator=stopping_condition_evaluator,
max_steps=max_probe_steps,
)

self.rationalizer = AggregateRationalizer(
importance_score_evaluator=importance_score_evaluator,
batch_size=num_probes,
overlap_threshold=0,
overlap_strict_pos=True,
keep_top_n=keep_top_n,
keep_ratio=keep_ratio,
)

@override
def attribute( # type: ignore
self,
inputs: TensorOrTupleOfTensorsGeneric,
_target: TargetType = None,
additional_forward_args: Any = None,
) -> Union[
TensorOrTupleOfTensorsGeneric,
tuple[TensorOrTupleOfTensorsGeneric, Tensor],
]:
"""Implement attribute"""
# encoder-decoder
if self.forward_func.is_encoder_decoder:
# with target-side attribution
if len(inputs) > 1:
self.rationalizer(
additional_forward_args[0], additional_forward_args[2], additional_forward_args[1], True
)
mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0)
res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2])
return (
res[:, : additional_forward_args[0].shape[1], :],
res[:, additional_forward_args[0].shape[1] :, :],
)
# source-side only
else:
self.rationalizer(additional_forward_args[1], additional_forward_args[3], additional_forward_args[2])
# decoder-only
self.rationalizer(additional_forward_args[0], additional_forward_args[1])
mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0)
res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2])
return (res,)
13 changes: 13 additions & 0 deletions inseq/attr/feat/ops/reagent_core/__init__.py
@@ -0,0 +1,13 @@
from .importance_score_evaluator import DeltaProbImportanceScoreEvaluator
from .rationalizer import AggregateRationalizer
from .stopping_condition_evaluator import TopKStoppingConditionEvaluator
from .token_replacer import UniformTokenReplacer
from .token_sampler import POSTagTokenSampler

__all__ = [
"DeltaProbImportanceScoreEvaluator",
"AggregateRationalizer",
"TopKStoppingConditionEvaluator",
"UniformTokenReplacer",
"POSTagTokenSampler",
]

0 comments on commit c9f2acf

Please sign in to comment.