Skip to content

Commit

Permalink
Type-check assignments to annotated function parameter names.
Browse files Browse the repository at this point in the history
This is a little different from what b/113372317 asks for but in the same
spirit: if you don't allow `def f(x: int = 0.0): ...`, why would you allow
`def f(x: int): x = 0.0`? Plus, now that we're enforcing annotations strictly
everywhere else, it would be confusing to be lenient here.
PiperOrigin-RevId: 317218197
  • Loading branch information
rchen152 committed Jun 23, 2020
1 parent 8be7866 commit 16ac66c
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 64 deletions.
9 changes: 8 additions & 1 deletion pytype/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3265,7 +3265,14 @@ def call(self, node, func, args, new_locals=False, alias_map=None):
node2, _ = async_generator.run_generator(node)
node_after_call, ret = node2, async_generator.to_variable(node2)
else:
node2, ret = self.vm.run_frame(frame, node)
if self.vm.options.check_parameter_types:
annotated_locals = {
name: abstract_utils.Local(node, self.get_first_opcode(), annot,
callargs.get(name), self.vm)
for name, annot in annotations.items() if name != "return"}
else:
annotated_locals = {}
node2, ret = self.vm.run_frame(frame, node, annotated_locals)
if self.is_coroutine():
ret = Coroutine(self.vm, ret, node2).to_variable(node2)
node_after_call = node2
Expand Down
51 changes: 51 additions & 0 deletions pytype/abstract_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,54 @@ def get_annotations_dict(members):
except ConversionError:
return None
return annots if annots.isinstance_AnnotationsDict() else None


class Local:
"""A possibly annotated local variable."""

def __init__(self, node, op, typ, orig, vm):
self._ops = [op]
if typ:
self.typ = vm.program.NewVariable([typ], [], node)
else:
# Creating too many variables bloats the typegraph, hurting performance,
# so we use None instead of an empty variable.
self.typ = None
self.orig = orig
self.vm = vm

@property
def last_op(self):
# TODO(b/74434237): This property can be removed once the usage of it in
# dataclass_overlay is gone.
return self._ops[-1]

@property
def stack(self):
return self.vm.simple_stack(self.last_op)

def update(self, node, op, typ, orig):
"""Update this variable's annotation and/or value."""
if op in self._ops:
return
self._ops.append(op)
if typ:
if self.typ:
self.typ.AddBinding(typ, [], node)
else:
self.typ = self.vm.program.NewVariable([typ], [], node)
if orig:
self.orig = orig

