Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Mar 17, 2023
1 parent bcc10ac commit f04e60a
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions src/bentoml/_internal/frameworks/pytorch.py
Expand Up @@ -5,18 +5,19 @@
from types import ModuleType
from typing import TYPE_CHECKING
from pathlib import Path
from functools import partial

import cloudpickle

import bentoml
from bentoml import Tag

from ..tag import Tag
from ..types import LazyType
from ..models import Model
from ..utils.pkg import get_pkg_version
from ...exceptions import NotFound
from ..models.model import ModelContext
from ..models.model import PartialKwargsModelOptions as ModelOptions
from ..models.model import PartialKwargsModelOptions
from .common.pytorch import torch
from .common.pytorch import PyTorchTensorContainer

Expand All @@ -43,9 +44,19 @@ def get(tag_like: str | Tag) -> Model:
return model


class ModelOptions(PartialKwargsModelOptions):
fullgraph: bool = False
dynamic: bool = False
backend: t.Union[str, t.Callable[..., t.Any]] = "inductor"
mode: t.Optional[str] = None
options: t.Optional[t.Dict[str, t.Union[str, int, bool]]] = None
disable: bool = False


def load_model(
bentoml_model: str | Tag | Model,
device_id: t.Optional[str] = "cpu",
**compile_kwargs: t.Any,
) -> torch.nn.Module:
"""
Load a model from a BentoML Model with given name.
Expand Down Expand Up @@ -76,13 +87,15 @@ def load_model(

weight_file = bentoml_model.path_of(MODEL_FILENAME)
with Path(weight_file).open("rb") as file:
model: "torch.nn.Module" = torch.load(file, map_location=device_id)
model: torch.nn.Module = torch.load(file, map_location=device_id)
if get_pkg_version("torch") >= "2.0.0":
return t.cast("torch.nn.Module", torch.compile(model, **compile_kwargs))
return model


def save_model(
name: str,
model: "torch.nn.Module",
model: torch.nn.Module,
*,
signatures: ModelSignaturesType | None = None,
labels: t.Dict[str, str] | None = None,
Expand Down Expand Up @@ -195,15 +208,25 @@ def get_runnable(bento_model: Model):
from .common.pytorch import PytorchModelRunnable
from .common.pytorch import make_pytorch_runnable_method

partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.partial_kwargs # type: ignore
opts = t.cast(ModelOptions, bento_model.info.options)
if get_pkg_version("torch") >= "2.0.0":
_load_model = partial(
load_model,
fullgraph=opts.fullgraph,
dynamic=opts.dynamic,
backend=opts.backend,
mode=opts.mode,
options=opts.options,
disable=opts.disable,
)
else:
_load_model = load_model

runnable_class: type[PytorchModelRunnable] = partial_class(
PytorchModelRunnable,
bento_model=bento_model,
loader=load_model,
runnable_class = partial_class(
PytorchModelRunnable, bento_model=bento_model, loader=_load_model
)
for method_name, options in bento_model.info.signatures.items():
method_partial_kwargs = partial_kwargs.get(method_name)
method_partial_kwargs = opts.partial_kwargs.get(method_name)
runnable_class.add_method(
make_pytorch_runnable_method(method_name, method_partial_kwargs),
name=method_name,
Expand Down

0 comments on commit f04e60a

Please sign in to comment.