Skip to content

Commit

Permalink
fix(bug): Make model.with_options() additive (#2519)
Browse files Browse the repository at this point in the history
Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
ssheng and aarnphm committed Jun 11, 2022
1 parent e0d0f7f commit 03ee3e4
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 73 deletions.
12 changes: 1 addition & 11 deletions bentoml/_internal/frameworks/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,6 @@ class KerasOptions(ModelOptions):
include_optimizer: bool
partial_kwargs: t.Dict[str, t.Any] = attr.field(factory=dict)

@classmethod
def with_options(cls, **kwargs: t.Any) -> ModelOptions:
return cls(**kwargs)

@staticmethod
def to_dict(options: ModelOptions) -> dict[str, t.Any]:
return attr.asdict(options)


def get(tag_like: str | Tag) -> bentoml.Model:
model = bentoml.models.get(tag_like)
Expand Down Expand Up @@ -256,9 +248,7 @@ def get_runnable(
Private API: use :obj:`~bentoml.Model.to_runnable` instead.
"""

partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.get(
"partial_kwargs", dict()
)
partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.partial_kwargs # type: ignore

class KerasRunnable(Runnable):
SUPPORT_NVIDIA_GPU = True
Expand Down
15 changes: 11 additions & 4 deletions bentoml/_internal/frameworks/tensorflow_v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import os
import pickle
import typing as t
import logging
Expand All @@ -9,10 +8,13 @@
import contextlib
from typing import TYPE_CHECKING

import attr

import bentoml
from bentoml import Tag
from bentoml import Runnable
from bentoml.models import ModelContext
from bentoml.models import ModelOptions
from bentoml.exceptions import NotFound
from bentoml.exceptions import MissingDependencyException

Expand Down Expand Up @@ -49,6 +51,13 @@
API_VERSION = "v1"


@attr.define
class TensorflowOptions(ModelOptions):
"""Options for the Keras model."""

partial_kwargs: t.Dict[str, t.Any] = attr.field(factory=dict)


def get(tag_like: str | Tag) -> bentoml.Model:
model = bentoml.models.get(tag_like)
if model.info.module not in (MODULE_NAME, __name__):
Expand Down Expand Up @@ -219,9 +228,7 @@ def get_runnable(
Private API: use :obj:`~bentoml.Model.to_runnable` instead.
"""

partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.get(
"partial_kwargs", dict()
)
partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.partial_kwargs

class TensorflowRunnable(Runnable):
SUPPORT_NVIDIA_GPU = True
Expand Down
14 changes: 0 additions & 14 deletions bentoml/_internal/frameworks/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,8 @@ class TransformersOptions(ModelOptions):
]
)

pipeline: bool = attr.field(
default=True, validator=attr.validators.instance_of(bool)
)

kwargs: t.Dict[str, t.Any] = attr.field(factory=dict)

@classmethod
def with_options(cls, **kwargs: t.Any) -> ModelOptions:
return cls(**kwargs)

@staticmethod
def to_dict(options: ModelOptions) -> dict[str, t.Any]:
return attr.asdict(options)


def get(tag_like: str | Tag) -> Model:
model = bentoml.models.get(tag_like)
Expand Down Expand Up @@ -146,8 +134,6 @@ def load_model(
f"Model {bento_model.tag} was saved with module {bento_model.info.module}, failed loading with {MODULE_NAME}."
)

bento_model.info.parse_options(TransformersOptions)

pipeline_task: str = bento_model.info.options.task # type: ignore
pipeline_kwargs: t.Dict[str, t.Any] = bento_model.info.options.kwargs # type: ignore
pipeline_kwargs.update(kwargs)
Expand Down
38 changes: 20 additions & 18 deletions bentoml/_internal/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import TYPE_CHECKING
from datetime import datetime
from datetime import timezone
from collections import UserDict

import fs
import attr
Expand Down Expand Up @@ -57,25 +56,17 @@ class ModelSignatureDict(t.TypedDict, total=False):
CUSTOM_OBJECTS_FILENAME = "custom_objects.pkl"


if TYPE_CHECKING:
ModelOptionsSuper = UserDict[str, t.Any]
else:
ModelOptionsSuper = UserDict


class ModelOptions(ModelOptionsSuper):
@classmethod
def with_options(cls, **kwargs: t.Any) -> ModelOptions:
return cls(**kwargs)
@attr.define
class ModelOptions:
def with_options(self, **kwargs: t.Any) -> "ModelOptions":
return attr.evolve(self, **kwargs)

@staticmethod
def to_dict(options: ModelOptions) -> dict[str, t.Any]:
return dict(options)
return attr.asdict(options)


