From f04e60ae47180cfd911a4b972862270958ad23f3 Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Fri, 17 Mar 2023 03:55:44 -0700 Subject: [PATCH] wip Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- src/bentoml/_internal/frameworks/pytorch.py | 43 ++++++++++++++++----- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/src/bentoml/_internal/frameworks/pytorch.py b/src/bentoml/_internal/frameworks/pytorch.py index c7ca74a8165..97c17705ce4 100644 --- a/src/bentoml/_internal/frameworks/pytorch.py +++ b/src/bentoml/_internal/frameworks/pytorch.py @@ -5,18 +5,19 @@ from types import ModuleType from typing import TYPE_CHECKING from pathlib import Path +from functools import partial import cloudpickle import bentoml -from bentoml import Tag +from ..tag import Tag from ..types import LazyType from ..models import Model from ..utils.pkg import get_pkg_version from ...exceptions import NotFound from ..models.model import ModelContext -from ..models.model import PartialKwargsModelOptions as ModelOptions +from ..models.model import PartialKwargsModelOptions from .common.pytorch import torch from .common.pytorch import PyTorchTensorContainer @@ -43,9 +44,19 @@ def get(tag_like: str | Tag) -> Model: return model +class ModelOptions(PartialKwargsModelOptions): + fullgraph: bool = False + dynamic: bool = False + backend: t.Union[str, t.Callable[..., t.Any]] = "inductor" + mode: t.Optional[str] = None + options: t.Optional[t.Dict[str, t.Union[str, int, bool]]] = None + disable: bool = False + + def load_model( bentoml_model: str | Tag | Model, device_id: t.Optional[str] = "cpu", + **compile_kwargs: t.Any, ) -> torch.nn.Module: """ Load a model from a BentoML Model with given name. @@ -76,13 +87,15 @@ def load_model( weight_file = bentoml_model.path_of(MODEL_FILENAME) with Path(weight_file).open("rb") as file: - model: "torch.nn.Module" = torch.load(file, map_location=device_id) + model: torch.nn.Module = torch.load(file, map_location=device_id) + if get_pkg_version("torch") >= "2.0.0": + return t.cast("torch.nn.Module", torch.compile(model, **compile_kwargs)) return model def save_model( name: str, - model: "torch.nn.Module", + model: torch.nn.Module, *, signatures: ModelSignaturesType | None = None, labels: t.Dict[str, str] | None = None, @@ -195,15 +208,25 @@ def get_runnable(bento_model: Model): from .common.pytorch import PytorchModelRunnable from .common.pytorch import make_pytorch_runnable_method - partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.partial_kwargs # type: ignore + opts = t.cast(ModelOptions, bento_model.info.options) + if get_pkg_version("torch") >= "2.0.0": + _load_model = partial( + load_model, + fullgraph=opts.fullgraph, + dynamic=opts.dynamic, + backend=opts.backend, + mode=opts.mode, + options=opts.options, + disable=opts.disable, + ) + else: + _load_model = load_model - runnable_class: type[PytorchModelRunnable] = partial_class( - PytorchModelRunnable, - bento_model=bento_model, - loader=load_model, + runnable_class = partial_class( + PytorchModelRunnable, bento_model=bento_model, loader=_load_model ) for method_name, options in bento_model.info.signatures.items(): - method_partial_kwargs = partial_kwargs.get(method_name) + method_partial_kwargs = opts.partial_kwargs.get(method_name) runnable_class.add_method( make_pytorch_runnable_method(method_name, method_partial_kwargs), name=method_name,