-
Notifications
You must be signed in to change notification settings - Fork 757
/
pytorch.py
238 lines (191 loc) · 8.06 KB
/
pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
from __future__ import annotations
import typing as t
import logging
from types import ModuleType
from typing import TYPE_CHECKING
from pathlib import Path
from functools import partial
import cloudpickle
import bentoml
from ..tag import Tag
from ..types import LazyType
from ..models import Model
from ..utils.pkg import get_pkg_version
from ...exceptions import NotFound
from ..models.model import ModelContext
from ..models.model import PartialKwargsModelOptions
from .common.pytorch import torch
from .common.pytorch import PyTorchTensorContainer
__all__ = ["load_model", "save_model", "get_runnable", "get", "PyTorchTensorContainer"]
MODULE_NAME = "bentoml.pytorch"
MODEL_FILENAME = "saved_model.pt"
API_VERSION = "v1"
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..models.model import ModelSignaturesType
def get(tag_like: str | Tag) -> Model:
model = bentoml.models.get(tag_like)
if model.info.module not in (MODULE_NAME, __name__):
raise NotFound(
f"Model {model.tag} was saved with module {model.info.module}, not loading with {MODULE_NAME}."
)
return model
class ModelOptions(PartialKwargsModelOptions):
fullgraph: bool = False
dynamic: bool = False
backend: t.Union[str, t.Callable[..., t.Any]] = "inductor"
mode: t.Optional[str] = None
options: t.Optional[t.Dict[str, t.Union[str, int, bool]]] = None
disable: bool = False
def load_model(
bentoml_model: str | Tag | Model,
device_id: t.Optional[str] = "cpu",
**compile_kwargs: t.Any,
) -> torch.nn.Module:
"""
Load a model from a BentoML Model with given name.
Args:
tag (:code:`Union[str, Tag]`):
Tag of a saved model in BentoML local modelstore.
device_id (:code:`str`, `optional`, default to :code:`cpu`):
Optional devices to put the given model on. Refer to `device attributes <https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device>`_.
Returns:
:obj:`torch.nn.Module`: an instance of :code:`torch.nn.Module` from BentoML modelstore.
Examples:
.. code-block:: python
import bentoml
model = bentoml.pytorch.load_model('lit_classifier:latest', device_id="cuda:0")
"""
if isinstance(bentoml_model, (str, Tag)):
bentoml_model = get(bentoml_model)
if bentoml_model.info.module not in (MODULE_NAME, __name__):
raise NotFound(
f"Model {bentoml_model.tag} was saved with module {bentoml_model.info.module}, not loading with {MODULE_NAME}."
)
weight_file = bentoml_model.path_of(MODEL_FILENAME)
with Path(weight_file).open("rb") as file:
model: torch.nn.Module = torch.load(file, map_location=device_id)
if get_pkg_version("torch") >= "2.0.0":
return t.cast("torch.nn.Module", torch.compile(model, **compile_kwargs))
return model
def save_model(
name: str,
model: torch.nn.Module,
*,
signatures: ModelSignaturesType | None = None,
labels: t.Dict[str, str] | None = None,
custom_objects: t.Dict[str, t.Any] | None = None,
external_modules: t.List[ModuleType] | None = None,
metadata: t.Dict[str, t.Any] | None = None,
) -> bentoml.Model:
"""
Save a model instance to BentoML modelstore.
Args:
name (:code:`str`):
Name for given model instance. This should pass Python identifier check.
model (:code:`torch.nn.Module`):
Instance of model to be saved
signatures (:code:`ModelSignaturesType`, `optional`, default to :code:`None`):
A dictionary of method names and their corresponding signatures.
labels (:code:`Dict[str, str]`, `optional`, default to :code:`None`):
user-defined labels for managing models, e.g. team=nlp, stage=dev
custom_objects (:code:`Dict[str, Any]]`, `optional`, default to :code:`None`):
user-defined additional python objects to be saved alongside the model,
e.g. a tokenizer instance, preprocessor function, model configuration json
external_modules (:code:`List[ModuleType]`, `optional`, default to :code:`None`):
user-defined additional python modules to be saved alongside the model or custom objects,
e.g. a tokenizer module, preprocessor module, model configuration module
metadata (:code:`Dict[str, Any]`, `optional`, default to :code:`None`):
Custom metadata for given model.
Returns:
:obj:`~bentoml.Tag`: A :obj:`tag` with a format `name:version` where `name` is the user-defined model's name, and a generated `version` by BentoML.
Examples:
.. code-block:: python
import torch
import bentoml
class NGramLanguageModeler(nn.Module):
def __init__(self, vocab_size, embedding_dim, context_size):
super(NGramLanguageModeler, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.linear1 = nn.Linear(context_size * embedding_dim, 128)
self.linear2 = nn.Linear(128, vocab_size)
def forward(self, inputs):
embeds = self.embeddings(inputs).view((1, -1))
out = F.relu(self.linear1(embeds))
out = self.linear2(out)
log_probs = F.log_softmax(out, dim=1)
return log_probs
tag = bentoml.pytorch.save("ngrams", NGramLanguageModeler(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE))
# example tag: ngrams:20201012_DE43A2
Integration with Torch Hub and BentoML:
.. code-block:: python
import torch
import bentoml
resnet50 = torch.hub.load("pytorch/vision", "resnet50", pretrained=True)
...
# trained a custom resnet50
tag = bentoml.pytorch.save("resnet50", resnet50)
"""
if not LazyType("torch.nn.Module").isinstance(model):
raise TypeError(f"Given model ({model}) is not a torch.nn.Module.")
context: ModelContext = ModelContext(
framework_name="torch",
framework_versions={"torch": get_pkg_version("torch")},
)
if signatures is None:
signatures = {"__call__": {"batchable": False}}
logger.info(
'Using the default model signature for PyTorch (%s) for model "%s".',
signatures,
name,
)
with bentoml.models.create(
name,
module=MODULE_NAME,
api_version=API_VERSION,
labels=labels,
signatures=signatures,
custom_objects=custom_objects,
external_modules=external_modules,
options=ModelOptions(),
context=context,
metadata=metadata,
) as bento_model:
weight_file = bento_model.path_of(MODEL_FILENAME)
with open(weight_file, "wb") as file:
torch.save(model, file, pickle_module=cloudpickle) # type: ignore
return bento_model
def get_runnable(bento_model: Model):
"""
Private API: use :obj:`~bentoml.Model.to_runnable` instead.
"""
from .common.pytorch import partial_class
from .common.pytorch import PytorchModelRunnable
from .common.pytorch import make_pytorch_runnable_method
opts = t.cast(ModelOptions, bento_model.info.options)
if get_pkg_version("torch") >= "2.0.0":
_load_model = partial(
load_model,
fullgraph=opts.fullgraph,
dynamic=opts.dynamic,
backend=opts.backend,
mode=opts.mode,
options=opts.options,
disable=opts.disable,
)
else:
_load_model = load_model
runnable_class = partial_class(
PytorchModelRunnable, bento_model=bento_model, loader=_load_model
)
for method_name, options in bento_model.info.signatures.items():
method_partial_kwargs = opts.partial_kwargs.get(method_name)
runnable_class.add_method(
make_pytorch_runnable_method(method_name, method_partial_kwargs),
name=method_name,
batchable=options.batchable,
batch_dim=options.batch_dim,
input_spec=options.input_spec,
output_spec=options.output_spec,
)
return runnable_class