Skip to content

Commit

Permalink
fix: from_sample override logic (#3202)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarnphm committed Nov 9, 2022
1 parent 39faf23 commit 248979b
Show file tree
Hide file tree
Showing 6 changed files with 5 additions and 75 deletions.
22 changes: 1 addition & 21 deletions src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -81,7 +81,6 @@ class IODescriptor(ABC, _OpenAPIMeta, t.Generic[IOType]):
_rpc_content_type: str = "application/grpc"
_proto_fields: tuple[ProtoField]
_sample: IOType | None = None
_initialized: bool = False
_args: t.Sequence[t.Any]
_kwargs: dict[str, t.Any]

Expand All @@ -99,27 +98,9 @@ def __init_subclass__(cls, *, descriptor_id: str | None = None):
def __init__(self, **kwargs: t.Any) -> None:
...

def __getattr__(self, name: str) -> t.Any:
if not self._initialized:
self._lazy_init()
assert self._initialized
return object.__getattribute__(self, name)

def __dir__(self) -> t.Iterable[str]:
if not self._initialized:
self._lazy_init()
assert self._initialized
return object.__dir__(self)

def __repr__(self) -> str:
return self.__class__.__qualname__

def _lazy_init(self) -> None:
self._initialized = True
self.__init__(*self._args, **self._kwargs)
del self._args
del self._kwargs

@property
def sample(self) -> IOType | None:
return self._sample
Expand All @@ -131,8 +112,7 @@ def sample(self, value: IOType) -> None:
@classmethod
def from_sample(cls, sample: IOType | t.Any, **kwargs: t.Any) -> Self:
klass = cls(**kwargs)
sample = klass._from_sample(sample)
klass.sample = sample
klass.sample = klass._from_sample(sample)
return klass

@abstractmethod
Expand Down
1 change: 0 additions & 1 deletion src/bentoml/_internal/io_descriptors/file.py
Expand Up @@ -135,7 +135,6 @@ def _from_sample(self, sample: FileType | str) -> FileType:
)
if isinstance(sample, t.IO):
sample = FileLike[bytes](sample, "<sample>")
self._mime_type = filetype.guess_mime(sample)
elif isinstance(sample, (str, os.PathLike)):
p = resolve_user_filepath(sample, ctx=None)
self._mime_type = filetype.guess_mime(p)
Expand Down
1 change: 1 addition & 0 deletions src/bentoml/_internal/io_descriptors/image.py
Expand Up @@ -231,6 +231,7 @@ def _from_sample(self, sample: ImageType | str) -> ImageType:
sample = PIL.Image.open(f)
except PIL.UnidentifiedImageError as err:
raise BadInput(f"Failed to parse sample image file: {err}") from None
self._mime_type = img_type.mime
return sample

def to_spec(self) -> dict[str, t.Any]:
Expand Down
11 changes: 0 additions & 11 deletions src/bentoml/_internal/io_descriptors/numpy.py
Expand Up @@ -216,15 +216,6 @@ def __init__(
shape: tuple[int, ...] | None = None,
enforce_shape: bool = False,
):
if enforce_dtype and not dtype:
raise InvalidArgument(
"'dtype' must be specified when 'enforce_dtype=True'"
) from None
if enforce_shape and not shape:
raise InvalidArgument(
"'shape' must be specified when 'enforce_shape=True'"
) from None

if dtype and not isinstance(dtype, np.dtype):
# Convert from primitive type or type string, e.g.: np.dtype(float) or np.dtype("float64")
try:
Expand Down Expand Up @@ -440,8 +431,6 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]:
) from None
self._dtype = sample.dtype
self._shape = sample.shape
self._enforce_dtype = True
self._enforce_shape = True
return sample

