Skip to content

Commit

Permalink
Add TypeSystem.extends (#250)
Browse files Browse the repository at this point in the history
Signed-off-by: James Oswald <james.oswald@dynotx.com>

Co-authored-by: Jacob Hayes <jacob.r.hayes@gmail.com>
Signed-off-by: Jacob Hayes <jacob.r.hayes@gmail.com>
  • Loading branch information
JamesOswald and JacobHayes committed Jul 19, 2022
1 parent 73a9107 commit b85405a
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 127 deletions.
53 changes: 40 additions & 13 deletions src/arti/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from arti.internal.models import Model
from arti.internal.type_hints import lenient_issubclass
from arti.internal.utils import class_name, frozendict, register
from arti.internal.utils import NoCopyDict, class_name, frozendict, register

DEFAULT_ANONYMOUS_NAME = "anon"

Expand Down Expand Up @@ -307,31 +307,31 @@ def matches_artigraph(cls, type_: Type, *, hints: dict[str, Any]) -> bool:
return isinstance(type_, cls.artigraph)

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> Type:
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any], type_system: "TypeSystem") -> Type:
raise NotImplementedError()

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
raise NotImplementedError()

@classmethod
def to_system(cls, type_: Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: Type, *, hints: dict[str, Any], type_system: "TypeSystem") -> Any:
raise NotImplementedError()


# _ScalarClassTypeAdapter can be used for scalars defined as python types (eg: int or str for the
# python TypeSystem).
class _ScalarClassTypeAdapter(TypeAdapter):
@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> Type:
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any], type_system: "TypeSystem") -> Type:
return cls.artigraph()

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return lenient_issubclass(type_, cls.system)

@classmethod
def to_system(cls, type_: Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: Type, *, hints: dict[str, Any], type_system: "TypeSystem") -> Any:
return cls.system

