Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TypeSystem.extends #250

Merged
merged 11 commits into from
Jul 19, 2022
53 changes: 40 additions & 13 deletions src/arti/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from arti.internal.models import Model
from arti.internal.type_hints import lenient_issubclass
from arti.internal.utils import class_name, frozendict, register
from arti.internal.utils import NoCopyDict, class_name, frozendict, register

DEFAULT_ANONYMOUS_NAME = "anon"

Expand Down Expand Up @@ -307,31 +307,31 @@ def matches_artigraph(cls, type_: Type, *, hints: dict[str, Any]) -> bool:
return isinstance(type_, cls.artigraph)

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> Type:
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any], type_system: "TypeSystem") -> Type:
raise NotImplementedError()

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
raise NotImplementedError()

@classmethod
def to_system(cls, type_: Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: Type, *, hints: dict[str, Any], type_system: "TypeSystem") -> Any:
raise NotImplementedError()


# _ScalarClassTypeAdapter can be used for scalars defined as python types (eg: int or str for the
# python TypeSystem).
class _ScalarClassTypeAdapter(TypeAdapter):
@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> Type:
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any], type_system: "TypeSystem") -> Type:
return cls.artigraph()

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return lenient_issubclass(type_, cls.system)

@classmethod
def to_system(cls, type_: Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: Type, *, hints: dict[str, Any], type_system: "TypeSystem") -> Any:
return cls.system

