Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support passing in Graph.artifacts #343

Merged
merged 1 commit into from
Mar 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/arti/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, cast

from pydantic import Field, PrivateAttr
from pydantic import Field, PrivateAttr, validator

import arti
from arti import io
Expand Down Expand Up @@ -105,6 +105,7 @@ class Graph(Model):
"""Graph stores a web of Artifacts connected by Producers."""

name: str
artifacts: ArtifactBox = Field(default_factory=lambda: ArtifactBox(**BOX_KWARGS[SEALED]))
# The Backend *itself* should not affect the results of a Graph build, though the contents
# certainly may (eg: stored annotations), so we avoid serializing it. This also prevent
# embedding any credentials.
Expand All @@ -113,9 +114,13 @@ class Graph(Model):

# Graph starts off sealed, but is opened within a `with Graph(...)` context
_status: Optional[bool] = PrivateAttr(None)
_artifacts: ArtifactBox = PrivateAttr(default_factory=lambda: ArtifactBox(**BOX_KWARGS[SEALED]))
_artifact_to_key: frozendict[Artifact, str] = PrivateAttr(frozendict())

@validator("artifacts")
@classmethod
def _convert_artifacts(cls, artifacts: ArtifactBox) -> ArtifactBox:
return ArtifactBox(artifacts, **BOX_KWARGS[SEALED])

def __enter__(self) -> Graph:
if arti.context.graph is not None:
raise ValueError(f"Another graph is being defined: {arti.context.graph}")
Expand All @@ -135,16 +140,13 @@ def __exit__(
TopologicalSorter(self.dependencies).prepare()

def _toggle(self, status: bool) -> None:
# The Graph object is "frozen", so we must bypass the assignment checks.
object.__setattr__(self, "artifacts", ArtifactBox(self.artifacts, **BOX_KWARGS[status]))
self._status = status
self._artifacts = ArtifactBox(self.artifacts, **BOX_KWARGS[status])
self._artifact_to_key = frozendict(
{artifact: key for key, artifact in self.artifacts.walk()}
)

@property
def artifacts(self) -> ArtifactBox:
return self._artifacts

@property
def artifact_to_key(self) -> frozendict[Artifact, str]:
return self._artifact_to_key
Expand Down
42 changes: 39 additions & 3 deletions src/arti/internal/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from functools import cached_property, partial
Expand All @@ -9,10 +11,12 @@
Literal,
Optional,
TypeVar,
Union,
get_args,
get_origin,
)

from box import Box
from pydantic import BaseModel, Extra, root_validator, validator
from pydantic.fields import ModelField, Undefined
from pydantic.json import pydantic_encoder as pydantic_json_encoder
Expand All @@ -21,6 +25,8 @@
from arti.internal.utils import class_name, frozendict

if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny

from arti.fingerprints import Fingerprint
from arti.types import Type

Expand Down Expand Up @@ -211,7 +217,7 @@ def _fingerprint_json_encoder(obj: Any, encoder: Any = pydantic_json_encoder) ->
return encoder(obj)

@property
def fingerprint(self) -> "Fingerprint":
def fingerprint(self) -> Fingerprint:
from arti.fingerprints import Fingerprint

# `.json` cannot be used, even with a custom encoder, because it calls `.dict`, which
Expand All @@ -233,6 +239,36 @@ def fingerprint(self) -> "Fingerprint":
)
return Fingerprint.from_string(f"{self._class_key_}:{json_repr}")

@classmethod
def _get_value(
cls,
v: Any,
to_dict: bool,
by_alias: bool,
include: Optional[Union[AbstractSetIntStr, MappingIntStrAny]],
exclude: Optional[Union[AbstractSetIntStr, MappingIntStrAny]],
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
) -> Any:
new = super()._get_value(
v,
to_dict=to_dict,
by_alias=by_alias,
include=include,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
# Copying dict subclasses doesn't preserve the subclass[1]. Further, we have extra Box
# configuration (namely frozen_box=True) we need to preserve.
#
# 1: https://github.com/pydantic/pydantic/issues/5225
if isinstance(v, Box):
return v.__class__(new, **v._Box__box_config())
return new

# Filter out non-fields from ._iter (and thus .dict, .json, etc), such as `@cached_property`
# after access (which just gets cached in .__dict__).
def _iter(self, *args: Any, **kwargs: Any) -> Generator[tuple[str, Any], None, None]:
Expand All @@ -242,8 +278,8 @@ def _iter(self, *args: Any, **kwargs: Any) -> Generator[tuple[str, Any], None, N

@classmethod
def _pydantic_type_system_post_field_conversion_hook_(
cls, type_: "Type", *, name: str, required: bool
) -> "Type":
cls, type_: Type, *, name: str, required: bool
) -> Type:
return type_


Expand Down
22 changes: 21 additions & 1 deletion tests/arti/graphs/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from arti import Artifact, CompositeKey, Fingerprint, Graph, GraphSnapshot, View, producer
from arti.backends.memory import MemoryBackend
from arti.executors.local import LocalExecutor
from arti.graphs import ArtifactBox
from arti.internal.utils import frozendict
from arti.storage.literal import StringLiteral
from arti.storage.local import LocalFile, LocalFilePartition
Expand Down Expand Up @@ -43,13 +44,31 @@ def test_Graph(graph: Graph) -> None:
assert graph.artifacts.c.a.storage.includes_input_fingerprint_template
assert graph.artifacts.c.b.storage.includes_input_fingerprint_template
# NOTE: We may need to occasionally update this, but ensure graph.backend is not included.
assert graph.fingerprint == Fingerprint.from_int(4705012302096346878)
assert graph.fingerprint == Fingerprint.from_int(3139813064524317498)


def test_Graph_pickle(graph: Graph) -> None:
assert graph == pickle.loads(pickle.dumps(graph))


def test_Graph_copy(graph: Graph) -> None:
# There are a few edge cases in pydantic when copying a model with a mapping subclass field[1], so
# double check things are ok under various conditions.
#
# 1: https://github.com/pydantic/pydantic/issues/5225
for copy in [
graph.copy(),
graph.copy(exclude={"backend"}),
graph.copy(include=set(graph.__fields__)),
]:
assert graph == copy
assert isinstance(copy.artifacts, ArtifactBox)
assert graph.artifacts == copy.artifacts
assert graph.fingerprint == copy.fingerprint
assert hash(graph) == hash(copy)
assert hash(copy) == copy.fingerprint.key


def test_Graph_literals(tmp_path: Path) -> None:
n_add_runs = 0

Expand Down Expand Up @@ -147,6 +166,7 @@ def test_Graph_snapshot() -> None:
# Ensure order independence
assert s.id == Fingerprint.combine(*reversed(id_components))

assert g.backend is s.backend # Ensure the backend is not copied
# Ensure metadata is written
with g.backend.connect() as conn:
assert conn.read_graph(g.name, g.fingerprint) == g
Expand Down