Skip to content

Commit

Permalink
fix: make DocList properly a Generic
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Mar 11, 2024
1 parent 951679c commit 1f35aca
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 145 deletions.
96 changes: 0 additions & 96 deletions aux.py

This file was deleted.

26 changes: 15 additions & 11 deletions docarray/array/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
cast,
overload,
Tuple,
get_args,
get_origin,
)

import numpy as np
Expand All @@ -35,6 +33,11 @@

if sys.version_info >= (3, 12):
from types import GenericAlias
else:
try:
from typing import GenericAlias
except:
from typing import _GenericAlias as GenericAlias

T = TypeVar('T', bound='AnyDocArray')
T_doc = TypeVar('T_doc', bound=BaseDocWithoutId)
Expand All @@ -48,7 +51,7 @@
)


class AnyDocArray(AbstractType, Sequence[T_doc], Generic[T_doc]):
class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType):
doc_type: Type[BaseDocWithoutId]
__typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDocWithoutId], Type]] = {}

Expand All @@ -57,7 +60,6 @@ def __repr__(self):

@classmethod
def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
print(f' hey here {item}')
if not isinstance(item, type):
if sys.version_info < (3, 12):
return Generic.__class_getitem__.__func__(cls, item) # type: ignore
Expand All @@ -76,10 +78,11 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
if item not in cls.__typed_da__[cls]:
# Promote to global scope so multiprocessing can pickle it
global _DocArrayTyped

class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore
doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item)
# __origin__: Type['AnyDocArray'] = cls # add this
# __args__: Tuple[Any, ...] = (item,) # add this
__origin__: Type['AnyDocArray'] = cls # add this
__args__: Tuple[Any, ...] = (item,) # add this

for field in _DocArrayTyped.doc_type._docarray_fields().keys():

Expand Down Expand Up @@ -110,12 +113,13 @@ def _setter(self, value):
_DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
)

cls.__typed_da__[cls][item] = _DocArrayTyped
if sys.version_info < (3, 12):
cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(_DocArrayTyped, item) # type: ignore
# this do nothing that checking that item is valid type var or str
# Keep the approach in #1147 to be compatible with lower versions of Python.
else:
cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item)

Check warning on line 121 in docarray/array/any_array.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/any_array.py#L121

Added line #L121 was not covered by tests

print(f'return {cls.__typed_da__[cls][item]}')
a = get_args(cls.__typed_da__[cls][item])
print(f'a {a}')
print(f'get_origin {get_origin(cls.__typed_da__[cls][item])}')
return cls.__typed_da__[cls][item]

@overload
Expand Down
47 changes: 12 additions & 35 deletions docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
overload,
Callable,
get_args,
Generic
)

from pydantic import parse_obj_as
Expand All @@ -31,7 +30,6 @@
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

from docarray.utils._internal._typing import safe_issubclass
Expand All @@ -48,11 +46,7 @@


class DocList(
ListAdvancedIndexing[T_doc],
PushPullMixin,
IOMixinDocList,
AnyDocArray[T_doc],
Generic[T_doc]
ListAdvancedIndexing[T_doc], PushPullMixin, IOMixinDocList, AnyDocArray[T_doc]
):
"""
DocList is a container of Documents.
Expand Down Expand Up @@ -363,32 +357,15 @@ def __repr__(self):
def __get_pydantic_core_schema__(
cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
def get_args_2(tp):
"""Get type arguments with all substitutions performed.
For unions, basic simplifications used by Union constructor are performed.
Examples::
get_args(Dict[str, int]) == (str, int)
get_args(int) == ()
get_args(Union[int, Union[T, int], str][int]) == (int, str)
get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
get_args(Callable[[], T][int]) == ([], int)
"""
from typing import _GenericAlias, get_origin
import collections
if isinstance(tp, _GenericAlias):
res = tp.__args__
if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis:
res = (list(res[:-1]), res[-1])
return res
else:
print(f'IN ELSE')
return ()

instance_schema = core_schema.is_instance_schema(cls)

Check warning on line 360 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L360

Added line #L360 was not covered by tests
print(f'instance_schema {instance_schema} and {handler}')
args = get_args_2(DocList[BaseDocWithoutId])
print(f' args {args}')
return core_schema.with_info_after_validator_function(
function=cls.validate,
schema=core_schema.list_schema(core_schema.any_schema()))

args = get_args(source)
if args:
sequence_t_schema = handler(Sequence[args[0]])

Check warning on line 364 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L362-L364

Added lines #L362 - L364 were not covered by tests
else:
sequence_t_schema = handler(Sequence)

Check warning on line 366 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L366

Added line #L366 was not covered by tests

non_instance_schema = core_schema.with_info_after_validator_function(

Check warning on line 368 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L368

Added line #L368 was not covered by tests
lambda v, i: DocList(v), sequence_t_schema
)
return core_schema.union_schema([instance_schema, non_instance_schema])

Check warning on line 371 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L371

Added line #L371 was not covered by tests
6 changes: 3 additions & 3 deletions docarray/documents/legacy/legacy_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Dict, Optional, List, Union
from typing import Any, Dict, Optional

from docarray import BaseDoc, DocList
from docarray.typing import AnyEmbedding, AnyTensor
Expand Down Expand Up @@ -50,8 +50,8 @@ class LegacyDocument(BaseDoc):
"""

tensor: Optional[AnyTensor] = None
chunks: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None
matches: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None
chunks: Optional[DocList[LegacyDocument]] = None
matches: Optional[DocList[LegacyDocument]] = None
blob: Optional[bytes] = None
text: Optional[str] = None
url: Optional[str] = None
Expand Down

0 comments on commit 1f35aca

Please sign in to comment.