@classmethod
Expand All @@ -358,7 +358,12 @@ def generate(
class TypeSystem(Model):
key: str

_adapter_by_key: dict[str, type[TypeAdapter]] = PrivateAttr(default_factory=dict)
# NOTE: Use a NoCopyDict to avoid copies of the registry. Otherwise, TypeSystems that extend
# this TypeSystem will only see the adapters registered *as of initialization* (as pydantic
# would deepcopy the TypeSystems in the `extends` argument).
_adapter_by_key: NoCopyDict[str, type[TypeAdapter]] = PrivateAttr(default_factory=NoCopyDict)

extends: "tuple[TypeSystem, ...]" = ()

def register_adapter(self, adapter: type[TypeAdapter]) -> type[TypeAdapter]:
return register(self._adapter_by_key, adapter.key, adapter)
Expand All @@ -367,14 +372,36 @@ def register_adapter(self, adapter: type[TypeAdapter]) -> type[TypeAdapter]:
def _priority_sorted_adapters(self) -> Iterator[type[TypeAdapter]]:
return reversed(sorted(self._adapter_by_key.values(), key=attrgetter("priority")))

def to_artigraph(self, type_: Any, *, hints: dict[str, Any]) -> Type:
def to_artigraph(
self, type_: Any, *, hints: dict[str, Any], root_type_system: "Optional[TypeSystem]" = None
) -> Type:
root_type_system = root_type_system or self
for adapter in self._priority_sorted_adapters:
if adapter.matches_system(type_, hints=hints):
return adapter.to_artigraph(type_, hints=hints)
raise NotImplementedError(f"No {self} adapter for system type: {type_}.")

def to_system(self, type_: Type, *, hints: dict[str, Any]) -> Any:
return adapter.to_artigraph(type_, hints=hints, type_system=root_type_system)
for type_system in self.extends:
try:
return type_system.to_artigraph(
type_, hints=hints, root_type_system=root_type_system
)
except NotImplementedError:
pass
raise NotImplementedError(f"No {root_type_system} adapter for system type: {type_}.")

def to_system(
self, type_: Type, *, hints: dict[str, Any], root_type_system: "Optional[TypeSystem]" = None
) -> Any:
root_type_system = root_type_system or self
for adapter in self._priority_sorted_adapters:
if adapter.matches_artigraph(type_, hints=hints):
return adapter.to_system(type_, hints=hints)
raise NotImplementedError(f"No {self} adapter for Artigraph type: {type_}.")
return adapter.to_system(type_, hints=hints, type_system=root_type_system)
for type_system in self.extends:
try:
return type_system.to_system(type_, hints=hints, root_type_system=root_type_system)
except NotImplementedError:
pass
raise NotImplementedError(f"No {root_type_system} adapter for Artigraph type: {type_}.")


# Fix ForwardRefs in outer_type_, pending: https://github.com/samuelcolvin/pydantic/pull/4249
TypeSystem.__fields__["extends"].outer_type_ = tuple[TypeSystem, ...]
Comment on lines +406 to +407
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: I added this to fix the "runtime type checking" done on inputs to our Models.

95 changes: 60 additions & 35 deletions src/arti/types/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from arti import types
from arti.internal.utils import classproperty
from arti.types import TypeSystem

pyarrow_type_system = types.TypeSystem(key="pyarrow")

Expand All @@ -25,15 +26,17 @@ def _is_system(cls) -> Callable[[pa.DataType], bool]:
return getattr(pa.types, f"is_{cls.system.__name__}") # type: ignore

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
return cls.artigraph()

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return isinstance(type_, pa.DataType) and cls._is_system(type_)

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
return cls.system()


Expand Down Expand Up @@ -80,7 +83,9 @@ class BinaryTypeAdapter(_PyarrowTypeAdapter):
system = pa.binary

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
if isinstance(type_, pa.FixedSizeBinaryType):
return cls.artigraph(byte_size=type_.byte_width)
return cls.artigraph()
Expand All @@ -92,7 +97,7 @@ def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return super().matches_system(type_, hints=hints) or pa.types.is_fixed_size_binary(type_)

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
return cls.system(length=-1 if type_.byte_size is None else type_.byte_size)

Expand Down Expand Up @@ -122,7 +127,7 @@ def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return False

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
return pa.binary() if type_.format == "WKB" else pa.string()

Expand All @@ -133,19 +138,21 @@ class ListTypeAdapter(_PyarrowTypeAdapter):
system = pa.list_

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
return cls.artigraph(
element=pyarrow_type_system.to_artigraph(type_.value_type, hints=hints),
element=type_system.to_artigraph(type_.value_type, hints=hints),
)

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return cast(bool, pa.types.is_list(type_))

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
return cls.system(value_type=pyarrow_type_system.to_system(type_.element, hints=hints))
return cls.system(value_type=type_system.to_system(type_.element, hints=hints))


@pyarrow_type_system.register_adapter
Expand All @@ -154,22 +161,24 @@ class MapTypeAdapter(_PyarrowTypeAdapter):
system = pa.map_

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
return cls.artigraph(
key=pyarrow_type_system.to_artigraph(type_.key_type, hints=hints),
value=pyarrow_type_system.to_artigraph(type_.item_type, hints=hints),
key=type_system.to_artigraph(type_.key_type, hints=hints),
value=type_system.to_artigraph(type_.item_type, hints=hints),
)

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return cast(bool, pa.types.is_map(type_))

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
return cls.system(
key_type=pyarrow_type_system.to_system(type_.key, hints=hints),
item_type=pyarrow_type_system.to_system(type_.value, hints=hints),
key_type=type_system.to_system(type_.key, hints=hints),
item_type=type_system.to_system(type_.value, hints=hints),
)


Expand All @@ -179,30 +188,37 @@ class StructTypeAdapter(_PyarrowTypeAdapter):
system = pa.struct

@classmethod
def _field_to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
ret = pyarrow_type_system.to_artigraph(type_.type, hints=hints)
def _field_to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
ret = type_system.to_artigraph(type_.type, hints=hints)
if type_.nullable != ret.nullable: # Avoid setting nullable if matching to minimize repr
ret = ret.copy(update={"nullable": type_.nullable})
return ret

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
return cls.artigraph(
fields={field.name: cls._field_to_artigraph(field, hints=hints) for field in type_}
fields={
field.name: cls._field_to_artigraph(field, hints=hints, type_system=type_system)
for field in type_
}
)

@classmethod
def _field_to_system(cls, name: str, type_: types.Type, *, hints: dict[str, Any]) -> Any:
return pa.field(
name, pyarrow_type_system.to_system(type_, hints=hints), nullable=type_.nullable
)
def _field_to_system(
cls, name: str, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem
) -> Any:
return pa.field(name, type_system.to_system(type_, hints=hints), nullable=type_.nullable)

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
return cls.system(
[
cls._field_to_system(name, subtype, hints=hints)
cls._field_to_system(name, subtype, hints=hints, type_system=type_system)
for name, subtype in type_.fields.items()
]
)
Expand All @@ -224,26 +240,31 @@ def matches_artigraph(cls, type_: types.Type, *, hints: dict[str, Any]) -> bool:
)

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
kwargs = {}
# NOTE: pyarrow converts all metadata keys/values to bytes
if type_.metadata and b"artigraph" in type_.metadata:
kwargs = json.loads(type_.metadata[b"artigraph"].decode())
for key in ["partition_by", "cluster_by"]:
if key in kwargs: # pragma: no cover
kwargs[key] = tuple(kwargs[key])
return cls.artigraph(element=StructTypeAdapter.to_artigraph(type_, hints=hints), **kwargs)
return cls.artigraph(
element=StructTypeAdapter.to_artigraph(type_, hints=hints, type_system=type_system),
**kwargs,
)

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return isinstance(type_, pa.lib.Schema)

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
assert isinstance(type_, cls.artigraph)
assert isinstance(type_.element, types.Struct)
return cls.system(
StructTypeAdapter.to_system(type_.element, hints=hints),
StructTypeAdapter.to_system(type_.element, hints=hints, type_system=type_system),
metadata={
"artigraph": json.dumps(
{
Expand All @@ -269,15 +290,17 @@ def unit_to_precision(cls) -> dict[str, str]:
return {v: k for k, v in cls.precision_to_unit.items()}

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
if (precision := cls.unit_to_precision.get(type_.unit)) is None: # pragma: no cover
raise ValueError(
f"{type_}.unit must be one of {tuple(cls.unit_to_precision)}, got {type_.unit}"
)
return cls.artigraph(precision=precision)

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
precision = type_.precision # type: ignore
if (unit := cls.precision_to_unit.get(precision)) is None: # pragma: no cover
raise ValueError(
Expand All @@ -302,19 +325,21 @@ class TimestampTypeAdapter(_BaseTimeTypeAdapter):
system = pa.timestamp

@classmethod
def to_artigraph(cls, type_: Any, *, hints: dict[str, Any]) -> types.Type:
def to_artigraph(
cls, type_: Any, *, hints: dict[str, Any], type_system: TypeSystem
) -> types.Type:
tz = type_.tz.upper()
if tz != "UTC":
raise ValueError(f"Timestamp {type_}.tz must be in UTC, got {tz}")
return super().to_artigraph(type_, hints=hints)
return super().to_artigraph(type_, hints=hints, type_system=type_system)

@classmethod
def matches_system(cls, type_: Any, *, hints: dict[str, Any]) -> bool:
return super().matches_system(type_, hints=hints) and type_.tz is not None

@classmethod
def to_system(cls, type_: types.Type, *, hints: dict[str, Any]) -> Any:
ts = super().to_system(type_, hints=hints)
def to_system(cls, type_: types.Type, *, hints: dict[str, Any], type_system: TypeSystem) -> Any:
ts = super().to_system(type_, hints=hints, type_system=type_system)
return cls.system(ts.unit, "UTC")


Expand Down