Skip to content

Commit

Permalink
Merge pull request #3977 from tybug/misaligned-caching
Browse files Browse the repository at this point in the history
Improve caching for misaligned `ConjectureData`s
  • Loading branch information
tybug committed May 9, 2024
2 parents 4071c30 + e9799d0 commit c579480
Show file tree
Hide file tree
Showing 10 changed files with 456 additions and 50 deletions.
3 changes: 3 additions & 0 deletions hypothesis-python/RELEASE.rst
@@ -0,0 +1,3 @@
RELEASE_TYPE: patch

This patch improves our internal caching logic. We don't expect it to result in any performance improvements (yet!).
92 changes: 78 additions & 14 deletions hypothesis-python/src/hypothesis/internal/conjecture/data.py
Expand Up @@ -136,6 +136,7 @@ class BooleanKWargs(TypedDict):
IntegerKWargs, FloatKWargs, StringKWargs, BytesKWargs, BooleanKWargs
]
IRTypeName: TypeAlias = Literal["integer", "string", "boolean", "float", "bytes"]
InvalidAt: TypeAlias = Tuple[IRTypeName, IRKWargsType]


class ExtraInformation:
Expand Down Expand Up @@ -956,6 +957,9 @@ def draw_boolean(
) -> None:
pass

def mark_invalid(self, invalid_at: InvalidAt) -> None:
pass


@attr.s(slots=True, repr=False, eq=False)
class IRNode:
Expand Down Expand Up @@ -1048,6 +1052,16 @@ def __eq__(self, other):
and self.was_forced == other.was_forced
)

def __hash__(self):
return hash(
(
self.ir_type,
ir_value_key(self.ir_type, self.value),
ir_kwargs_key(self.ir_type, self.kwargs),
self.was_forced,
)
)

def __repr__(self):
# repr to avoid "BytesWarning: str() on a bytes instance" for bytes nodes
forced_marker = " [forced]" if self.was_forced else ""
Expand Down Expand Up @@ -1087,22 +1101,36 @@ def ir_value_permitted(value, ir_type, kwargs):
raise NotImplementedError(f"unhandled type {type(value)} of ir value {value}")


def ir_value_key(ir_type, v):
if ir_type == "float":
return float_to_int(v)
return v


def ir_kwargs_key(ir_type, kwargs):
if ir_type == "float":
return (
float_to_int(kwargs["min_value"]),
float_to_int(kwargs["max_value"]),
kwargs["allow_nan"],
kwargs["smallest_nonzero_magnitude"],
)
if ir_type == "integer":
return (
kwargs["min_value"],
kwargs["max_value"],
None if kwargs["weights"] is None else tuple(kwargs["weights"]),
kwargs["shrink_towards"],
)
return tuple(kwargs[key] for key in sorted(kwargs))


def ir_value_equal(ir_type, v1, v2):
if ir_type != "float":
return v1 == v2
return float_to_int(v1) == float_to_int(v2)
return ir_value_key(ir_type, v1) == ir_value_key(ir_type, v2)


def ir_kwargs_equal(ir_type, kwargs1, kwargs2):
if ir_type != "float":
return kwargs1 == kwargs2
return (
float_to_int(kwargs1["min_value"]) == float_to_int(kwargs2["min_value"])
and float_to_int(kwargs1["max_value"]) == float_to_int(kwargs2["max_value"])
and kwargs1["allow_nan"] == kwargs2["allow_nan"]
and kwargs1["smallest_nonzero_magnitude"]
== kwargs2["smallest_nonzero_magnitude"]
)
return ir_kwargs_key(ir_type, kwargs1) == ir_kwargs_key(ir_type, kwargs2)


@dataclass_transform()
Expand All @@ -1125,6 +1153,7 @@ class ConjectureResult:
examples: Examples = attr.ib(repr=False)
arg_slices: Set[Tuple[int, int]] = attr.ib(repr=False)
slice_comments: Dict[Tuple[int, int], str] = attr.ib(repr=False)
invalid_at: Optional[InvalidAt] = attr.ib(repr=False)

index: int = attr.ib(init=False)

