diff --git a/pyupgrade/_plugins/typing_pep604.py b/pyupgrade/_plugins/typing_pep604.py index 1a2ec3e6..1067af4c 100644 --- a/pyupgrade/_plugins/typing_pep604.py +++ b/pyupgrade/_plugins/typing_pep604.py @@ -144,16 +144,17 @@ def visit_Subscript( if not _supported_version(state): return - # prevent rewriting forward annotations - if ( - (sys.version_info >= (3, 9) and _any_arg_is_str(node.slice)) or - ( - sys.version_info < (3, 9) and - isinstance(node.slice, ast.Index) and - _any_arg_is_str(node.slice.value) - ) - ): - return + # don't rewrite forward annotations (unless we know they will be dequoted) + if 'annotations' not in state.from_imports['__future__']: + if ( + (sys.version_info >= (3, 9) and _any_arg_is_str(node.slice)) or + ( + sys.version_info < (3, 9) and + isinstance(node.slice, ast.Index) and + _any_arg_is_str(node.slice.value) + ) + ): + return if is_name_attr( node.value, diff --git a/tests/features/typing_pep604_test.py b/tests/features/typing_pep604_test.py index 5c4f648e..035429c7 100644 --- a/tests/features/typing_pep604_test.py +++ b/tests/features/typing_pep604_test.py @@ -185,6 +185,17 @@ def f(x: int | str) -> None: ... id='Optional rewrite multi-line', ), + pytest.param( + 'from __future__ import annotations\n' + 'from typing import Optional\n' + 'x: Optional["str"]\n', + + 'from __future__ import annotations\n' + 'from typing import Optional\n' + 'x: str | None\n', + + id='Optional rewrite with forward reference', + ), pytest.param( 'from typing import Union, Sequence\n' 'def f(x: Union[Union[A, B], Sequence[Union[C, D]]]): pass\n',