bentoml_cattr.register_structure_hook_func(
lambda cls: issubclass(cls, ModelOptions), lambda d, cls: cls.with_options(**d) # type: ignore
)
bentoml_cattr.register_structure_hook(ModelOptions, lambda d, cls: cls(**d))
bentoml_cattr.register_unstructure_hook(ModelOptions, lambda v: v.to_dict(v)) # type: ignore # pylint: disable=unnecessary-lambda # lambda required


Expand Down Expand Up @@ -545,9 +536,6 @@ def with_options(self, **kwargs: t.Any) -> ModelInfo:
def to_dict(self) -> t.Dict[str, t.Any]:
return bentoml_cattr.unstructure(self) # type: ignore (incomplete cattr types)

def parse_options(self, options_class: type[ModelOptions]) -> None:
object.__setattr__(self, "options", options_class.with_options(**self.options))

@overload
def dump(self, stream: io.StringIO) -> io.BytesIO:
...
Expand Down Expand Up @@ -588,6 +576,20 @@ def from_yaml_file(stream: t.IO[t.Any]):
del yaml_content["context"]["pip_dependencies"]
yaml_content["context"]["framework_versions"] = {}

# register hook for model options
module_name: str = yaml_content["module"]
try:
module = importlib.import_module(module_name)
except (ValueError, ModuleNotFoundError) as e:
raise BentoMLException(
f"Module '{module_name}' defined in {MODEL_YAML_FILENAME} is not found."
) from e
if hasattr(module, "ModelOptions"):
bentoml_cattr.register_structure_hook(
ModelOptions,
lambda d, _: module.ModelOptions(**d),
)

try:
model_info = bentoml_cattr.structure(yaml_content, ModelInfo)
except TypeError as e: # pragma: no cover - simple error handling
Expand Down
1 change: 1 addition & 0 deletions bentoml/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._internal.frameworks.keras import load_model
from ._internal.frameworks.keras import save_model
from ._internal.frameworks.keras import get_runnable
from ._internal.frameworks.keras import KerasOptions as ModelOptions # type: ignore # noqa

logger = logging.getLogger(__name__)

Expand Down
1 change: 1 addition & 0 deletions bentoml/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._internal.frameworks.tensorflow_v2 import load_model
from ._internal.frameworks.tensorflow_v2 import save_model
from ._internal.frameworks.tensorflow_v2 import get_runnable
from ._internal.frameworks.tensorflow_v2 import TensorflowOptions as ModelOptions # type: ignore # noqa

logger = logging.getLogger(__name__)

Expand Down
1 change: 1 addition & 0 deletions bentoml/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._internal.frameworks.transformers import load_model
from ._internal.frameworks.transformers import save_model
from ._internal.frameworks.transformers import get_runnable
from ._internal.frameworks.transformers import TransformersOptions as ModelOptions # type: ignore # noqa

logger = logging.getLogger(__name__)

Expand Down
18 changes: 15 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,27 @@ def fixture_change_test_dir(request: pytest.FixtureRequest):
def fixture_dummy_model_store(tmpdir_factory: "pytest.TempPathFactory") -> ModelStore:
store = ModelStore(tmpdir_factory.mktemp("models"))
with bentoml.models.create(
"testmodel", signatures={}, context=TEST_MODEL_CONTEXT, _model_store=store
"testmodel",
module=__name__,
signatures={},
context=TEST_MODEL_CONTEXT,
_model_store=store,
):
pass
with bentoml.models.create(
"testmodel", signatures={}, context=TEST_MODEL_CONTEXT, _model_store=store
"testmodel",
module=__name__,
signatures={},
context=TEST_MODEL_CONTEXT,
_model_store=store,
):
pass
with bentoml.models.create(
"anothermodel", signatures={}, context=TEST_MODEL_CONTEXT, _model_store=store
"anothermodel",
module=__name__,
signatures={},
context=TEST_MODEL_CONTEXT,
_model_store=store,
):
pass