@classmethod
Expand All @@ -358,7 +358,12 @@ def generate(
class TypeSystem(Model):
key: str

_adapter_by_key: dict[str, type[TypeAdapter]] = PrivateAttr(default_factory=dict)
# NOTE: Use a NoCopyDict to avoid copies of the registry. Otherwise, TypeSystems that extend
# this TypeSystem will only see the adapters registered *as of initialization* (as pydantic
# would deepcopy the TypeSystems in the `extends` argument).
_adapter_by_key: NoCopyDict[str, type[TypeAdapter]] = PrivateAttr(default_factory=NoCopyDict)

extends: "tuple[TypeSystem, ...]" = ()

def register_adapter(self, adapter: type[TypeAdapter]) -> type[TypeAdapter]:
return register(self._adapter_by_key, adapter.key, adapter)
Expand All @@ -367,14 +372,36 @@ def register_adapter(self, adapter: type[TypeAdapter]) -> type[TypeAdapter]:
def _priority_sorted_adapters(self) -> Iterator[type[TypeAdapter]]:
return reversed(sorted(self._adapter_by_key.values(), key=attrgetter("priority")))

def to_artigraph(self, type_: Any, *, hints: dict[str, Any]) -> Type:
def to_artigraph(
self, type_: Any, *, hints: dict[str, Any], root_type_system: "Optional[TypeSystem]" = None
) -> Type:
root_type_system = root_type_system or self
for adapter in self._priority_sorted_adapters:
if adapter.matches_system(type_, hints=hints):
return adapter.to_artigraph(type_, hints=hints)
raise NotImplementedError(f"No {self} adapter for system type: {type_}.")

def to_system(self, type_: Type, *, hints: dict[str, Any]) -> Any:
return adapter.to_artigraph(type_, hints=hints, type_system=root_type_system)
for type_system in self.extends:
try:
return type_system.to_artigraph(
type_, hints=hints, root_type_system=root_type_system
)
except NotImplementedError:
pass
raise NotImplementedError(f"No {root_type_system} adapter for system type: {type_}.")

def to_system(
self, type_: Type, *, hints: dict[str, Any], root_type_system: "Optional[TypeSystem]" = None
) -> Any:
root_type_system = root_type_system or self
for adapter in self._priority_sorted_adapters:
if adapter.matches_artigraph(type_, hints=hints):
return adapter.to_system(type_, hints=hints)
raise NotImplementedError(f"No {self} adapter for Artigraph type: {type_}.")
return adapter.to_system(type_, hints=hints, type_system=root_type_system)
for type_system in self.extends:
try:
return type_system.to_system(type_, hints=hints, root_type_system=root_type_system)
except NotImplementedError:
pass
raise NotImplementedError(f"No {root_type_system} adapter for Artigraph type: {type_}.")


# Fix ForwardRefs in outer_type_, pending: https://github.com/samuelcolvin/pydantic/pull/4249
TypeSystem.__fields__["extends"].outer_type_ = tuple[TypeSystem, ...]
95 changes: 60 additions & 35 deletions src/arti/types/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from arti import types
from arti.internal.utils import classproperty
from arti.types import TypeSystem

pyarrow_type_system = types.TypeSystem(key="pyarrow")

Expand All @@ -25,15 +26,17 @@ def _is_system(cls) -> Callable[[pa.DataType], bool]:
return getattr(pa.types, f"is_{cls.system.__name__}") # type: ignore

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
return cls.artigraph()

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return isinstance(type_, pa.DataType) and cls._is_system(type_)

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
return cls.system()


Expand Down Expand Up @@ -80,7 +83,9 @@ class BinaryTypeAdapter(_PyarrowTypeAdapter):
system = pa.binary

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
if isinstance(type_, pa.FixedSizeBinaryType):
return cls.artigraph(byte_size=type_.byte_width)
return cls.artigraph()
Expand All @@ -92,7 +97,7 @@ def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return super().matches_system(type_, hints=hints) or pa.types.is_fixed_size_binary(type_)

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
return cls.system(length=-1 if type_.byte_size is None else type_.byte_size)

Expand Down Expand Up @@ -122,7 +127,7 @@ def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return False

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
return pa.binary() if type_.format == "WKB" else pa.string()

Expand All @@ -133,19 +138,21 @@ class ListTypeAdapter(_PyarrowTypeAdapter):
system = pa.list_

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
return cls.artigraph(
element=pyarrow_type_system.to_artigraph(type_.value_type, hints=hints),
element=type_system.to_artigraph(type_.value_type, hints=hints),
)

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return cast(bool, pa.types.is_list(type_))

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
return cls.system(value_type=pyarrow_type_system.to_system(type_.element, hints=hints))
return cls.system(value_type=type_system.to_system(type_.element, hints=hints))


@pyarrow_type_system.register_adapter
Expand All @@ -154,22 +161,24 @@ class MapTypeAdapter(_PyarrowTypeAdapter):
system = pa.map_

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
return cls.artigraph(
key=pyarrow_type_system.to_artigraph(type_.key_type, hints=hints),
value=pyarrow_type_system.to_artigraph(type_.item_type, hints=hints),
key=type_system.to_artigraph(type_.key_type, hints=hints),
value=type_system.to_artigraph(type_.item_type, hints=hints),
)

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return cast(bool, pa.types.is_map(type_))

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
return cls.system(
key_type=pyarrow_type_system.to_system(type_.key, hints=hints),
item_type=pyarrow_type_system.to_system(type_.value, hints=hints),
key_type=type_system.to_system(type_.key, hints=hints),
item_type=type_system.to_system(type_.value, hints=hints),
)


Expand All @@ -179,30 +188,37 @@ class StructTypeAdapter(_PyarrowTypeAdapter):
system = pa.struct

@classmethod
def _field_to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
ret = pyarrow_type_system.to_artigraph(type_.type, hints=hints)
def _field_to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
ret = type_system.to_artigraph(type_.type, hints=hints)
if type_.nullable != ret.nullable: # Avoid setting nullable if matching to minimize repr
ret = ret.copy(update={"nullable": type_.nullable})
return ret

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
return cls.artigraph(
fields={field.name: cls._field_to_artigraph(field, hints=hints) for field in type_}
fields={
field.name: cls._field_to_artigraph(field, hints=hints, type_system=type_system)
for field in type_
}
)

@classmethod
def _field_to_system(cls, name: str, type_: types.Type, *, hints: dict[str, Any]) -> Any:
return pa.field(
name, pyarrow_type_system.to_system(type_, hints=hints), nullable=type_.nullable
)
def _field_to_system(
cls, name: str, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem
) -> Any:
return pa.field(name, type_system.to_system(type_, hints=hints), nullable=type_.nullable)

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
return cls.system(
[
cls._field_to_system(name, subtype, hints=hints)
cls._field_to_system(name, subtype, hints=hints, type_system=type_system)
for name, subtype in type_.fields.items()
]
)
Expand All @@ -224,26 +240,31 @@ def matches_artigraph(cls, type_: types.Type, *, hints: dict[str, Any]) -> bool:
)

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
kwargs = {}
# NOTE: pyarrow converts all metadata keys/values to bytes
if type_.metadata and b"artigraph" in type_.metadata:
kwargs = json.loads(type_.metadata[b"artigraph"].decode())
for key in ["partition_by", "cluster_by"]:
if key in kwargs: # pragma: no cover
kwargs[key] = tuple(kwargs[key])
return cls.artigraph(element=StructTypeAdapter.to_artigraph(type_, hints=hints), **kwargs)
return cls.artigraph(
element=StructTypeAdapter.to_artigraph(type_, hints=hints, type_system=type_system),
**kwargs,
)

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return isinstance(type_, pa.lib.Schema)

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
assert isinstance(type_.element, types.Struct)
return cls.system(
StructTypeAdapter.to_system(type_.element, hints=hints),
StructTypeAdapter.to_system(type_.element, hints=hints, type_system=type_system),
metadata={
"artigraph": json.dumps(
{
Expand All @@ -269,15 +290,17 @@ def unit_to_precision(cls) -> dict[str, str]:
return {v: k for k, v in cls.precision_to_unit.items()}

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
if (precision := cls.unit_to_precision.get(type_.unit)) is None: # pragma: no cover
raise ValueError(
f"{type_}.unit must be one of {tuple(cls.unit_to_precision)}, got {type_.unit}"
)
return cls.artigraph(precision=precision)

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
precision = type_.precision # type: ignore
if (unit := cls.precision_to_unit.get(precision)) is None: # pragma: no cover
raise ValueError(
Expand All @@ -302,19 +325,21 @@ class TimestampTypeAdapter(_BaseTimeTypeAdapter):
system = pa.timestamp

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
tz = type_.tz.upper()
if tz != "UTC":
raise ValueError(f"Timestamp {type_}.tz must be in UTC, got {tz}")
return super().to_artigraph(type_, hints=hints)
return super().to_artigraph(type_, hints=hints, type_system=type_system)

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return super().matches_system(type_, hints=hints) and type_.tz is not None

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
ts = super().to_system(type_, hints=hints)
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
ts = super().to_system(type_, hints=hints, type_system=type_system)
return cls.system(ts.unit, "UTC")


Expand Down

0 comments on commit b85405a

Please sign in to comment.