def get_type(self, node, name):
"""Gets the variable's annotation."""
if not self.typ:
return None
values = self.typ.Data(node)
if len(values) > 1:
self.vm.errorlog.ambiguous_annotation(self.stack, values, name)
return self.vm.convert.unsolvable
elif values:
return values[0]
else:
return None
3 changes: 2 additions & 1 deletion pytype/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def add_basic_options(o):
o.add_argument(
"--check-parameter-types", action="store_true",
dest="check_parameter_types", default=False,
help="Check parameter defaults against their annotations. " + temporary)
help=("Check parameter defaults and assignments against their "
"annotations. " + temporary))
o.add_argument(
"--check-variable-types", action="store_true",
dest="check_variable_types", default=False,
Expand Down
7 changes: 3 additions & 4 deletions pytype/tests/py3/test_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@ def f(x: Optional[str]):
""")
self.assertErrorRegexes(errors, {"e": r"upper.*None"})

def test_error_in_any(self):
errors = self.CheckWithErrors("""
def test_any_annotation(self):
self.Check("""
from typing import Any
def f(x: Any):
if __random__:
x = 42
x.upper() # attribute-error[e]
x.upper()
""")
self.assertErrorRegexes(errors, {"e": r"upper.*int.*Union\[Any, int\]"})


class TestAttributesPython3FeatureTest(test_base.TargetPython3FeatureTest):
Expand Down
2 changes: 1 addition & 1 deletion pytype/tests/py3/test_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_set(self):
self.Check("""
from typing import List, Set
def f(data: List[str]):
data = set(x for x in data)
data = set(x for x in data) # type: Set[str]
g(data)
def g(data: Set[str]):
pass
Expand Down
22 changes: 22 additions & 0 deletions pytype/tests/py3/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,28 @@ def f(x: int = '', y: str = 0): # annotation-type-mismatch[e1] # annotation-ty
"e2": r"Annotation: str.*Assignment: int"})


class TestCheckParameterAssignment(test_base.TargetPython3BasicTest):
"""Tests for checking assignments to annotated parameters."""

def test_basic(self):
errors = self.CheckWithErrors("""
def f(x: int):
x = '' # annotation-type-mismatch[e]
""")
self.assertErrorRegexes(errors, {"e": r"Annotation: int.*Assignment: str"})

def test_typevar(self):
errors = self.CheckWithErrors("""
from typing import TypeVar
T = TypeVar('T')
def f(x: T, y: T):
x = 0 # annotation-type-mismatch[e]
f('', '')
""")
self.assertErrorRegexes(
errors, {"e": r"Annotation: str.*Assignment: int.*Called from.*line 5"})


class TestFunctions(test_base.TargetPython3BasicTest):
"""Tests for functions."""

Expand Down
6 changes: 3 additions & 3 deletions pytype/tests/py3/test_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def f(x: Optional[NoneType]) -> int:
if x is not None:
x = None
if x is None:
x = 1
x = 1 # type: int
return x
""")

Expand All @@ -112,9 +112,9 @@ def test_guarding_is_not_else(self):
from typing import Optional
def f(x: Optional[str]) -> int:
if x is None:
x = 1
x = 1 # type: int
else:
x = 1
x = 1 # type: int
return x
""")

Expand Down
60 changes: 6 additions & 54 deletions pytype/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,57 +75,6 @@ def is_annotate(self):
return self.op == self.ANNOTATE


class Local:
"""A possibly annotated local variable."""

def __init__(self, node, op, typ, orig, vm):
self._ops = [op]
if typ:
self.typ = vm.program.NewVariable([typ], [], node)
else:
# Creating too many variables bloats the typegraph, hurting performance,
# so we use None instead of an empty variable.
self.typ = None
self.orig = orig
self.vm = vm

@property
def last_op(self):
# TODO(b/74434237): This property can be removed once the usage of it in
# dataclass_overlay is gone.
return self._ops[-1]

@property
def stack(self):
return self.vm.simple_stack(self.last_op)

def update(self, node, op, typ, orig):
"""Update this variable's annotation and/or value."""
if op in self._ops:
return
self._ops.append(op)
if typ:
if self.typ:
self.typ.AddBinding(typ, [], node)
else:
self.typ = self.vm.program.NewVariable([typ], [], node)
if orig:
self.orig = orig

def get_type(self, node, name):
"""Gets the variable's annotation."""
if not self.typ:
return None
values = self.typ.Data(node)
if len(values) > 1:
self.vm.errorlog.ambiguous_annotation(self.stack, values, name)
return self.vm.convert.unsolvable
elif values:
return values[0]
else:
return None


_opcode_counter = metrics.MapCounter("vm_opcode")


Expand Down Expand Up @@ -356,7 +305,7 @@ def join_cfg_nodes(self, nodes):
node.ConnectTo(ret)
return ret

def run_frame(self, frame, node):
def run_frame(self, frame, node, annotated_locals=None):
"""Run a frame (typically belonging to a method)."""
self.push_frame(frame)
frame.states[frame.f_code.co_code[0]] = frame_state.FrameState.init(
Expand All @@ -367,7 +316,9 @@ def run_frame(self, frame, node):
# don't care to track locals for this frame and don't want it to overwrite
# the locals of the actual module frame.
self.local_ops[frame_name] = []
self.annotated_locals[frame_name] = {}
self.annotated_locals[frame_name] = annotated_locals or {}
else:
assert annotated_locals is None
can_return = False
return_nodes = []
for block in frame.f_code.order:
Expand Down Expand Up @@ -1315,7 +1266,8 @@ def _update_annotations_dict(
if name in annotations_dict:
annotations_dict[name].update(node, op, typ, orig_val)
else:
annotations_dict[name] = Local(node, op, typ, orig_val, self)
annotations_dict[name] = abstract_utils.Local(
node, op, typ, orig_val, self)

def _store_value(self, state, name, value, local):
if local:
Expand Down

0 comments on commit 16ac66c

Please sign in to comment.