diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index eb7497a..53a5165 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -12,6 +12,7 @@ this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Fixed +- [Imports used in generics and wrapped in string get removed by @hadialqattan](https://github.com/hadialqattan/pycln/pull/178) - [Setting an artificial lock to LibCST version `+0.4.0` for Python `3.6.x`; a bug introduced in LibCST `+0.4.0` that affects `PY3.6` by @hadialqattan](https://github.com/hadialqattan/pycln/pull/174) ### Added diff --git a/docs/README.md b/docs/README.md index 36f1f64..67aace5 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1005,6 +1005,23 @@ __all__ = ["os", "time"] > Not supported, also not on the > [roadmap](https://github.com/hadialqattan/pycln/projects/3). +### Generics wrapping strings + +> Pycln can understand imports used in generics and wrapped in string. + +```python +from typing import Generic, TypeVar +from xxx import Baz # marked as used. + +CustomType = TypeVar("CustomType") + +class Foo(Generic[CustomType]): + ... + +class Bar(Foo["Baz"]): # <~ + ... +``` + ### Init file (`__init__.py`) > Pycln can not decide whether the unused imported names are useless or imported to be diff --git a/pycln/utils/scan.py b/pycln/utils/scan.py index 1e3fca1..e9165fe 100644 --- a/pycln/utils/scan.py +++ b/pycln/utils/scan.py @@ -299,6 +299,29 @@ def visit_FunctionDef(self, node: FunctionDefT): # Support `ast.AsyncFunctionDef`. visit_AsyncFunctionDef = visit_FunctionDef + @recursive + def visit_ClassDef(self, node: ast.ClassDef): + #: Support imports used in generics and wrapped in string: + #: + #: >>> from typing import Generic + #: >>> from foo import Bar + #: >>> + #: >>> class SuperClass(Generic[SomeType]): + #: >>> ... + #: >>> + #: >>> class SubClass(SuperClass["Bar"]) # <~ detecting Bar. + #: >>> ... + #: + #: Issue: https://github.com/hadialqattan/pycln/issues/169 + for base in node.bases: + if isinstance(base, ast.Subscript): + if PY39_PLUS: + s_val = base.slice # type: ignore + else: + s_val = base.slice.value # type: ignore + for elt in getattr(s_val, "elts", ()) or (s_val,): + self._parse_string(elt) # type: ignore + @recursive def visit_Assign(self, node: ast.Assign): # Support Python ^3.8 type comments. diff --git a/tests/test_scan.py b/tests/test_scan.py index 001e7ed..9ff7402 100644 --- a/tests/test_scan.py +++ b/tests/test_scan.py @@ -563,6 +563,31 @@ def test_visit_AsyncFunctionDef(self, code, expec_names): source_stats, _ = analyzer.get_stats() self.assert_set_equal_or_not(source_stats.name_, expec_names) + @pytest.mark.parametrize( + "code, expec_names", + [ + pytest.param( + ("class Foo(Bar):\n" " pass"), + {"Bar"}, + id="no generics", + ), + pytest.param( + ("class Foo(Bar['Baz']):\n" " pass"), + {"Bar", "Baz"}, + id="one generic", + ), + pytest.param( + ("class Foo(Bar['Baz'], Bax['Tax']):\n" " pass"), + {"Bar", "Baz", "Bax", "Tax"}, + id="many generics", + ), + ], + ) + def test_visit_ClassDef(self, code, expec_names): + analyzer = self._get_analyzer(code) + source_stats, _ = analyzer.get_stats() + self.assert_set_equal_or_not(source_stats.name_, expec_names) + @pytest.mark.parametrize( "code, expec_names, expec_names_to_skip", [