Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] fix incorrect type hints + refactor for impoved consistency #2304

Closed
wants to merge 10 commits into from

Conversation

danphenderson
Copy link
Contributor

@danphenderson danphenderson commented Feb 16, 2024

Proposed revisions for issue 2303

I am opening the pull request as a work-in-progress to collect and address feedback.

I was able to automate the revisions using Perl - reference the commit log to see the shell command that created each diff. There was some manual intervention after executing the commands in b642750 and dcc5d7e.

I need to do further investigation to test for any possible regressions in the documentation and to confirm that the generate_api.py script is still working as expected. I have validated mypy and the test-suite locally.

Thanks,
Daniel H

danphenderson added 10 commits February 15, 2024 16:34
Commit was automated using this powerful one-liner:

`find . -type f -name "*.py" -print0 | xargs -0 perl -i -pe \
	's/([a-zA-Z_]+): (?!(Optional\[|Any\]))([a-zA-Z_]+) = None/\1: Optional[\3] = None/g'`

Which modifies variable annotations initialized with None to use Optional[Type],
except when the type is already annotated as Optional[Type] or Any

After executing the command, I manually updated modules that didn't
already import Optional.

issue: 2303
Changes where automated with this perl one-liner:

`perl -pi -e 's/Union\[\s*(\w+)\s*,\s*None\s*\]/Optional[$1]/g' \
 $(find . -type f -name '*.py')`

WIP, still need to address the case where the union specifies more
than two different types.

issue: 2303
Changes with created with this perl command:

`perl -pi -e 's/typing\.//g' $(find . -type f -name '*.py')`

Which removes all instances of '.typing'. I manually had to update
the imports.

WIP: care should be taken to test that 'generate_api.py' and
'documentation_provider.py' scripts are working properly.

issue: 2303
Changes made with perl command:

`perl -pi -e 's/: Dict\[(.*?)\] = None/: Optional[Dict[$1]] = None/g'
$(find . -type f -name '*.py')`

Which replaces ': Dict[..., ...] = None' with ': Optional[Dict[..., ...]] = None', where '...' can be any type.

issue: 2303
Changes made with perl command:

`perl -pi -e 's/: Literal\[(.*?)\] = None/: Optional[Literal[$1]] = None/g'
    $(find . -type f -name '*.py')`

Which replaces ': Literal[..., ...] = None' with ': Optional[Literal[..., ...]] = None', where '...' can be any type.

issue: 2303
Changes made with a perl command (see 09f138e, 728d3ea)
with manual intervention to fix unintended modifications.

': Sequence[...] = None' has been corrected to ':
Optional[Sequence[...]] = None'.

issue: 2303
Changes made with the perl command:

find . -type f -name '*.py' -exec perl -i -pe \
	's/(\s*)(\w+): Union\[(.*?)\] = None,/$1$2: Optional[Union[$3]] = None,/g' {} +

