Skip to content

Commit

Permalink
fix: make DocList properly a Generic
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Mar 11, 2024
1 parent 951679c commit 12cd670
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 142 deletions.
96 changes: 0 additions & 96 deletions aux.py

This file was deleted.

20 changes: 9 additions & 11 deletions docarray/array/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
cast,
overload,
Tuple,
get_args,
get_origin,
)

try:
from typing import GenericAlias
except:
from typing import _GenericAlias as GenericAlias

import numpy as np

from docarray.base_doc.doc import BaseDocWithoutId
Expand Down Expand Up @@ -48,7 +51,7 @@
)


class AnyDocArray(AbstractType, Sequence[T_doc], Generic[T_doc]):
class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType):
doc_type: Type[BaseDocWithoutId]
__typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDocWithoutId], Type]] = {}

Expand All @@ -57,7 +60,6 @@ def __repr__(self):

@classmethod
def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
print(f' hey here {item}')
if not isinstance(item, type):
if sys.version_info < (3, 12):
return Generic.__class_getitem__.__func__(cls, item) # type: ignore
Expand All @@ -78,8 +80,8 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
global _DocArrayTyped
class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore
doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item)
# __origin__: Type['AnyDocArray'] = cls # add this
# __args__: Tuple[Any, ...] = (item,) # add this
__origin__: Type['AnyDocArray'] = cls # add this
__args__: Tuple[Any, ...] = (item,) # add this

for field in _DocArrayTyped.doc_type._docarray_fields().keys():

Expand Down Expand Up @@ -110,12 +112,8 @@ def _setter(self, value):
_DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
)

cls.__typed_da__[cls][item] = _DocArrayTyped
cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item)

print(f'return {cls.__typed_da__[cls][item]}')
a = get_args(cls.__typed_da__[cls][item])
print(f'a {a}')
print(f'get_origin {get_origin(cls.__typed_da__[cls][item])}')
return cls.__typed_da__[cls][item]

@overload
Expand Down
44 changes: 12 additions & 32 deletions docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
overload,
Callable,
get_args,
Generic
)

from pydantic import parse_obj_as
Expand All @@ -31,7 +30,6 @@
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

from docarray.utils._internal._typing import safe_issubclass
Expand All @@ -51,8 +49,7 @@ class DocList(
ListAdvancedIndexing[T_doc],
PushPullMixin,
IOMixinDocList,
AnyDocArray[T_doc],
Generic[T_doc]
AnyDocArray[T_doc]
):
"""
DocList is a container of Documents.
Expand Down Expand Up @@ -363,32 +360,15 @@ def __repr__(self):
def __get_pydantic_core_schema__(
cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
def get_args_2(tp):
"""Get type arguments with all substitutions performed.
For unions, basic simplifications used by Union constructor are performed.
Examples::
get_args(Dict[str, int]) == (str, int)
get_args(int) == ()
get_args(Union[int, Union[T, int], str][int]) == (int, str)
get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
get_args(Callable[[], T][int]) == ([], int)
"""
from typing import _GenericAlias, get_origin
import collections
if isinstance(tp, _GenericAlias):
res = tp.__args__
if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis:
res = (list(res[:-1]), res[-1])
return res
else:
print(f'IN ELSE')
return ()

instance_schema = core_schema.is_instance_schema(cls)

Check warning on line 363 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L363

Added line #L363 was not covered by tests
print(f'instance_schema {instance_schema} and {handler}')
args = get_args_2(DocList[BaseDocWithoutId])
print(f' args {args}')
return core_schema.with_info_after_validator_function(
function=cls.validate,
schema=core_schema.list_schema(core_schema.any_schema()))

args = get_args(source)
if args:
sequence_t_schema = handler(Sequence[args[0]])

Check warning on line 367 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L365-L367

Added lines #L365 - L367 were not covered by tests
else:
sequence_t_schema = handler(Sequence)

Check warning on line 369 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L369

Added line #L369 was not covered by tests

non_instance_schema = core_schema.with_info_after_validator_function(

Check warning on line 371 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L371

Added line #L371 was not covered by tests
lambda v, i: DocList(v), sequence_t_schema
)
return core_schema.union_schema([instance_schema, non_instance_schema])

Check warning on line 374 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L374

Added line #L374 was not covered by tests
6 changes: 3 additions & 3 deletions docarray/documents/legacy/legacy_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Dict, Optional, List, Union
from typing import Any, Dict, Optional

from docarray import BaseDoc, DocList
from docarray.typing import AnyEmbedding, AnyTensor
Expand Down Expand Up @@ -50,8 +50,8 @@ class LegacyDocument(BaseDoc):
"""

tensor: Optional[AnyTensor] = None
chunks: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None
matches: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None
chunks: Optional[DocList[LegacyDocument]] = None
matches: Optional[DocList[LegacyDocument]] = None
blob: Optional[bytes] = None
text: Optional[str] = None
url: Optional[str] = None
Expand Down

0 comments on commit 12cd670

Please sign in to comment.