diff --git a/mypy/suggestions.py b/mypy/suggestions.py index 6e987d455ae4..19879dee0cda 100644 --- a/mypy/suggestions.py +++ b/mypy/suggestions.py @@ -42,7 +42,9 @@ reverse_builtin_aliases, ) from mypy.server.update import FineGrainedBuildManager -from mypy.util import module_prefix, split_target +from mypy.util import split_target +from mypy.find_sources import SourceFinder, InvalidSourceList +from mypy.modulefinder import PYTHON_EXTENSIONS from mypy.plugin import Plugin, FunctionContext, MethodContext from mypy.traverser import TraverserVisitor from mypy.checkexpr import has_any_type @@ -162,6 +164,7 @@ def __init__(self, fgmanager: FineGrainedBuildManager, self.manager = fgmanager.manager self.plugin = self.manager.plugin self.graph = fgmanager.graph + self.finder = SourceFinder(self.manager.fscache) self.give_json = json self.no_errors = no_errors @@ -174,19 +177,21 @@ def __init__(self, fgmanager: FineGrainedBuildManager, def suggest(self, function: str) -> str: """Suggest an inferred type for function.""" - with self.restore_after(function): + mod, func_name, node = self.find_node(function) + + with self.restore_after(mod): with self.with_export_types(): - suggestion = self.get_suggestion(function) + suggestion = self.get_suggestion(mod, node) if self.give_json: - return self.json_suggestion(function, suggestion) + return self.json_suggestion(mod, func_name, node, suggestion) else: return self.format_signature(suggestion) def suggest_callsites(self, function: str) -> str: """Find a list of call sites of function.""" - with self.restore_after(function): - _, _, node = self.find_node(function) + mod, _, node = self.find_node(function) + with self.restore_after(mod): callsites, _ = self.get_callsites(node) return '\n'.join(dedup( @@ -195,7 +200,7 @@ def suggest_callsites(self, function: str) -> str: )) @contextmanager - def restore_after(self, target: str) -> Iterator[None]: + def restore_after(self, module: str) -> Iterator[None]: """Context manager that reloads a module after executing the body. This should undo any damage done to the module state while mucking around. @@ -203,9 +208,7 @@ def restore_after(self, target: str) -> Iterator[None]: try: yield finally: - module = module_prefix(self.graph, target) - if module: - self.reload(self.graph[module]) + self.reload(self.graph[module]) @contextmanager def with_export_types(self) -> Iterator[None]: @@ -321,13 +324,12 @@ def find_best(self, func: FuncDef, guesses: List[CallableType]) -> Tuple[Callabl key=lambda s: (count_errors(errors[s]), self.score_callable(s))) return best, count_errors(errors[best]) - def get_suggestion(self, function: str) -> PyAnnotateSignature: + def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature: """Compute a suggestion for a function. Return the type and whether the first argument should be ignored. """ graph = self.graph - mod, _, node = self.find_node(function) callsites, orig_errors = self.get_callsites(node) if self.no_errors and orig_errors: @@ -386,15 +388,49 @@ def format_args(self, return "(%s)" % (", ".join(args)) def find_node(self, key: str) -> Tuple[str, str, FuncDef]: - """From a target name, return module/target names and the func def.""" + """From a target name, return module/target names and the func def. + + The 'key' argument can be in one of two formats: + * As the function full name, e.g., package.module.Cls.method + * As the function location as file and line separated by column, + e.g., path/to/file.py:42 + """ # TODO: Also return OverloadedFuncDef -- currently these are ignored. - graph = self.fgmanager.graph - target = split_target(graph, key) - if not target: - raise SuggestionFailure("Cannot find module for %s" % (key,)) - modname, tail = target + node = None # type: Optional[SymbolNode] + if ':' in key: + if key.count(':') > 1: + raise SuggestionFailure( + 'Malformed location for function: {}. Must be either' + ' package.module.Class.method or path/to/file.py:line'.format(key)) + file, line = key.split(':') + if not line.isdigit(): + raise SuggestionFailure('Line number must be a number. Got {}'.format(line)) + line_number = int(line) + modname, node = self.find_node_by_file_and_line(file, line_number) + tail = node.fullname()[len(modname) + 1:] # add one to account for '.' + else: + target = split_target(self.fgmanager.graph, key) + if not target: + raise SuggestionFailure("Cannot find module for %s" % (key,)) + modname, tail = target + node = self.find_node_by_module_and_name(modname, tail) - tree = self.ensure_loaded(graph[modname]) + if isinstance(node, Decorator): + node = self.extract_from_decorator(node) + if not node: + raise SuggestionFailure("Object %s is a decorator we can't handle" % key) + + if not isinstance(node, FuncDef): + raise SuggestionFailure("Object %s is not a function" % key) + + return modname, tail, node + + def find_node_by_module_and_name(self, modname: str, tail: str) -> Optional[SymbolNode]: + """Find symbol node by module id and qualified name. + + Raise SuggestionFailure if can't find one. + """ + tree = self.ensure_loaded(self.fgmanager.graph[modname]) # N.B. This is reimplemented from update's lookup_target # basically just to produce better error messages. @@ -416,18 +452,38 @@ def find_node(self, key: str) -> Tuple[str, str, FuncDef]: # Look for the actual function/method funcname = components[-1] if funcname not in names: + key = modname + '.' + tail raise SuggestionFailure("Unknown %s %s" % ("method" if len(components) > 1 else "function", key)) - node = names[funcname].node - if isinstance(node, Decorator): - node = self.extract_from_decorator(node) - if not node: - raise SuggestionFailure("Object %s is a decorator we can't handle" % key) + return names[funcname].node - if not isinstance(node, FuncDef): - raise SuggestionFailure("Object %s is not a function" % key) + def find_node_by_file_and_line(self, file: str, line: int) -> Tuple[str, SymbolNode]: + """Find symbol node by path to file and line number. - return (modname, tail, node) + Return module id and the node found. Raise SuggestionFailure if can't find one. + """ + if not any(file.endswith(ext) for ext in PYTHON_EXTENSIONS): + raise SuggestionFailure('Source file is not a Python file') + try: + modname, _ = self.finder.crawl_up(os.path.normpath(file)) + except InvalidSourceList: + raise SuggestionFailure('Invalid source file name: ' + file) + if modname not in self.graph: + raise SuggestionFailure('Unknown module: ' + modname) + # We must be sure about any edits in this file as this might affect the line numbers. + tree = self.ensure_loaded(self.fgmanager.graph[modname], force=True) + node = None # type: Optional[SymbolNode] + for _, sym, _ in tree.local_definitions(): + if isinstance(sym.node, FuncDef) and sym.node.line == line: + node = sym.node + break + elif isinstance(sym.node, Decorator) and sym.node.func.line == line: + node = sym.node + break + # TODO: add support for OverloadedFuncDef. + if not node: + raise SuggestionFailure('Cannot find a function at line {}'.format(line)) + return modname, node def extract_from_decorator(self, node: Decorator) -> Optional[FuncDef]: for dec in node.decorators: @@ -483,9 +539,9 @@ def reload(self, state: State, check_errors: bool = False) -> List[str]: raise SuggestionFailure("Error while trying to load %s" % state.id) return res - def ensure_loaded(self, state: State) -> MypyFile: + def ensure_loaded(self, state: State, force: bool = False) -> MypyFile: """Make sure that the module represented by state is fully loaded.""" - if not state.tree or state.tree.is_cache_skeleton: + if not state.tree or state.tree.is_cache_skeleton or force: self.reload(state, check_errors=True) assert state.tree is not None return state.tree @@ -493,9 +549,9 @@ def ensure_loaded(self, state: State) -> MypyFile: def builtin_type(self, s: str) -> Instance: return self.manager.semantic_analyzer.builtin_type(s) - def json_suggestion(self, function: str, suggestion: PyAnnotateSignature) -> str: + def json_suggestion(self, mod: str, func_name: str, node: FuncDef, + suggestion: PyAnnotateSignature) -> str: """Produce a json blob for a suggestion suitable for application by pyannotate.""" - mod, func_name, node = self.find_node(function) # pyannotate irritatingly drops class names for class and static methods if node.is_class or node.is_static: func_name = func_name.split('.', 1)[-1] diff --git a/mypy/test/testfinegrained.py b/mypy/test/testfinegrained.py index 736cfe000bb0..868dcfa39871 100644 --- a/mypy/test/testfinegrained.py +++ b/mypy/test/testfinegrained.py @@ -96,7 +96,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: if messages: a.extend(normalize_messages(messages)) - a.extend(self.maybe_suggest(step, server, main_src)) + assert testcase.tmpdir + a.extend(self.maybe_suggest(step, server, main_src, testcase.tmpdir.name)) if server.fine_grained_manager: if CHECK_CONSISTENCY: @@ -155,7 +156,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: a.append('==') a.extend(new_messages) - a.extend(self.maybe_suggest(step, server, main_src)) + assert testcase.tmpdir + a.extend(self.maybe_suggest(step, server, main_src, testcase.tmpdir.name)) # Normalize paths in test output (for Windows). a = [line.replace('\\', '/') for line in a] @@ -268,7 +270,7 @@ def parse_sources(self, program_text: str, return [base] + create_source_list([test_temp_dir], options, allow_empty_dir=True) - def maybe_suggest(self, step: int, server: Server, src: str) -> List[str]: + def maybe_suggest(self, step: int, server: Server, src: str, tmp_dir: str) -> List[str]: output = [] # type: List[str] targets = self.get_suggest(src, step) for flags, target in targets: @@ -285,13 +287,17 @@ def maybe_suggest(self, step: int, server: Server, src: str) -> List[str]: try_text=try_text, flex_any=flex_any, callsites=callsites)) val = res['error'] if 'error' in res else res['out'] + res['err'] + if json: + # JSON contains already escaped \ on Windows, so requires a bit of care. + val = val.replace('\\\\', '\\') + val = val.replace(tmp_dir + os.path.sep, '') output.extend(val.strip().split('\n')) return normalize_messages(output) def get_suggest(self, program_text: str, incremental_step: int) -> List[Tuple[str, str]]: step_bit = '1?' if incremental_step == 1 else str(incremental_step) - regex = '# suggest{}: (--[a-zA-Z0-9_\\-./=?^ ]+ )*([a-zA-Z0-9_./?^ ]+)$'.format(step_bit) + regex = '# suggest{}: (--[a-zA-Z0-9_\\-./=?^ ]+ )*([a-zA-Z0-9_.:/?^ ]+)$'.format(step_bit) m = re.findall(regex, program_text, flags=re.MULTILINE) return m diff --git a/test-data/unit/fine-grained-suggest.test b/test-data/unit/fine-grained-suggest.test index d24a2909559b..09ee7e1a215c 100644 --- a/test-data/unit/fine-grained-suggest.test +++ b/test-data/unit/fine-grained-suggest.test @@ -607,3 +607,115 @@ def bar(iany) -> None: (int, int) -> int (str, int) -> str == + +[case testSuggestColonBasic] +# suggest: tmp/foo.py:1 +# suggest: tmp/bar/baz.py:2 +[file foo.py] +def func(arg): + return 0 +func('test') +from bar.baz import C +C().method('test') +[file bar/__init__.py] +[file bar/baz.py] +class C: + def method(self, x): + return 0 +[out] +(str) -> int +(str) -> int +== + +[case testSuggestColonBadLocation] +# suggest: tmp/foo.py:7:8:9 +[file foo.py] +[out] +Malformed location for function: tmp/foo.py:7:8:9. Must be either package.module.Class.method or path/to/file.py:line +== + +[case testSuggestColonBadLine] +# suggest: tmp/foo.py:bad +[file foo.py] +[out] +Line number must be a number. Got bad +== + +[case testSuggestColonBadFile] +# suggest: tmp/foo.txt:1 +[file foo.txt] +def f(): pass +[out] +Source file is not a Python file +== + +[case testSuggestColonUnknownLine] +# suggest: tmp/foo.py:42 +[file foo.py] +def func(x): + return 0 +func('test') +[out] +Cannot find a function at line 42 +== + +[case testSuggestColonClass] +# suggest: tmp/foo.py:1 +[file foo.py] +class C: + pass +[out] +Cannot find a function at line 1 +== + +[case testSuggestColonDecorator] +# suggest: tmp/foo.py:6 +[file foo.py] +from typing import TypeVar, Callable, Any +F = TypeVar('F', bound=Callable[..., Any]) +def deco(f: F) -> F: ... + +@deco +def func(arg): + return 0 +func('test') +[out] +(str) -> int +== + +[case testSuggestColonMethod] +# suggest: tmp/foo.py:3 +[file foo.py] +class Out: + class In: + def method(self, x): + return Out() +x: Out.In +x.method(x) +[out] +(foo:Out.In) -> foo.Out +== + +[case testSuggestColonMethodJSON] +# suggest: --json tmp/foo.py:3 +[file foo.py] +class Out: + class In: + def method(self, x): + return Out() +x: Out.In +x.method(x) +[out] +[[{"func_name": "Out.In.method", "line": 3, "path": "tmp/foo.py", "samples": 0, "signature": {"arg_types": ["foo:Out.In"], "return_type": "foo.Out"}}] +== + +[case testSuggestColonNonPackageDir] +# cmd: mypy foo/bar/baz.py +# suggest: tmp/foo/bar/baz.py:1 +[file foo/bar/baz.py] +def func(arg): + return 0 +func('test') +[out] +(str) -> int +==