/
text.py
174 lines (129 loc) · 5.24 KB
/
text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from __future__ import annotations
import typing as t
from typing import TYPE_CHECKING
from starlette.requests import Request
from starlette.responses import Response
from bentoml.exceptions import BentoMLException
from .base import IODescriptor
from ..utils.http import set_cookies
from ..service.openapi import SUCCESS_DESCRIPTION
from ..utils.lazy_loader import LazyLoader
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType
if TYPE_CHECKING:
from google.protobuf import wrappers_pb2
from typing_extensions import Self
from .base import SpecDict
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context
else:
wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2")
MIME_TYPE = "text/plain"
class Text(IODescriptor[str], descriptor_id="bentoml.io.Text"):
"""
:obj:`Text` defines API specification for the inputs/outputs of a Service. :obj:`Text`
represents strings for all incoming requests/outcoming responses as specified in
your API function signature.
A sample GPT2 service implementation:
.. code-block:: python
:caption: `service.py`
from __future__ import annotations
import bentoml
from bentoml.io import Text
runner = bentoml.tensorflow.get('gpt2:latest').to_runner()
svc = bentoml.Service("gpt2-generation", runners=[runner])
@svc.api(input=Text(), output=Text())
def predict(text: str) -> str:
res = runner.run(text)
return res['generated_text']
Users then can then serve this service with :code:`bentoml serve`:
.. code-block:: bash
% bentoml serve ./service.py:svc --reload
Users can then send requests to the newly started services with any client:
.. tab-set::
.. tab-item:: Bash
.. code-block:: bash
% curl -X POST -H "Content-Type: text/plain" \\
--data 'Not for nothing did Orin say that people outdoors.' \\
http://0.0.0.0:3000/predict
.. tab-item:: Python
.. code-block:: python
:caption: `request.py`
import requests
requests.post(
"http://0.0.0.0:3000/predict",
headers = {"content-type":"text/plain"},
data = 'Not for nothing did Orin say that people outdoors.'
).text
.. note::
:obj:`Text` is not designed to take any ``args`` or ``kwargs`` during initialization.
Returns:
:obj:`Text`: IO Descriptor that represents strings type.
"""
_proto_fields = ("text",)
_mime_type = MIME_TYPE
def __init__(self, *args: t.Any, **kwargs: t.Any):
if args or kwargs:
raise BentoMLException(
f"'{self.__class__.__name__}' is not designed to take any args or kwargs during initialization."
) from None
def _from_sample(self, sample: str | bytes) -> str:
if isinstance(sample, bytes):
sample = sample.decode("utf-8")
return sample
def input_type(self) -> t.Type[str]:
return str
def to_spec(self) -> dict[str, t.Any]:
return {"id": self.descriptor_id}
@classmethod
def from_spec(cls, spec: SpecDict) -> Self:
return cls()
def openapi_schema(self) -> Schema:
return Schema(type="string")
def openapi_components(self) -> dict[str, t.Any] | None:
pass
def openapi_example(self):
return str(self.sample)
def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {
self._mime_type: MediaType(
schema=self.openapi_schema(), example=self.openapi_example()
)
},
"required": True,
"x-bentoml-io-descriptor": self.to_spec(),
}
def openapi_responses(self) -> OpenAPIResponse:
return {
"description": SUCCESS_DESCRIPTION,
"content": {
self._mime_type: MediaType(
schema=self.openapi_schema(), example=self.openapi_example()
)
},
"x-bentoml-io-descriptor": self.to_spec(),
}
async def from_http_request(self, request: Request) -> str:
obj = await request.body()
return str(obj.decode("utf-8"))
async def to_http_response(self, obj: str, ctx: Context | None = None) -> Response:
if ctx is not None:
res = Response(
obj,
media_type=self._mime_type,
headers=ctx.response.metadata, # type: ignore (bad starlette types)
status_code=ctx.response.status_code,
)
set_cookies(res, ctx.response.cookies)
return res
else:
return Response(obj, media_type=self._mime_type)
async def from_proto(self, field: wrappers_pb2.StringValue | bytes) -> str:
if isinstance(field, bytes):
return field.decode("utf-8")
else:
assert isinstance(field, wrappers_pb2.StringValue)
return field.value
async def to_proto(self, obj: str) -> wrappers_pb2.StringValue:
return wrappers_pb2.StringValue(value=obj)