Skip to content

Commit

Permalink
working on exclude/include
Browse files Browse the repository at this point in the history
  • Loading branch information
mristin committed Dec 10, 2020
1 parent d27e7af commit f961ae6
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 31 deletions.
163 changes: 133 additions & 30 deletions icontract/integration/with_hypothesis/icontract_hypothesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import sys
import tokenize
import types
from typing import List, Optional, Tuple, TextIO, Mapping, Any, MutableMapping, Union, Callable
from typing import List, Optional, Tuple, TextIO, Mapping, Any, MutableMapping, Union, Callable, Set

import asttokens

Expand All @@ -23,11 +23,13 @@

class LineRange:
"""Represent a line range (indexed from 1, both first and last inclusive)."""
def __init__(self, first: int, last: int)->None:

def __init__(self, first: int, last: int) -> None:
"""Initialize with the given values."""
self.first = first
self.last = last


class ParamsGeneral:
"""Represent general program parameters specified regardless of the command."""

Expand All @@ -41,9 +43,11 @@ def __init__(
self.include = include
self.exclude = exclude


_LINE_RANGE_RE = re.compile(r'^\s*(?P<first>[0-9]|[1-9][0-9]+)(\s*-\s*(?P<last>[1-9]|[1-9][0-9]+))?\s*$')

def _parse_point_spec(text: str)->Tuple[Optional[Union[LineRange, re.Pattern]], List[str]]:

def _parse_point_spec(text: str) -> Tuple[Optional[Union[LineRange, re.Pattern]], List[str]]:
"""
Try to parse the given specification of function point(s).
Expand All @@ -67,11 +71,11 @@ def _parse_point_spec(text: str)->Tuple[Optional[Union[LineRange, re.Pattern]],
if first <= 0:
errors.append("Unexpected line index (expected to start from 1): {}".format(text))
return None, errors

if last < first:
errors.append("Unexpected line range (last < first): {}".format(text))
return None, errors

else:
return LineRange(first=int(mtch.group('first')), last=int(mtch.group('last'))), errors

Expand All @@ -81,7 +85,8 @@ def _parse_point_spec(text: str)->Tuple[Optional[Union[LineRange, re.Pattern]],
except re.error as err:
errors.append("Failed to parse the pattern {}: {}".format(text, err))
return None, errors



def _parse_general_params(args: argparse.Namespace) -> Tuple[Optional[ParamsGeneral], List[str]]:
"""
Try to parse general parameters of the program (regardless of the command).
Expand All @@ -95,7 +100,7 @@ def _parse_general_params(args: argparse.Namespace) -> Tuple[Optional[ParamsGene
for include_str in args.include:
point_spec, point_spec_errors = _parse_point_spec(text=include_str)
errors.extend(point_spec_errors)

if not point_spec_errors:
include.append(point_spec)

Expand Down Expand Up @@ -326,13 +331,16 @@ def captured_output():
out.seek(0)
return None, out.read(), err.read()

_PYICONTRACT_HYPOTHESIS_DIRECTIVE_RE = re.compile(r'#\s*pyicontract-hypothesis\s*:\s*(?P<value>disable|enable)')

_DIRECTIVE_RE = re.compile(r'^#\s*pyicontract-hypothesis\s*:\s*(?P<value>[^ \t]*)\s*$')


class Point:
"""Represent a testable function."""
@icontract.require(lambda srow: srow > 0)
@icontract.require(lambda erow: erow > 0)
@icontract.require(lambda srow, erow: srow <= erow)

@icontract.require(lambda first_row: first_row > 0)
@icontract.require(lambda last_row: last_row > 0)
@icontract.require(lambda first_row, last_row: first_row <= last_row)
def __init__(self, first_row: int, last_row: int, func: Callable[..., Any]) -> None:
"""
Initialize with the given values.
Expand All @@ -343,30 +351,132 @@ def __init__(self, first_row: int, last_row: int, func: Callable[..., Any]) -> N
self.last_row = last_row
self.func = func


def _select_points(
source_code: str,
mod: types.ModuleType,
include: List[Union[LineRange, re.Pattern]],
exclude: List[Union[LineRange, re.Pattern]]
)->Tuple[List[Callable[..., Any]], List[str]]:
) -> Tuple[List[Point], List[str]]:
points = [] # type: List[Point]
errors = [] # type: List[str]

for key in dir(mod):
value = getattr(mod, key)
if inspect.isfunction(value):
func = value # type: Callable[..., Any]
source_lines, srow = inspect.getsourcelines(func)

point = Point(first_row=srow, last_row = srow+len(source_lines) - 1, func=func)
point = Point(first_row=srow, last_row=srow + len(source_lines) - 1, func=func)
points.append(point)

# TODO: exclude functions if they have the directive in the body
# TODO: exclude ranges of functions if the comment is in the root
##
# Exclude functions which have the disable directive in the body
##

included = points

# TODO: test invalid value different from disable-once

excluded = set() # type: Set[Point]
for point in points:
reader = io.BytesIO(inspect.getsource(point.func).encode('utf-8'))
for toktype, _, (first_row, _), _, line in tokenize.tokenize(reader.readline):
if toktype == tokenize.COMMENT:
mtch = _DIRECTIVE_RE.match(line.strip())
if mtch:
value = mtch.group('value')

if value != 'disable-once':
errors.append(
("Unexpected directive within a function {} on line {}. "
"Only '# pyicontract-hypothesis: disable-once' expected, "
"but got: {}").format(point.func.__name__, first_row + point.first_row - 1, line.strip()))
continue

excluded.add(point)

if errors:
return [], errors

included = [point for point in points if point not in excluded]

##
# Add ranges of lines given by comment directives to the ``exclude``
##

extended_exclude = exclude[:]

# TODO: test invalid value different from enable/disable

range_start = None # type: Optional[int]
reader = io.BytesIO(source_code.encode('utf-8'))
for toktype, _, (first_row, _), _, line in tokenize.tokenize(reader.readline):
if toktype == tokenize.COMMENT:
mtch = _DIRECTIVE_RE.match(line.strip())
if mtch:
value = mtch.group('value')

if value not in ['enable', 'disable']:
errors.append(
("Unexpected directive on line {}. "
"Expected '# pyicontract-hypothesis: (disable|enable)', "
"but got: {}").format(
first_row, line.strip()))
continue

if value == 'disable':
if range_start is not None:
continue

range_start = first_row

elif value == 'enable':
if range_start is not None:
extended_exclude.append(LineRange(first=range_start, last=first_row))

return [], []
else:
raise AssertionError("Unexpected value: {}".format(json.dumps(value)))

if errors:
return [], errors

##
# Remove ``included`` which do not match ``include``
##
# TODO: two-iterator sweep

##
# Exclude all points in ``included`` if matched in ``extended_exclude``
##
# TODO: two-iterator sweep

# TODO: test at point level, manually calling _load_module_from_source_file

return included, []


def _load_module_from_source_file(path: pathlib.Path) -> Tuple[Optional[types.ModuleType], List[str]]:
"""
Try to load a module from the source file.
def test(general: ParamsGeneral, command: ParamsTest)->List[str]:
Return (loaded module, errors if any).
"""
fullname = re.sub(r'[^A-Za-z0-9_]', '_', path.stem)

mod = None # type: Optional[types.ModuleType]
try:
loader = importlib.machinery.SourceFileLoader(fullname=fullname, path=str(path))
mod = types.ModuleType(loader.name)
loader.exec_module(mod)
except Exception as error:
return None, ['Failed to import the file {}: {}'.format(path, error)]

assert mod is not None, "Expected mod to be set before"

return mod, []


def test(general: ParamsGeneral, command: ParamsTest) -> List[str]:
"""
Test the specified functions.
Expand All @@ -380,17 +490,9 @@ def test(general: ParamsGeneral, command: ParamsTest)->List[str]:
except Exception as error:
return ['Failed to read the file {}: {}'.format(command.path, error)]

fullname = re.sub(r'[^A-Za-z0-9_]', '_', command.path.stem)

mod = None # type: Optional[types.ModuleType]
try:
loader = importlib.machinery.SourceFileLoader(fullname=fullname, path=str(command.path))
mod = types.ModuleType(loader.name)
loader.exec_module(mod)
except Exception as error:
return ['Failed to import the file {}: {}'.format(command.path, error)]

assert mod is not None, "Expected mod to be set before"
mod, errors = _load_module_from_source_file(path=command.path)
if errors:
return errors

points, errors = _select_points(source_code=source_code, mod=mod, include=general.include, exclude=general.exclude)
if errors:
Expand All @@ -399,14 +501,15 @@ def test(general: ParamsGeneral, command: ParamsTest)->List[str]:
print(f"points is {points!r}") # TODO: debug


def ghostwrite(general: ParamsGeneral, command: ParamsGhostwrite)->Tuple[str, List[str]]:
def ghostwrite(general: ParamsGeneral, command: ParamsGhostwrite) -> Tuple[str, List[str]]:
"""
Write a unit test module for the specified functions.
Return (generated code, errors if any).
"""
raise NotImplementedError()


def testable_main(argv: List[str], stdout: TextIO, stderr: TextIO) -> int:
"""Execute the testable_main routine."""
parser = _make_argument_parser()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/with_hypothesis/sample_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def untestable_some_func(x: int) -> None:
pass

def untestable_another_func(x: int) -> None:
# pyicontract-hypothesis: disable
# pyicontract-hypothesis: disable-for-this-function
pass

# pyicontract-hypothesis: disable
Expand Down

0 comments on commit f961ae6

Please sign in to comment.