Skip to content

Commit

Permalink
Merge branch 'main' into attention-attribution\n\nNeeded to downgrade…
Browse files Browse the repository at this point in the history
… torch to ~1.12.1 again because of platform issue: pytorch/pytorch#88826
  • Loading branch information
lsickert committed Nov 23, 2022
2 parents 13fc9f3 + 7791d0e commit bb13036
Show file tree
Hide file tree
Showing 7 changed files with 840 additions and 1,604 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Expand Up @@ -12,13 +12,13 @@ jobs:
if: github.actor != 'dependabot[bot]' && github.actor != 'dependabot-preview[bot]'
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v4.3.0
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -28,14 +28,14 @@ jobs:
export PATH="$HOME/.poetry/env:$PATH"
- name: Set up cache
uses: actions/cache@v3.0.1
uses: actions/cache@v3.0.11
with:
path: .venv
key: venv-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('poetry.lock') }}
- name: Install dependencies
run: |
poetry config virtualenvs.in-project true
poetry install
make install-ci
- name: Run style checks
run: |
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Expand Up @@ -5,7 +5,7 @@ default_stages: [commit, push]

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: check-yaml
Expand Down
9 changes: 7 additions & 2 deletions Makefile
Expand Up @@ -44,15 +44,16 @@ poetry-remove:
.PHONY: add-torch-gpu
add-torch-gpu:
poetry run poe upgrade-pip
poetry run pip uninstall torch -y
poetry run poe torch-cuda11

.PHONY: install
install:
poetry install --no-dev
poetry install

.PHONY: install-dev
install-dev:
poetry install --extras all
poetry install --all-extras --with lint,docs --sync
# -poetry run mypy --install-types --non-interactive ./
poetry run pre-commit install
poetry run pre-commit autoupdate
Expand All @@ -63,6 +64,10 @@ install-gpu: install add-torch-gpu
.PHONY: install-dev-gpu
install-dev-gpu: install-dev add-torch-gpu

.PHONY: install-ci
install-ci:
poetry install --with lint

.PHONY: update-deps
update-deps:
poetry lock && poetry export --without-hashes > requirements.txt
Expand Down
6 changes: 0 additions & 6 deletions inseq/attr/feat/ops/basic_attention.py
Expand Up @@ -21,7 +21,6 @@
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import Attribution
from captum.log import log_usage
from transformers.modeling_outputs import Seq2SeqLMOutput

from ....utils.typing import MultiStepEmbeddingsTensor

Expand Down Expand Up @@ -91,7 +90,6 @@ def _merge_attention_heads(
if option == "average":
return attention.mean(1)

# TODO: test this, I feel like this method is not doing what we want here
elif option == "max":
return attention.max(1)

Expand Down Expand Up @@ -127,10 +125,6 @@ def _extract_forward_pass_args(

return forward_pass_args

def _run_forward_pass(self, **forward_args: dict) -> Seq2SeqLMOutput:

pass


class AggregatedAttention(AttentionAttribution):
"""
Expand Down

0 comments on commit bb13036

Please sign in to comment.