Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ruff stylechecking #159

Merged
merged 4 commits into from Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 5 additions & 22 deletions .pre-commit-config.yaml
Expand Up @@ -18,22 +18,6 @@ repos:
- id: end-of-file-fixer
exclude: LICENSE

- repo: local
hooks:
- id: pyupgrade
name: pyupgrade
entry: poetry run pyupgrade --py38-plus
types: [python]
language: system

- repo: local
hooks:
- id: isort
name: isort
entry: poetry run isort --settings-path pyproject.toml
types: [python]
language: system

- repo: local
hooks:
- id: black
Expand All @@ -42,13 +26,12 @@ repos:
types: [python]
language: system

- repo: local
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.227'
hooks:
- id: flake8
name: flake8
entry: poetry run flake8 --config setup.cfg
types: [python]
language: system
- id: ruff
# Respect `exclude` and `extend-exclude` settings.
args: ["--force-exclude"]

- repo: local
hooks:
Expand Down
9 changes: 3 additions & 6 deletions Makefile
Expand Up @@ -78,26 +78,23 @@ update-deps:
#* Linting
.PHONY: check-style
check-style:
poetry run isort --diff --check-only --settings-path pyproject.toml ./
poetry run black --diff --check --config pyproject.toml ./
poetry run ruff --no-fix --config pyproject.toml ./
# poetry run darglint --verbosity 2 inseq tests
poetry run flake8 --config setup.cfg ./
# poetry run mypy --config-file pyproject.toml ./