Expand Down Expand Up @@ -1977,6 +2006,7 @@ def __init__(
self.extra_information = ExtraInformation()

self.ir_tree_nodes = ir_tree_prefix
self.invalid_at: Optional[InvalidAt] = None
self._node_index = 0
self.start_example(TOP_LABEL)

Expand Down Expand Up @@ -2279,7 +2309,6 @@ def _pop_ir_tree_node(self, ir_type: IRTypeName, kwargs: IRKWargsType) -> IRNode
self.mark_overrun()

node = self.ir_tree_nodes[self._node_index]
self._node_index += 1
# If we're trying to draw a different ir type at the same location, then
# this ir tree has become badly misaligned. We don't have many good/simple
# options here for realigning beyond giving up.
Expand All @@ -2292,14 +2321,21 @@ def _pop_ir_tree_node(self, ir_type: IRTypeName, kwargs: IRKWargsType) -> IRNode
# (in fact, it is possible that giving up early here results in more time
# for useful shrinks to run).
if node.ir_type != ir_type:
invalid_at = (ir_type, kwargs)
self.invalid_at = invalid_at
self.observer.mark_invalid(invalid_at)
self.mark_invalid(f"(internal) want a {ir_type} but have a {node.ir_type}")

# if a node has different kwargs (and so is misaligned), but has a value
# that is allowed by the expected kwargs, then we can coerce this node
# into an aligned one by using its value. It's unclear how useful this is.
if not ir_value_permitted(node.value, node.ir_type, kwargs):
invalid_at = (ir_type, kwargs)
self.invalid_at = invalid_at
self.observer.mark_invalid(invalid_at)
self.mark_invalid(f"(internal) got a {ir_type} but outside the valid range")

self._node_index += 1
return node

def as_result(self) -> Union[ConjectureResult, _Overrun]:
Expand Down Expand Up @@ -2328,6 +2364,7 @@ def as_result(self) -> Union[ConjectureResult, _Overrun]:
forced_indices=frozenset(self.forced_indices),
arg_slices=self.arg_slices,
slice_comments=self.slice_comments,
invalid_at=self.invalid_at,
)
assert self.__result is not None
self.blocks.transfer_ownership(self.__result)
Expand Down Expand Up @@ -2475,7 +2512,34 @@ def freeze(self) -> None:
self.frozen = True

self.buffer = bytes(self.buffer)
self.observer.conclude_test(self.status, self.interesting_origin)

# if we were invalid because of a misalignment in the tree, we don't
# want to tell the DataTree that. Doing so would lead to inconsistent behavior.
# Given an empty DataTree
# ┌──────┐
# │ root │
# └──────┘
# and supposing the very first draw is misaligned, concluding here would
# tell the datatree that the *only* possibility at the root node is Status.INVALID:
# ┌──────┐
# │ root │
# └──┬───┘
# ┌───────────┴───────────────┐
# │ Conclusion(Status.INVALID)│
# └───────────────────────────┘
# when in fact this is only the case when we try to draw a misaligned node.
# For instance, suppose we come along in the second test case and try a
# valid node as the first draw from the root. The DataTree thinks this
# is flaky (because root must lead to Status.INVALID in the tree) while
# in fact nothing in the test function has changed and the only change
# is in the ir tree prefix we are supplying.
#
# From the perspective of DataTree, it is safe to not conclude here. This
# tells the datatree that we don't know what happens after this node - which
# is true! We are aborting early here because the ir tree became misaligned,
# which is a semantically different invalidity than an assume or filter failing.
if self.invalid_at is None:
self.observer.conclude_test(self.status, self.interesting_origin)

