diff --git a/README.md b/README.md index 1f31f064..6172839a 100644 --- a/README.md +++ b/README.md @@ -694,3 +694,16 @@ Availability: -def f(x: 'queue.Queue[int]') -> C: +def f(x: queue.Queue[int]) -> C: ``` + + +### use `datetime.UTC` alias + +Availability: +- `--py311-plus` is passed on the commandline. + +```diff + import datetime + +-datetime.timezone.utc ++datetime.UTC +``` diff --git a/pyupgrade/_plugins/datetime_utc_alias.py b/pyupgrade/_plugins/datetime_utc_alias.py new file mode 100644 index 00000000..5280e5db --- /dev/null +++ b/pyupgrade/_plugins/datetime_utc_alias.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import ast +import functools +from typing import Iterable + +from tokenize_rt import Offset + +from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._data import register +from pyupgrade._data import State +from pyupgrade._data import TokenFunc +from pyupgrade._token_helpers import replace_name + + +@register(ast.Attribute) +def visit_Attribute( + state: State, + node: ast.Attribute, + parent: ast.AST, +) -> Iterable[tuple[Offset, TokenFunc]]: + if ( + state.settings.min_version >= (3, 11) and + node.attr == 'utc' and + isinstance(node.value, ast.Attribute) and + node.value.attr == 'timezone' and + isinstance(node.value.value, ast.Name) and + node.value.value.id == 'datetime' + ): + func = functools.partial( + replace_name, + name='utc', + new='datetime.UTC', + ) + yield ast_to_offset(node), func diff --git a/tests/features/datetime_utc_alias_test.py b/tests/features/datetime_utc_alias_test.py new file mode 100644 index 00000000..b40f8ff0 --- /dev/null +++ b/tests/features/datetime_utc_alias_test.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import pytest + +from pyupgrade._data import Settings +from pyupgrade._main import _fix_plugins + + +@pytest.mark.parametrize( + ('s',), + ( + pytest.param( + 'import datetime\n' + 'print(datetime.timezone(-1))', + + id='not rewriting timezone object to alias', + ), + ), +) +def test_fix_datetime_utc_alias_noop(s): + assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s + assert _fix_plugins(s, settings=Settings(min_version=(3, 11))) == s + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + pytest.param( + 'import datetime\n' + 'print(datetime.timezone.utc)', + + 'import datetime\n' + 'print(datetime.UTC)', + + id='rewriting to alias', + ), + ), +) +def test_fix_datetime_utc_alias(s, expected): + assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s + assert _fix_plugins(s, settings=Settings(min_version=(3, 11))) == expected