.PHONY: fix-style
fix-style:
poetry run pyupgrade --exit-zero-even-if-changed --py38-plus **/*.py
poetry run isort --settings-path pyproject.toml ./
poetry run black --config pyproject.toml ./
poetry run ruff --config pyproject.toml ./

.PHONY: check-safety
check-safety:
poetry check
poetry run safety check --full-report -i 51499 -i 51457
poetry run bandit -ll --recursive inseq tests

.PHONY: lint
lint: check-style check-safety
lint: fix-style check-safety

#* Linting
.PHONY: test
Expand Down
1 change: 0 additions & 1 deletion inseq/attr/__init__.py
Expand Up @@ -7,7 +7,6 @@
register_step_score,
)


__all__ = [
"FeatureAttribution",
"list_feature_attribution_methods",
Expand Down
4 changes: 1 addition & 3 deletions inseq/attr/attribution_decorators.py
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.
""" Decorators for attribution methods. """

from typing import Any, Callable, List, Optional, Sequence

import logging
from functools import wraps
from typing import Any, Callable, List, Optional, Sequence

from ..data.data_utils import TensorWrapper


logger = logging.getLogger(__name__)


Expand Down
1 change: 0 additions & 1 deletion inseq/attr/feat/__init__.py
Expand Up @@ -13,7 +13,6 @@
SaliencyAttribution,
)


__all__ = [
"FeatureAttribution",
"extract_args",
Expand Down
4 changes: 1 addition & 3 deletions inseq/attr/feat/attention_attribution.py
Expand Up @@ -13,9 +13,8 @@
# limitations under the License.
""" Attention-based feature attribution methods. """

from typing import Any, Callable, Dict, Union

import logging
from typing import Any, Callable, Dict, Union

from ...data import Batch, EncoderDecoderBatch, FeatureAttributionStepOutput
from ...utils import Registry, pretty_tensor
Expand All @@ -25,7 +24,6 @@
from .feature_attribution import FeatureAttribution
from .ops import Attention


logger = logging.getLogger(__name__)


Expand Down
8 changes: 3 additions & 5 deletions inseq/attr/feat/attribution_utils.py
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import logging
import math
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch

Expand All @@ -20,7 +19,6 @@
TokenWithId,
)


if TYPE_CHECKING:
from ...models import AttributionModel
from .feature_attribution import FeatureAttribution
Expand Down Expand Up @@ -163,8 +161,8 @@ def extract_args(
if step_score not in STEP_SCORES_MAP:
raise AttributeError(
f"Step score {step_score} not found. Available step scores are: "
f"{', '.join([x for x in STEP_SCORES_MAP.keys()])}. Use the inseq.register_step_score"
f"function to register a custom step score."
f"{', '.join(list(STEP_SCORES_MAP.keys()))}. Use the inseq.register_step_score"
"function to register a custom step score."
)
extra_step_scores_args.update(
**extract_signature_args(
Expand Down
8 changes: 3 additions & 5 deletions inseq/attr/feat/feature_attribution.py
Expand Up @@ -16,11 +16,10 @@
Todo:
* 🟡: Allow custom arguments for model loading in the :class:`FeatureAttribution` :meth:`load` method.
"""
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import logging
from abc import abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from torchtyping import TensorType

Expand All @@ -46,7 +45,6 @@
from ..attribution_decorators import batched, set_hook, unset_hook
from .attribution_utils import STEP_SCORES_MAP, check_attribute_positions, get_step_scores, tok2string


if TYPE_CHECKING:
from ...models import AttributionModel

Expand Down Expand Up @@ -465,8 +463,8 @@ def filtered_attribute_step(
if step_score not in STEP_SCORES_MAP:
raise AttributeError(
f"Step score {step_score} not found. Available step scores are: "
f"{', '.join([x for x in STEP_SCORES_MAP.keys()])}. Use the inseq.register_step_score"
f"function to register a custom step score."
f"{', '.join(list(STEP_SCORES_MAP.keys()))}. Use the inseq.register_step_score"
"function to register a custom step score."
)
step_output.step_scores[step_score] = get_step_scores(
self.attribution_model, batch, target_ids, step_score, step_scores_args
Expand Down
4 changes: 1 addition & 3 deletions inseq/attr/feat/gradient_attribution.py
Expand Up @@ -13,9 +13,8 @@
# limitations under the License.
""" Gradient-based feature attribution methods. """

from typing import Any, Dict

import logging
from typing import Any, Dict

from captum.attr import (
DeepLift,
Expand All @@ -34,7 +33,6 @@
from .feature_attribution import FeatureAttribution
from .ops import DiscretetizedIntegratedGradients


logger = logging.getLogger(__name__)


Expand Down
1 change: 0 additions & 1 deletion inseq/attr/feat/ops/__init__.py
Expand Up @@ -2,5 +2,4 @@
from .discretized_integrated_gradients import DiscretetizedIntegratedGradients
from .monotonic_path_builder import MonotonicPathBuilder


__all__ = ["DiscretetizedIntegratedGradients", "MonotonicPathBuilder", "Attention"]
8 changes: 3 additions & 5 deletions inseq/attr/feat/ops/basic_attention.py
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Protocol, Tuple, Union

import logging
from typing import Any, Dict, List, Protocol, Tuple, Union

import torch
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
Expand All @@ -24,7 +23,6 @@
from ....data import Batch, EncoderDecoderBatch
from ....utils.typing import AggregatedLayerAttentionTensor, FullAttentionOutput, FullLayerAttentionTensor


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -115,7 +113,7 @@ def _aggregate_attention_heads(
if isinstance(aggregate_fn, str):
if aggregate_fn not in cls.AGGREGATE_FN_OPTIONS:
raise RuntimeError(
"Invalid aggregation method specified." f"Valid methods are: {cls.AGGREGATE_FN_OPTIONS.keys()}"
f"Invalid aggregation method specified.Valid methods are: {cls.AGGREGATE_FN_OPTIONS.keys()}"
)
aggregate_fn = cls.AGGREGATE_FN_OPTIONS[aggregate_fn]
if heads is None:
Expand Down Expand Up @@ -201,7 +199,7 @@ def _aggregate_layers(
if isinstance(aggregate_fn, str):
if aggregate_fn not in cls.AGGREGATE_FN_OPTIONS:
raise RuntimeError(
"Invalid aggregation method specified." f"Valid methods are: {cls.AGGREGATE_FN_OPTIONS.keys()}"
f"Invalid aggregation method specified.Valid methods are: {cls.AGGREGATE_FN_OPTIONS.keys()}"
)
aggregate_fn = cls.AGGREGATE_FN_OPTIONS[aggregate_fn]
if layers is None:
Expand Down
3 changes: 1 addition & 2 deletions inseq/attr/feat/ops/discretized_integrated_gradients.py
Expand Up @@ -16,9 +16,8 @@
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
# OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from typing import Any, Callable, List, Tuple, Union

from pathlib import Path
from typing import Any, Callable, List, Tuple, Union

import torch
from captum._utils.common import (
Expand Down
11 changes: 5 additions & 6 deletions inseq/attr/feat/ops/monotonic_path_builder.py
Expand Up @@ -17,21 +17,19 @@
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
# OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from typing import Any, List, Optional, Tuple, Union

import logging
import os
from enum import Enum
from itertools import islice
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union

import torch
from scipy.sparse import csr_matrix
from torchtyping import TensorType

from ....utils import is_joblib_available, is_scikitlearn_available


if is_joblib_available():
from joblib import Parallel, delayed

Expand All @@ -41,7 +39,6 @@
from ....utils import INSEQ_ARTIFACTS_CACHE, cache_results, euclidean_distance
from ....utils.typing import MultiStepEmbeddingsTensor, VocabularyEmbeddingsTensor


logger = logging.getLogger(__name__)


Expand All @@ -59,8 +56,10 @@ def __init__(
*args: Tuple[Any],
) -> None:
super().__init__(
f"Unknown strategy: {strategy}.\nAvailable strategies: "
f"{','.join([s.value for s in PathBuildingStrategies])}",
(
f"Unknown strategy: {strategy}.\nAvailable strategies: "
f"{','.join([s.value for s in PathBuildingStrategies])}"
),
*args,
)

Expand Down
7 changes: 3 additions & 4 deletions inseq/commands/attribute.py
@@ -1,7 +1,6 @@
from typing import List, Optional

import logging
from dataclasses import dataclass, field
from typing import List, Optional

from .. import list_feature_attribution_methods, load_model
from ..utils import get_default_device
Expand Down Expand Up @@ -123,9 +122,9 @@ def __post_init__(self):
if self.input_texts is None:
raise RuntimeError("Input texts must be specified.")
if isinstance(self.input_texts, str):
self.input_texts = [t for t in self.input_texts]
self.input_texts = list(self.input_texts)
if isinstance(self.generated_texts, str):
self.generated_texts = [t for t in self.generated_texts]
self.generated_texts = list(self.generated_texts)


def attribute(input_texts, generated_texts, args: AttributeBaseArgs):
Expand Down
4 changes: 1 addition & 3 deletions inseq/commands/attribute_dataset.py
@@ -1,12 +1,10 @@
from typing import List, Optional, Tuple

from dataclasses import dataclass, field
from typing import List, Optional, Tuple

from ..utils import is_datasets_available
from .attribute import AttributeBaseArgs, attribute
from .base import BaseCLICommand


if is_datasets_available():
from datasets import load_dataset

Expand Down
4 changes: 1 addition & 3 deletions inseq/commands/base.py
@@ -1,12 +1,10 @@
from typing import Any, Iterable, NewType, Union

import dataclasses
from abc import ABC, abstractstaticmethod
from argparse import Namespace
from typing import Any, Iterable, NewType, Union

from ..utils import InseqArgumentParser


DataClassType = NewType("DataClassType", Any)
OneOrMoreDataClasses = Union[DataClassType, Iterable[DataClassType]]

Expand Down
4 changes: 2 additions & 2 deletions inseq/commands/cli.py
@@ -1,10 +1,10 @@
"""Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/commands/transformers_cli.py"""
import sys

from ..utils import InseqArgumentParser
from .attribute import AttributeCommand
from .attribute_dataset import AttributeDatasetCommand


COMMANDS = [AttributeCommand, AttributeDatasetCommand]


Expand All @@ -19,7 +19,7 @@ def main():

if not hasattr(args, "factory_method"):
parser.print_help()
exit(1)
sys.exit(1)

# Run
command, command_args = args.factory_method(args)
Expand Down
1 change: 0 additions & 1 deletion inseq/data/__init__.py
Expand Up @@ -16,7 +16,6 @@
from .batch import Batch, BatchEmbedding, BatchEncoding, DecoderOnlyBatch, EncoderDecoderBatch
from .viz import show_attributions


__all__ = [
"Aggregator",
"AggregatorPipeline",
Expand Down