async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:
Expand Down
30 changes: 2 additions & 28 deletions src/bentoml/_internal/io_descriptors/pandas.py
Expand Up @@ -325,19 +325,6 @@ def __init__(
enforce_shape: bool = False,
default_format: t.Literal["json", "parquet", "csv"] = "json",
):
if enforce_dtype and dtype is None:
raise ValueError(
"'dtype' must be specified if 'enforce_dtype' is True"
) from None
if enforce_shape and shape is None:
raise ValueError(
"'shape' must be specified if 'enforce_shape' is True"
) from None
if apply_column_names and columns is None:
raise ValueError(
"'columns' must be specified if 'apply_column_names' is True"
) from None

self._orient: ext.DataFrameOrient = orient
self._columns = columns
self._apply_column_names = apply_column_names
Expand Down Expand Up @@ -431,10 +418,8 @@ def predict(inputs: pd.DataFrame) -> pd.DataFrame: ...
) from None
self._shape = sample.shape
self._columns = [str(i) for i in list(sample.columns)]
self._dtype = True
self._enforce_dtype = True
self._enforce_shape = True
self._apply_column_names = True
if self._dtype is None:
self._dtype = True # infer dtype automatically
return sample

def _convert_dtype(
Expand Down Expand Up @@ -828,15 +813,6 @@ def __init__(
shape: tuple[int, ...] | None = None,
enforce_shape: bool = False,
):
if enforce_dtype and dtype is None:
raise ValueError(
"'dtype' must be specified if 'enforce_dtype' is True"
) from None
if enforce_shape and shape is None:
raise ValueError(
"'shape' must be specified if 'enforce_shape' is True"
) from None

self._orient: ext.SeriesOrient = orient
self._dtype = dtype
self._enforce_dtype = enforce_dtype
Expand Down Expand Up @@ -885,8 +861,6 @@ def predict(inputs: pd.Series) -> pd.Series: ...
sample = pd.Series(sample)
self._dtype = sample.dtype
self._shape = sample.shape
self._enforce_dtype = True
self._enforce_shape = True
return sample

def input_type(self) -> LazyType[ext.PdSeries]:
Expand Down
15 changes: 1 addition & 14 deletions tests/unit/_internal/io/test_numpy.py
Expand Up @@ -10,7 +10,6 @@

from bentoml.io import NumpyNdarray
from bentoml.exceptions import BadInput
from bentoml.exceptions import InvalidArgument
from bentoml.exceptions import BentoMLException
from bentoml._internal.service.openapi.specification import Schema

Expand All @@ -29,7 +28,7 @@ class ExampleGeneric(str, np.generic):


example = np.zeros((2, 2, 3, 2))
from_example = NumpyNdarray.from_sample(example)
from_example = NumpyNdarray.from_sample(example, enforce_dtype=True, enforce_shape=True)


def test_invalid_dtype():
Expand All @@ -43,15 +42,6 @@ def test_invalid_dtype():
assert "expects a 'numpy.array'" in str(e.value)


def test_invalid_init():
with pytest.raises(InvalidArgument) as exc_info:
NumpyNdarray(enforce_dtype=True)
assert "'dtype' must be specified" in str(exc_info.value)
with pytest.raises(InvalidArgument) as exc_info:
NumpyNdarray(enforce_shape=True)
assert "'shape' must be specified" in str(exc_info.value)


@pytest.mark.parametrize("dtype, expected", [("float", "number"), (">U8", "integer")])
def test_numpy_to_openapi_types(dtype: str, expected: str):
assert NumpyNdarray(dtype=dtype)._openapi_types() == expected # type: ignore (private functions warning)
Expand Down Expand Up @@ -123,9 +113,6 @@ def test_verify_numpy_ndarray(caplog: LogCaptureFixture):

# test cases where reshape is failed
example = NumpyNdarray.from_sample(np.ones((2, 2, 3)))
# Note that from_sample now lazy load the IO descriptor
example._enforce_shape = False
example._enforce_dtype = False
with caplog.at_level(logging.DEBUG):
example.validate_array(np.array("asdf"))
assert "Failed to reshape" in caplog.text
Expand Down

0 comments on commit 248979b

Please sign in to comment.