Which replaces instances of '<param>: Union[...] = None,' with '<param>: Optional[Union[...]] = None,'
Note, this does not modify incorrect class attributes of this form (they
don't end with a ',').

issue: 2303
Changes made with the perl command:

find . -type f -name '*.py' -exec perl -i -pe \
	's/Union\[((?:(?![,]).)*), None\]/Optional[$1]/g' {} +

Which replaces Union type hints with a single 'TYPE' followed by 'None'
are replaced with 'Optional[<TYPE>]'.

issue: 2303
Changes made with shell command:

`perl -pi -e 's/: Callable\[(.*?)\] = None/: Optional[Callable[$1]] = None/g' $(find . -type f -name '*.py')`

Which replaces ': Callable[...] = None' with ': Optional[Callable[...]] = None', where '...' can be any type.

issue: 2303
@danphenderson
Copy link
Contributor Author

danphenderson commented Feb 19, 2024

I am closing this PR. After considering the issue at hand, I was able to craft a much more robust solution to perform the refactoring. The robust solution transforms the abstract syntax tree and the solution is testable. For documentation purposes, here is the POC script:

import ast

from typing import Union, List, Callable
from pathlib import Path
from black import format_str, FileMode


def ensure_typing_imports(ast_tree, import_name="Optional"):
    """
    Ensure that the specified import_name from the 'typing' module is present in the AST tree.
    If the 'typing' import exists but doesn't include import_name, add it.
    If there's no 'typing' import, add one with import_name.
    """
    typing_import_found = False
    for node in ast.iter_child_nodes(ast_tree):
        # Check if 'typing' is already imported and if import_name is included
        if isinstance(node, ast.ImportFrom) and node.module == 'typing':
            typing_import_found = True
            if import_name not in {alias.name for alias in node.names}:
                # Add import_name to the existing 'typing' import
                node.names.append(ast.alias(name=import_name, asname=None))
                node.names.sort(key=lambda alias: alias.name)  # Optional: sort imports alphabetically
    if not typing_import_found:
        # Add a new import statement for 'typing' including import_name
        new_import = ast.ImportFrom(module='typing', names=[ast.alias(name=import_name, asname=None)], level=0)
        ast_tree.body.insert(0, new_import)  # Add the new import at the top of the file



class UnionToOptionalTransformer(ast.NodeTransformer):
    """
    Transformer to convert 'Union' types containing 'None' to 'Optional' types.
    """

    def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
        """
        Visit a Subscript node in the AST and transform Union[Type, None] to Optional[Type].
        """
        # Check if the node is a Union type
        if isinstance(node.value, ast.Name) and node.value.id == 'Union':
            # Extract the types in the Union
            types = self._extract_types_from_union(node.slice) # type: ignore
            
            # Check if None is one of the types in the Union
            if any(isinstance(t, ast.Constant) and t.value is None for t in types):
                # Remove None from the types
                types_without_none = [t for t in types if not (isinstance(t, ast.Constant) and t.value is None)]
                
                # Convert to Optional type
                if len(types_without_none) == 1:
                    # Single type + None becomes Optional[SingleType]
                    new_node = self._create_optional_node(types_without_none[0])
                else:
                    # Multiple types + None becomes Optional[Union[Types]]
                    union_node = self._create_union_node(types_without_none)
                    new_node = self._create_optional_node(union_node)
                    
                return ast.copy_location(new_node, node)
        return self.generic_visit(node)

    def _extract_types_from_union(self, slice: ast.Slice) -> List[ast.AST]:
        """
        Extract types from the Union's slice.
        """
        if isinstance(slice, ast.Index):
            if isinstance(slice.value, ast.Tuple):
                return list(slice.value.elts)
            return [slice.value]
        elif isinstance(slice, ast.Tuple):
            return list(slice.elts)
        else:
            raise NotImplementedError(f"Unsupported slice type: {type(slice)}")
    
    def _create_optional_node(self, type_node: ast.AST) -> ast.Subscript:
        """
        Create an Optional[type] node.
        """
        optional_name = ast.Name(id='Optional', ctx=ast.Load())
        return ast.Subscript(value=optional_name, slice=ast.Index(value=type_node), ctx=ast.Load())

    def _create_union_node(self, types: List[ast.AST]) -> ast.Subscript:
        """
        Create a Union[Types] node.
        """
        union_name = ast.Name(id='Union', ctx=ast.Load())
        types_tuple = ast.Tuple(elts=types, ctx=ast.Load())
        return ast.Subscript(value=union_name, slice=ast.Index(value=types_tuple), ctx=ast.Load())


def union_to_optional_transform(source_code: str) -> str:
    """
    Parse the source code into an abstract syntax tree, apply the UnionToOptionalTransformer,
    and return the modified code.
    """
    tree = ast.parse(source_code)
    transformer = UnionToOptionalTransformer()
    transformed_tree = transformer.visit(tree)
    
    # Ensure 'Optional' is imported from 'typing', if needed
    ensure_typing_imports(transformed_tree)
    
    # Assuming Python 3.9+ for ast.unparse
    transformed_code = ast.unparse(transformed_tree) # type: ignore
    return format_str(transformed_code, mode=FileMode())


class EnforceOptionalTransformer(ast.NodeTransformer):
    """
    Transformer to enforce the use of 'Optional' for annotated assignments to None.
    """
    def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST:
        # Check if the value assigned is None
        if isinstance(node.value, ast.Constant) and node.value.value is None:
            # Ensure the annotation is not already Optional or Any
            if not (self._is_optional_annotation(node.annotation) or self._is_any_annotation(node.annotation)):
                # Wrap the existing annotation in Optional
                node.annotation = self._wrap_with_optional(node.annotation)
        return self.generic_visit(node)
    
    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
        return self._visit_function_common(node)

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
        return self._visit_function_common(node)

    def _visit_function_common(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> ast.AST:
        # Adjust the annotations for parameters with a default value of None
        for i, default in enumerate(node.args.defaults):
            if isinstance(default, ast.Constant) and default.value is None:
                arg = node.args.args[len(node.args.args) - len(node.args.defaults) + i]
                if arg.annotation:
                    if not (self._is_optional_annotation(arg.annotation) or self._is_any_annotation(arg.annotation)):
                        arg.annotation = self._wrap_with_optional(arg.annotation)
        return self.generic_visit(node)

    def _is_optional_annotation(self, annotation: ast.expr) -> bool:
        """Check if the annotation is already Optional."""
        return (isinstance(annotation, ast.Subscript) and
                isinstance(annotation.value, ast.Name) and
                annotation.value.id == 'Optional')

    def _is_any_annotation(self, annotation: ast.expr) -> bool:
        """Check if the annotation is Any."""
        return isinstance(annotation, ast.Name) and annotation.id == 'Any'

    def _wrap_with_optional(self, type_annotation: ast.expr) -> ast.Subscript:
        """Wrap the provided type annotation in Optional."""
        optional_name = ast.Name(id='Optional', ctx=ast.Load())
        return ast.Subscript(value=optional_name, slice=ast.Index(value=type_annotation), ctx=ast.Load())
    


def enforce_optional_transform(source_code: str) -> str:
    """
    Parse the source code, apply corrections for optional type hints, and return the modified code.
    """
    tree = ast.parse(source_code)
    transformer = EnforceOptionalTransformer()
    transformed_tree = transformer.visit(tree)
    # Ensure 'Optional' is imported from 'typing', if needed
    ensure_typing_imports(transformed_tree)
    
    # Assuming Python 3.9+ for ast.unparse
    transformed_code = ast.unparse(transformed_tree) # type: ignore
    return format_str(transformed_code, mode=FileMode())



def process_files_in_directory(directory_path: Path, transformation_pipeline: List[Callable]):
    for file in directory_path.glob('**/*.py'):  # Adjust the glob pattern if needed
        with open(file, 'r', encoding='utf-8') as f:
            original_content = f.read()
        
        # Apply the transformation pipeline
        transformed_content = original_content
        for transform in transformation_pipeline:
            transformed_content = transform(transformed_content)

        if original_content != transformed_content:
            with open(file, 'w', encoding='utf-8') as f:
                f.write(transformed_content)
            print(f"Processed {file}")



def main() -> None:
    # Define the transformation pipeline
    transformation_pipeline = [union_to_optional_transform, enforce_optional_transform]

    # Assumeing the script is run from the root of the repository
    directory = Path.cwd() / Path("playwright")
    if not directory.exists():
        raise FileNotFoundError("Directory 'playwright' not found")
        
    process_files_in_directory(
        Path.cwd() / Path("playwright"),
        transformation_pipeline
    )

if __name__ == "__main__":
    main()

Note, that this script also has some inherent issues. Namely:

  • It loses inline comments ('#') during the transformation.
  • It can't be executed in a python 3.8 installation of playwright.
    I am going to work on porting the script above to libcst.

Also, here is the test coverage that I have written for the POC implementation:

import pytest
from utils.linting.check_type_hints import union_to_optional_transform, enforce_optional_transform

def normalize_code(code: str) -> str:
    """Normalize Python code string for comparison."""
    # Remove leading/trailing whitespace and replace sequences of whitespace characters with a single space
    normalized = ' '.join(code.split())
    # Replace double quotes with single quotes in string literals for consistent comparison
    normalized = normalized.replace('"', "'")
    return normalized


@pytest.mark.parametrize("source_code, expected_code", [
    # Test single type with None
    (
        "a: Union[int, None]",
        "a: Optional[int]"
    ),
    # Test subscripted single type with None
    (
        "a: Union[Dict[str, int], None]",
        "a: Optional[Dict[str, int]]"
    ),
    # Test multiple types including None
    (
        "a: Union[str, int, None]",
        "a: Optional[Union[str, int]]"
    ),
    # Test multiple subscripted types with None
    (
        "a: Union[Dict[str, int], List[int], None]",
        "a: Optional[Union[Dict[str, int], List[int]]]"
    ),  
    # Test Union without None should remain unchanged
    (
        "a: Union[str, int]",
        "a: Union[str, int]"
    ),
    # Test Union without None and multiple types should remain unchanged
    (
        "a: Union[str, int, float]",
        "a: Union[str, int, float]"
    ),
    # Test nested Unions with None
    (
        "a: Union[str, Union[int, None]]",
        "a: Union[str, Optional[int]]"
    ),
    # Test nested Unions without None should remain unchanged
    (
        "a: Union[str, Union[int, float]]",
        "a: Union[str, Union[int, float]]"
    ),
    # Test function parameter and return type annotations
    (
        "def func(a: Union[str, None]) -> Union[int, None]: pass",
        "def func(a: Optional[str]) -> Optional[int]: pass"
    ),
    # Test async function parameter and return type annotations
    (
        "async def func(a: Union[str, None]) -> Union[int, None]: pass",
        "async def func(a: Optional[str]) -> Optional[int]: pass"
    ),
    # Test variable annotations
    (
        "var: Union[str, None] = 'hello'",
        "var: Optional[str] = 'hello'"
    ),
    # Test class attribute annotations
    (
        "class A: a: Union[str, None] = 'hello'",
        "class A: a: Optional[str] = 'hello'"
    ),
    # Test variable annotations within a function
    (
        "def func(): var: Union[str, None] = 'hello'",
        "def func(): var: Optional[str] = 'hello'"
    ),
])
def test_union_to_optional_transform(source_code, expected_code):
    transformed_code = union_to_optional_transform(source_code)
    assert normalize_code(transformed_code) == normalize_code(expected_code)


@pytest.mark.parametrize("source_code, expected_code", [
    # Single type, should be wrapped with Optional
    (
        "var: int = None",
        "var: Optional[int] = None"
    ),
    # Single type with existing Optional, should remain unchanged
    (
        "var: Optional[int] = None",
        "var: Optional[int] = None"
    ),
    # Type hint is Any, should remain unchanged
    (
        "var: Any = None",
        "var: Any = None"
    ),
    # Nested Optional, inner type should not be wrapped again
    (
        "var: Optional[Dict[str, int]] = None",
        "var: Optional[Dict[str, int]] = None"
    ),
    # Without initial annotation, should remain unchanged (though not a common use case for this transformer)
    (
        "var = None",
        "var = None"
    ),
    # Ensure complex types are correctly handled
    (
        "var: Dict[str, List[int]] = None",
        "var: Optional[Dict[str, List[int]]] = None"
    ),
    # Test annotations within a function param
    (
        "def func(var: int = None): pass",
        "def func(var: Optional[int] = None): pass"
    ),
    # Test annotations within a class
    (
        "class A: var: int = None",
        "class A: var: Optional[int] = None"
    ),
    # Test annotations within an async function
    (
        "async def func(var: int = None): pass",
        "async def func(var: Optional[int] = None): pass"
    ),
])
def test_enforce_optional_transform(source_code, expected_code):
    transformed_code = enforce_optional_transform(source_code)
    assert normalize_code(transformed_code) == normalize_code(expected_code)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant