From ca0411330be56b159670c5ad23ca30afcb8e29ac Mon Sep 17 00:00:00 2001 From: Martin Thoma Date: Sun, 20 Feb 2022 08:29:10 +0100 Subject: [PATCH] SIM906: Merge nested os.path.join calls (#104) Credits to Skylion007 for defining this rule! Closes #101 --- README.md | 17 ++++++++++ flake8_simplify.py | 75 +++++++++++++++++++++++++++++++++++++++--- tests/test_simplify.py | 23 +++++++++++++ 3 files changed, 111 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 79277bf..a81494a 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ Python-specific rules: * `SIM120` ![](https://shields.io/badge/-legacyfix-inactive): Use 'class FooBar:' instead of 'class FooBar(object):' ([example](#SIM120)) * `SIM124`: Reserved for [SIM904](#sim904) once it's stable * `SIM125`: Reserved for [SIM905](#sim905) once it's stable +* `SIM126`: Reserved for [SIM906](#sim906) once it's stable Simplifying Comparations: @@ -96,6 +97,7 @@ Current experimental rules: * `SIM901`: Use comparisons directly instead of wrapping them in a `bool(...)` call ([example](#SIM901)) * `SIM904`: Assign values to dictionary directly at initialization ([example](#SIM904)) * [`SIM905`](https://github.com/MartinThoma/flake8-simplify/issues/86): Split string directly if only constants are used ([example](#SIM905)) +* [`SIM906`](https://github.com/MartinThoma/flake8-simplify/issues/101): Merge nested os.path.join calls ([example](#SIM906)) ## Disabling Rules @@ -581,3 +583,18 @@ domains = "de com net org".split() # Good domains = ["de", "com", "net", "org"] ``` + +### SIM906 + +This rule will be renamed to `SIM126` after its 6-month trial period is over. +Please report any issues you encounter with this rule! + +The trial period starts on 20th of February and will end on 20th of September 2022. + +```python +# Bad +os.path.join(a, os.path.join(b, c)) + +# Good +os.path.join(a, b, c) +``` diff --git a/flake8_simplify.py b/flake8_simplify.py index 093d68e..21c3e5b 100644 --- a/flake8_simplify.py +++ b/flake8_simplify.py @@ -2,6 +2,7 @@ import ast import itertools import json +import logging import sys from collections import defaultdict from typing import ( @@ -19,6 +20,8 @@ # Third party import astor +logger = logging.getLogger(__name__) + class UnaryOp(ast.UnaryOp): def __init__(self, orig: ast.UnaryOp) -> None: @@ -108,9 +111,10 @@ def __init__(self, orig: ast.Call) -> None: "SIM401 Use '{value} = {dict}.get({key}, {default_value})' " "instead of an if-block" ) -SIM901 = "SIM901 Use '{better}' instead of '{current}'" +SIM901 = "SIM901 Use '{expected}' instead of '{actual}'" SIM904 = "SIM904 Initialize dictionary '{dict_name}' directly" SIM905 = "SIM905 Use '{expected}' instead of '{actual}'" +SIM906 = "SIM906 Use '{expected}' instead of '{actual}'" # ast.Constant in Python 3.8, ast.NameConstant in Python 3.6 and 3.7 BOOL_CONST_TYPES = (ast.Constant, ast.NameConstant) @@ -1828,14 +1832,14 @@ def _get_sim901(node: ast.Call) -> List[Tuple[int, int, str]]: ): return errors - current = to_source(node) - better = to_source(node.args[0]) + actual = to_source(node) + expected = to_source(node.args[0]) errors.append( ( node.lineno, node.col_offset, - SIM901.format(current=current, better=better), + SIM901.format(actual=actual, expected=expected), ) ) return errors @@ -1937,6 +1941,68 @@ def _get_sim905(node: ast.Call) -> List[Tuple[int, int, str]]: return errors +def _get_sim906(node: ast.Call) -> List[Tuple[int, int, str]]: + errors: List[Tuple[int, int, str]] = [] + if not ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Attribute) + and isinstance(node.func.value.value, ast.Name) + and node.func.value.value.id == "os" + and node.func.value.attr == "path" + and node.func.attr == "join" + and len(node.args) == 2 + and any( + ( + isinstance(arg, ast.Call) + and isinstance(arg.func, ast.Attribute) + and isinstance(arg.func.value, ast.Attribute) + and isinstance(arg.func.value.value, ast.Name) + and arg.func.value.value.id == "os" + and arg.func.value.attr == "path" + and arg.func.attr == "join" + ) + for arg in node.args + ) + ): + return errors + + def get_os_path_join_args(node: ast.Call) -> List[str]: + names: List[str] = [] + for arg in node.args: + if ( + isinstance(arg, ast.Call) + and isinstance(arg.func, ast.Attribute) + and isinstance(arg.func.value, ast.Attribute) + and isinstance(arg.func.value.value, ast.Name) + and arg.func.value.value.id == "os" + and arg.func.value.attr == "path" + and arg.func.attr == "join" + ): + names = names + get_os_path_join_args(arg) + elif isinstance(arg, ast.Name): + names.append(arg.id) + elif isinstance(arg, ast.Str): + names.append(f"'{arg.s}'") + else: + logger.debug( + f"Unexpexted os.path.join arg: {arg} -- {to_source(arg)}" + ) + return names + + names = get_os_path_join_args(node) + + actual = to_source(node) + expected = f"os.path.join({', '.join(names)})" + errors.append( + ( + node.lineno, + node.col_offset, + SIM906.format(actual=actual, expected=expected), + ) + ) + return errors + + class Visitor(ast.NodeVisitor): def __init__(self) -> None: self.errors: List[Tuple[int, int, str]] = [] @@ -1949,6 +2015,7 @@ def visit_Call(self, node: ast.Call) -> Any: self.errors += _get_sim115(Call(node)) self.errors += _get_sim901(node) self.errors += _get_sim905(node) + self.errors += _get_sim906(node) self.generic_visit(node) def visit_With(self, node: ast.With) -> Any: diff --git a/tests/test_simplify.py b/tests/test_simplify.py index 90cdc86..8faeed7 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -989,3 +989,26 @@ def test_sim905(): '1:10 SIM905 Use \'["de", "com", "net", "org"]\' ' "instead of '\"de com net org\".split()'" } + + +@pytest.mark.parametrize( + ("s", "msg"), + ( + # Credits to Skylion007 for the following example + # https://github.com/MartinThoma/flake8-simplify/issues/101 + ( + "os.path.join(a,os.path.join(b,c))", + "1:0 SIM906 Use 'os.path.join(a, b, c)' " + "instead of 'os.path.join(a, os.path.join(b, c))'", + ), + ( + "os.path.join(a,os.path.join('b',c))", + "1:0 SIM906 Use 'os.path.join(a, 'b', c)' " + "instead of 'os.path.join(a, os.path.join('b', c))'", + ), + ), + ids=["base", "str-arg"], +) +def test_sim906(s, msg): + results = _results(s) + assert results == {msg}