diff --git a/pyupgrade/_plugins/versioned_branches.py b/pyupgrade/_plugins/versioned_branches.py index df66b744..030e1a4e 100644 --- a/pyupgrade/_plugins/versioned_branches.py +++ b/pyupgrade/_plugins/versioned_branches.py @@ -77,11 +77,13 @@ def _eq(test: ast.Compare, n: int) -> bool: def _compare_to_3( test: ast.Compare, op: Union[Type[ast.cmpop], Tuple[Type[ast.cmpop], ...]], + minor: int = 0, ) -> bool: + min_len = 2 if minor else 1 if not ( isinstance(test.ops[0], op) and isinstance(test.comparators[0], ast.Tuple) and - len(test.comparators[0].elts) >= 1 and + len(test.comparators[0].elts) >= min_len and all(isinstance(n, ast.Num) for n in test.comparators[0].elts) ): return False @@ -89,7 +91,13 @@ def _compare_to_3( # checked above but mypy needs help elts = cast('List[ast.Num]', test.comparators[0].elts) - return elts[0].n == 3 and all(n.n == 0 for n in elts[1:]) + retv = elts[0].n == 3 + offset = 1 + if minor: + retv &= elts[1].n == minor + offset += 1 + retv &= all(n.n == 0 for n in elts[offset:]) + return retv @register(ast.If) @@ -98,6 +106,11 @@ def visit_If( node: ast.If, parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: + if state.settings.min_version >= (3,) and ( + len(state.settings.min_version) >= 2 + ): + py3_minor = state.settings.min_version[1] + if ( state.settings.min_version >= (3,) and ( # if six.PY2: @@ -126,6 +139,33 @@ def visit_If( _eq(node.test, 2) or _compare_to_3(node.test, ast.Lt) ) + ) or + # sys.version_info < (3, n) (with n<=m) + # or sys.version_info <= (3, n) (with n= 2 and ( + isinstance(node.test, ast.Compare) and + is_name_attr( + node.test.left, + state.from_imports, + 'sys', + ('version_info',), + ) and + len(node.test.ops) == 1 and ( + _compare_to_3( + node.test, + ast.Lt, + minor=py3_minor, + ) or any( + _compare_to_3( + node.test, + (ast.Lt, ast.LtE), + minor=minor, + ) + for minor in range(py3_minor) + ) + ) + ) ) ) ): @@ -159,6 +199,36 @@ def visit_If( _eq(node.test, 3) or _compare_to_3(node.test, (ast.Gt, ast.GtE)) ) + ) or + # sys.version_info >= (3, n) (with n<=m) + # or sys.version_info > (3, n) (with n= 2 and ( + + isinstance(node.test, ast.Compare) and + is_name_attr( + node.test.left, + state.from_imports, + 'sys', + ('version_info',), + ) and + len(node.test.ops) == 1 and + ( + _compare_to_3( + node.test, + ast.GtE, + minor=py3_minor, + ) or any( + _compare_to_3( + node.test, + (ast.Gt, ast.GtE), + minor=minor, + ) + for minor in range(py3_minor) + ) + ) + + ) ) ) ): diff --git a/tests/features/versioned_branches_test.py b/tests/features/versioned_branches_test.py index b135fd90..4b54673d 100644 --- a/tests/features/versioned_branches_test.py +++ b/tests/features/versioned_branches_test.py @@ -23,11 +23,6 @@ ' pass\n' 'elif False:\n' ' pass\n', - # don't rewrite version compares with not 3.0 compares - 'if sys.version_info >= (3, 6):\n' - ' 3.6\n' - 'else:\n' - ' 3.5\n', # don't try and think about `sys.version` 'from sys import version\n' 'if sys.version[0] > "2":\n' @@ -452,3 +447,155 @@ def test_fix_py2_blocks(s, expected): def test_fix_py3_only_code(s, expected): ret = _fix_plugins(s, settings=Settings(min_version=(3,))) assert ret == expected + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + pytest.param( + 'import sys\n' + 'if sys.version_info > (3, 5):\n' + ' 3+6\n' + 'else:\n' + ' 3-5\n', + + 'import sys\n' + '3+6\n', + id='sys.version_info > (3, 5)', + ), + pytest.param( + 'from sys import version_info\n' + 'if version_info > (3, 5):\n' + ' 3+6\n' + 'else:\n' + ' 3-5\n', + + 'from sys import version_info\n' + '3+6\n', + id='from sys import version_info, > (3, 5)', + ), + pytest.param( + 'import sys\n' + 'if sys.version_info >= (3, 6):\n' + ' 3+6\n' + 'else:\n' + ' 3-5\n', + + 'import sys\n' + '3+6\n', + id='sys.version_info >= (3, 6)', + ), + pytest.param( + 'from sys import version_info\n' + 'if version_info >= (3, 6):\n' + ' 3+6\n' + 'else:\n' + ' 3-5\n', + + 'from sys import version_info\n' + '3+6\n', + id='from sys import version_info, >= (3, 6)', + ), + pytest.param( + 'import sys\n' + 'if sys.version_info < (3, 6):\n' + ' 3-5\n' + 'else:\n' + ' 3+6\n', + + 'import sys\n' + '3+6\n', + id='sys.version_info < (3, 6)', + ), + pytest.param( + 'from sys import version_info\n' + 'if version_info < (3, 6):\n' + ' 3-5\n' + 'else:\n' + ' 3+6\n', + + 'from sys import version_info\n' + '3+6\n', + id='from sys import version_info, < (3, 6)', + ), + pytest.param( + 'import sys\n' + 'if sys.version_info <= (3, 5):\n' + ' 3-5\n' + 'else:\n' + ' 3+6\n', + + 'import sys\n' + '3+6\n', + id='sys.version_info <= (3, 5)', + ), + pytest.param( + 'from sys import version_info\n' + 'if version_info <= (3, 5):\n' + ' 3-5\n' + 'else:\n' + ' 3+6\n', + + 'from sys import version_info\n' + '3+6\n', + id='from sys import version_info, <= (3, 5)', + ), + ), +) +def test_fix_py3x_only_code(s, expected): + ret = _fix_plugins(s, settings=Settings(min_version=(3, 6))) + assert ret == expected + + +@pytest.mark.parametrize( + 's', + ( + # we timidly skip `if` without `else` as it could cause a SyntaxError + 'import sys' + 'if sys.version_info >= (3, 6):\n' + ' pass', + # here's the case where it causes a SyntaxError + 'import sys' + 'if True' + ' if sys.version_info >= (3, 6):\n' + ' pass\n', + # both branches are still relevant in the following cases + 'import sys\n' + 'if sys.version_info > (3, 7):\n' + ' 3-6\n' + 'else:\n' + ' 3+7\n', + + 'import sys\n' + 'if sys.version_info < (3, 7):\n' + ' 3-6\n' + 'else:\n' + ' 3+7\n', + + 'import sys\n' + 'if sys.version_info >= (3, 7):\n' + ' 3+7\n' + 'else:\n' + ' 3-6\n', + + 'import sys\n' + 'if sys.version_info <= (3, 7):\n' + ' 3-7\n' + 'else:\n' + ' 3+8\n', + + 'import sys\n' + 'if sys.version_info <= (3, 6):\n' + ' 3-6\n' + 'else:\n' + ' 3+7\n', + + 'import sys\n' + 'if sys.version_info > (3, 6):\n' + ' 3+7\n' + 'else:\n' + ' 3-6\n', + ), +) +def test_fix_py3x_only_noop(s): + assert _fix_plugins(s, settings=Settings(min_version=(3, 6))) == s