Expand Down
27 changes: 14 additions & 13 deletions tests/integration/frameworks/test_transformers_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,36 +39,36 @@ def pt_gpt2_pipeline():
(
"text-generation",
transformers.pipeline(task="text-generation"), # type: ignore
{"pipeline": True, "task": "text-generation"},
{"pipeline": True, "task": "text-generation"},
{},
{"task": "text-generation", "kwargs": {}},
"A Bento box is a ",
),
(
"text-generation",
transformers.pipeline(task="text-generation"), # type: ignore
{"pipeline": True, "task": "text-generation", "kwargs": {"a": 1}},
{"pipeline": True, "task": "text-generation", "kwargs": {"a": 1}},
{"kwargs": {"a": 1}},
{"task": "text-generation", "kwargs": {"a": 1}},
"A Bento box is a ",
),
(
"text-generation",
tf_gpt2_pipeline(),
{"pipeline": True, "task": "text-generation"},
{"pipeline": True, "task": "text-generation"},
{},
{"task": "text-generation", "kwargs": {}},
"A Bento box is a ",
),
(
"text-generation",
pt_gpt2_pipeline(),
{"pipeline": True, "task": "text-generation"},
{"pipeline": True, "task": "text-generation"},
{},
{"task": "text-generation", "kwargs": {}},
"A Bento box is a ",
),
(
"image-classification",
transformers.pipeline("image-classification"), # type: ignore
{"pipeline": True, "task": "image-classification"},
{"pipeline": True, "task": "image-classification"},
{},
{"task": "image-classification", "kwargs": {}},
Image.open(
requests.get(
"http://images.cocodataset.org/val2017/000000039769.jpg",
Expand All @@ -79,8 +79,8 @@ def pt_gpt2_pipeline():
(
"text-classification",
transformers.pipeline("text-classification"), # type: ignore
{"pipeline": True, "task": "text-classification"},
{"pipeline": True, "task": "text-classification"},
{},
{"task": "text-classification", "kwargs": {}},
"BentoML is an awesome library for machine learning.",
),
],
Expand All @@ -101,7 +101,8 @@ def test_transformers(
)
assert bento_model.tag == tag
assert bento_model.info.context.framework_name == "transformers"
assert dict(bento_model.info.options) == expected_options
assert bento_model.info.options.task == expected_options["task"] # type: ignore
assert bento_model.info.options.kwargs == expected_options["kwargs"] # type: ignore

runnable: bentoml.Runnable = bentoml.transformers.get_runnable(bento_model)()
output_data = runnable(input_data) # type: ignore
Expand Down
20 changes: 13 additions & 7 deletions tests/unit/_internal/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from datetime import timezone

import fs
import attr
import numpy as np
import pytest
import fs.errors

from bentoml import Tag
from bentoml.exceptions import BentoMLException
from bentoml._internal.models import ModelContext
from bentoml._internal.models import ModelOptions
from bentoml._internal.models import ModelOptions as InternalModelOptions
from bentoml._internal.models.model import Model
from bentoml._internal.models.model import ModelInfo
from bentoml._internal.models.model import ModelStore
Expand All @@ -32,7 +33,7 @@
expected_yaml = """\
name: test
version: v1
module: testmodule
module: test_model
labels:
label: stringvalue
options:
Expand Down Expand Up @@ -77,20 +78,24 @@
"""


class TestModelOption(ModelOptions):
@attr.define
class TestModelOptions(InternalModelOptions):
option_a: int
option_b: str
option_c: list[float]


ModelOptions = TestModelOptions


def test_model_info(tmpdir: "Path"):
start = datetime.now(timezone.utc)
modelinfo_a = ModelInfo(
tag=Tag("tag"),
module="module",
api_version="v1",
labels={},
options=ModelOptions(),
options=TestModelOptions(option_a=42, option_b="foo", option_c=[0.1, 0.2]),
metadata={},
context=TEST_MODEL_CONTEXT,
signatures={"predict": {"batchable": True}},
Expand All @@ -102,9 +107,9 @@ def test_model_info(tmpdir: "Path"):
assert start <= modelinfo_a.creation_time <= end

tag = Tag("test", "v1")
module = "testmodule"
module = __name__
labels = {"label": "stringvalue"}
options = TestModelOption(option_a=1, option_b="foo", option_c=[0.1, 0.2])
options = TestModelOptions(option_a=1, option_b="foo", option_c=[0.1, 0.2])
metadata = {"a": 0.1, "b": 1, "c": np.array([2, 3, 4], dtype=np.uint32)}
# TODO: add test cases for input_spec and output_spec
signatures = {
Expand Down Expand Up @@ -183,10 +188,11 @@ def __call__(self, y: int) -> int:
def fixture_bento_model():
model = Model.create(
"testmodel",
module="foo",
module=__name__,
api_version="v1",
signatures={},
context=TEST_MODEL_CONTEXT,
options=TestModelOptions(option_a=1, option_b="foo", option_c=[0.1, 0.2]),
custom_objects={
"add": AdditionClass(add_num_1),
},
Expand Down

0 comments on commit 03ee3e4

Please sign in to comment.