def choice(
self,
Expand Down
18 changes: 18 additions & 0 deletions hypothesis-python/src/hypothesis/internal/conjecture/datatree.py
Expand Up @@ -24,6 +24,7 @@
DataObserver,
FloatKWargs,
IntegerKWargs,
InvalidAt,
IRKWargsType,
IRType,
IRTypeName,
Expand Down Expand Up @@ -422,6 +423,8 @@ class TreeNode:
# be explored when generating novel prefixes)
transition: Union[None, Branch, Conclusion, Killed] = attr.ib(default=None)

invalid_at: Optional[InvalidAt] = attr.ib(default=None)

# A tree node is exhausted if every possible sequence of draws below it has
# been explored. We only update this when performing operations that could
# change the answer.
Expand Down Expand Up @@ -475,6 +478,8 @@ def split_at(self, i):
del self.ir_types[i:]
del self.values[i:]
del self.kwargs[i:]
# we have a transition now, so we don't need to carry around invalid_at.
self.invalid_at = None
assert len(self.values) == len(self.kwargs) == len(self.ir_types) == i

def check_exhausted(self):
Expand Down Expand Up @@ -811,6 +816,9 @@ def simulate_test_function(self, data):
node = self.root

def draw(ir_type, kwargs, *, forced=None):
if ir_type == "float" and forced is not None:
forced = int_to_float(forced)

draw_func = getattr(data, f"draw_{ir_type}")
value = draw_func(**kwargs, forced=forced)

Expand All @@ -832,6 +840,13 @@ def draw(ir_type, kwargs, *, forced=None):
t = node.transition
data.conclude_test(t.status, t.interesting_origin)
elif node.transition is None:
if node.invalid_at is not None:
(ir_type, kwargs) = node.invalid_at
try:
draw(ir_type, kwargs)
except StopTest:
if data.invalid_at is not None:
raise
raise PreviouslyUnseenBehaviour
elif isinstance(node.transition, Branch):
v = draw(node.transition.ir_type, node.transition.kwargs)
Expand Down Expand Up @@ -977,6 +992,9 @@ def draw_boolean(
) -> None:
self.draw_value("boolean", value, was_forced=was_forced, kwargs=kwargs)

def mark_invalid(self, invalid_at: InvalidAt) -> None:
self.__current_node.invalid_at = invalid_at

def draw_value(
self,
ir_type: IRTypeName,
Expand Down
90 changes: 83 additions & 7 deletions hypothesis-python/src/hypothesis/internal/conjecture/engine.py
Expand Up @@ -34,6 +34,8 @@
Overrun,
PrimitiveProvider,
Status,
ir_kwargs_key,
ir_value_key,
)
from hypothesis.internal.conjecture.datatree import DataTree, PreviouslyUnseenBehaviour
from hypothesis.internal.conjecture.junkdrawer import clamp, ensure_free_stackframes
Expand Down Expand Up @@ -186,6 +188,7 @@ def __init__(
# shrinking where we need to know about the structure of the
# executed test case.
self.__data_cache = LRUReusedCache(CACHE_SIZE)
self.__data_cache_ir = LRUReusedCache(CACHE_SIZE)

self.__pending_call_explanation = None
self._switch_to_hypothesis_provider = False
Expand Down Expand Up @@ -239,10 +242,70 @@ def __stoppable_test_function(self, data):
# correct engine.
raise

def ir_tree_to_data(self, ir_tree_nodes):
data = ConjectureData.for_ir_tree(ir_tree_nodes)
self.__stoppable_test_function(data)
return data
def _cache_key_ir(self, *, nodes=None, data=None):
assert (nodes is not None) ^ (data is not None)
extension = []
if data is not None:
nodes = data.examples.ir_tree_nodes
if data.invalid_at is not None:
# if we're invalid then we should have at least one node left (the invalid one).
assert data._node_index < len(data.ir_tree_nodes)
extension = [data.ir_tree_nodes[data._node_index]]

# intentionally drop was_forced from equality here, because the was_forced
# of node prefixes on ConjectureData has no impact on that data's result
return tuple(
(
node.ir_type,
ir_value_key(node.ir_type, node.value),
ir_kwargs_key(node.ir_type, node.kwargs),
)
for node in nodes + extension
)

def _cache(self, data):
result = data.as_result()
# when we shrink, we try out of bounds things, which can lead to the same
# data.buffer having multiple outcomes. eg data.buffer=b'' is Status.OVERRUN
# in normal circumstances, but a data with
# ir_nodes=[integer -5 {min_value: 0, max_value: 10}] will also have
# data.buffer=b'' but will be Status.INVALID instead. We do not want to
# change the cached value to INVALID in this case.
#
# We handle this specially for the ir cache by keying off the misaligned node
# as well, but we cannot do the same for buffers as we do not know ahead of
# time what buffer a node maps to. I think it's largely fine that we don't
# write to the buffer cache here as we move more things to the ir cache.
if data.invalid_at is None:
self.__data_cache[data.buffer] = result
key = self._cache_key_ir(data=data)
self.__data_cache_ir[key] = result

def cached_test_function_ir(self, nodes):
key = self._cache_key_ir(nodes=nodes)
try:
return self.__data_cache_ir[key]
except KeyError:
pass

try:
trial_data = self.new_conjecture_data_ir(nodes)
self.tree.simulate_test_function(trial_data)
except PreviouslyUnseenBehaviour:
pass
else:
trial_data.freeze()
key = self._cache_key_ir(data=trial_data)
try:
return self.__data_cache_ir[key]
except KeyError:
pass

data = self.new_conjecture_data_ir(nodes)
# note that calling test_function caches `data` for us, for both an ir
# tree key and a buffer key.
self.test_function(data)
return data.as_result()

def test_function(self, data):
if self.__pending_call_explanation is not None:
Expand Down Expand Up @@ -274,7 +337,7 @@ def test_function(self, data):
),
}
self.stats_per_test_case.append(call_stats)
self.__data_cache[data.buffer] = data.as_result()
self._cache(data)

self.debug_data(data)

Expand Down Expand Up @@ -321,8 +384,9 @@ def test_function(self, data):

# drive the ir tree through the test function to convert it
# to a buffer
data = self.ir_tree_to_data(data.examples.ir_tree_nodes)
self.__data_cache[data.buffer] = data.as_result()
data = ConjectureData.for_ir_tree(data.examples.ir_tree_nodes)
self.__stoppable_test_function(data)
self._cache(data)

key = data.interesting_origin
changed = False
Expand Down Expand Up @@ -983,6 +1047,18 @@ def _run(self):
self.shrink_interesting_examples()
self.exit_with(ExitReason.finished)

def new_conjecture_data_ir(self, ir_tree_prefix, *, observer=None):
provider = (
HypothesisProvider if self._switch_to_hypothesis_provider else self.provider
)
observer = observer or self.tree.new_observer()
if self.settings.backend != "hypothesis":
observer = DataObserver()

return ConjectureData.for_ir_tree(
ir_tree_prefix, observer=observer, provider=provider
)

def new_conjecture_data(self, prefix, max_length=BUFFER_SIZE, observer=None):
provider = (
HypothesisProvider if self._switch_to_hypothesis_provider else self.provider
Expand Down

0 comments on commit c579480

Please sign in to comment.