Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: from_sample override logic #3202

Merged
merged 3 commits into from Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 1 addition & 21 deletions src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -81,7 +81,6 @@ class IODescriptor(ABC, _OpenAPIMeta, t.Generic[IOType]):
_rpc_content_type: str = "application/grpc"
_proto_fields: tuple[ProtoField]
_sample: IOType | None = None
_initialized: bool = False
_args: t.Sequence[t.Any]
_kwargs: dict[str, t.Any]

Expand All @@ -99,27 +98,9 @@ def __init_subclass__(cls, *, descriptor_id: str | None = None):
def __init__(self, **kwargs: t.Any) -> None:
...

def __getattr__(self, name: str) -> t.Any:
if not self._initialized:
self._lazy_init()
assert self._initialized
return object.__getattribute__(self, name)

def __dir__(self) -> t.Iterable[str]:
if not self._initialized:
self._lazy_init()
assert self._initialized
return object.__dir__(self)

def __repr__(self) -> str:
return self.__class__.__qualname__

def _lazy_init(self) -> None:
self._initialized = True
self.__init__(*self._args, **self._kwargs)
del self._args
del self._kwargs

@property
def sample(self) -> IOType | None:
return self._sample
Expand All @@ -131,8 +112,7 @@ def sample(self, value: IOType) -> None:
@classmethod
def from_sample(cls, sample: IOType | t.Any, **kwargs: t.Any) -> Self:
klass = cls(**kwargs)
sample = klass._from_sample(sample)
klass.sample = sample
klass.sample = klass._from_sample(sample)
return klass

@abstractmethod
Expand Down
1 change: 0 additions & 1 deletion src/bentoml/_internal/io_descriptors/file.py
Expand Up @@ -135,7 +135,6 @@ def _from_sample(self, sample: FileType | str) -> FileType:
)
if isinstance(sample, t.IO):
sample = FileLike[bytes](sample, "<sample>")
self._mime_type = filetype.guess_mime(sample)
elif isinstance(sample, (str, os.PathLike)):
p = resolve_user_filepath(sample, ctx=None)
self._mime_type = filetype.guess_mime(p)
Expand Down
1 change: 1 addition & 0 deletions src/bentoml/_internal/io_descriptors/image.py
Expand Up @@ -231,6 +231,7 @@ def _from_sample(self, sample: ImageType | str) -> ImageType:
sample = PIL.Image.open(f)
except PIL.UnidentifiedImageError as err:
raise BadInput(f"Failed to parse sample image file: {err}") from None
self._mime_type = img_type.mime
return sample

def to_spec(self) -> dict[str, t.Any]:
Expand Down
11 changes: 0 additions & 11 deletions src/bentoml/_internal/io_descriptors/numpy.py
Expand Up @@ -216,15 +216,6 @@ def __init__(
shape: tuple[int, ...] | None = None,
enforce_shape: bool = False,
):
if enforce_dtype and not dtype:
raise InvalidArgument(
"'dtype' must be specified when 'enforce_dtype=True'"
) from None
if enforce_shape and not shape:
raise InvalidArgument(
"'shape' must be specified when 'enforce_shape=True'"
) from None

if dtype and not isinstance(dtype, np.dtype):
# Convert from primitive type or type string, e.g.: np.dtype(float) or np.dtype("float64")
try:
Expand Down Expand Up @@ -440,8 +431,6 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]:
) from None
self._dtype = sample.dtype
self._shape = sample.shape
self._enforce_dtype = True
self._enforce_shape = True
return sample

async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:
Expand Down
30 changes: 2 additions & 28 deletions src/bentoml/_internal/io_descriptors/pandas.py
Expand Up @@ -325,19 +325,6 @@ def __init__(
enforce_shape: bool = False,
default_format: t.Literal["json", "parquet", "csv"] = "json",
):
if enforce_dtype and dtype is None:
raise ValueError(
"'dtype' must be specified if 'enforce_dtype' is True"
) from None
if enforce_shape and shape is None:
raise ValueError(
"'shape' must be specified if 'enforce_shape' is True"
) from None
if apply_column_names and columns is None:
raise ValueError(
"'columns' must be specified if 'apply_column_names' is True"
) from None

self._orient: ext.DataFrameOrient = orient
self._columns = columns
self._apply_column_names = apply_column_names
Expand Down Expand Up @@ -431,10 +418,8 @@ def predict(inputs: pd.DataFrame) -> pd.DataFrame: ...
) from None
self._shape = sample.shape
self._columns = [str(i) for i in list(sample.columns)]
self._dtype = True
self._enforce_dtype = True
self._enforce_shape = True
self._apply_column_names = True
if self._dtype is None:
self._dtype = True # infer dtype automatically
return sample

def _convert_dtype(
Expand Down Expand Up @@ -828,15 +813,6 @@ def __init__(
shape: tuple[int, ...] | None = None,
enforce_shape: bool = False,
):
if enforce_dtype and dtype is None:
raise ValueError(
"'dtype' must be specified if 'enforce_dtype' is True"
) from None
if enforce_shape and shape is None:
raise ValueError(
"'shape' must be specified if 'enforce_shape' is True"
) from None

self._orient: ext.SeriesOrient = orient
self._dtype = dtype
self._enforce_dtype = enforce_dtype
Expand Down Expand Up @@ -885,8 +861,6 @@ def predict(inputs: pd.Series) -> pd.Series: ...
sample = pd.Series(sample)
self._dtype = sample.dtype
self._shape = sample.shape
self._enforce_dtype = True
self._enforce_shape = True
return sample

def input_type(self) -> LazyType[ext.PdSeries]:
Expand Down
15 changes: 1 addition & 14 deletions tests/unit/_internal/io/test_numpy.py
Expand Up @@ -10,7 +10,6 @@

from bentoml.io import NumpyNdarray
from bentoml.exceptions import BadInput
from bentoml.exceptions import InvalidArgument
from bentoml.exceptions import BentoMLException
from bentoml._internal.service.openapi.specification import Schema

Expand All @@ -29,7 +28,7 @@ class ExampleGeneric(str, np.generic):


example = np.zeros((2, 2, 3, 2))
from_example = NumpyNdarray.from_sample(example)
from_example = NumpyNdarray.from_sample(example, enforce_dtype=True, enforce_shape=True)


def test_invalid_dtype():
Expand All @@ -43,15 +42,6 @@ def test_invalid_dtype():
assert "expects a 'numpy.array'" in str(e.value)


def test_invalid_init():
with pytest.raises(InvalidArgument) as exc_info:
NumpyNdarray(enforce_dtype=True)
assert "'dtype' must be specified" in str(exc_info.value)
with pytest.raises(InvalidArgument) as exc_info:
NumpyNdarray(enforce_shape=True)
assert "'shape' must be specified" in str(exc_info.value)


@pytest.mark.parametrize("dtype, expected", [("float", "number"), (">U8", "integer")])
def test_numpy_to_openapi_types(dtype: str, expected: str):
assert NumpyNdarray(dtype=dtype)._openapi_types() == expected # type: ignore (private functions warning)
Expand Down Expand Up @@ -123,9 +113,6 @@ def test_verify_numpy_ndarray(caplog: LogCaptureFixture):

# test cases where reshape is failed
example = NumpyNdarray.from_sample(np.ones((2, 2, 3)))
# Note that from_sample now lazy load the IO descriptor
example._enforce_shape = False
example._enforce_dtype = False
with caplog.at_level(logging.DEBUG):
example.validate_array(np.array("asdf"))
assert "Failed to reshape" in caplog.text
Expand Down