forked from bentoml/BentoML
/
inference_api.py
171 lines (145 loc) · 6.26 KB
/
inference_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from __future__ import annotations
import re
import typing as t
import inspect
from typing import Optional
import yaml
from ..types import is_compatible_type
from ..context import InferenceApiContext as Context
from ...exceptions import InvalidArgument
from ..io_descriptors import IODescriptor
RESERVED_API_NAMES = [
"index",
"swagger",
"docs",
"metrics",
"healthz",
"livez",
"readyz",
]
class InferenceAPI:
def __init__(
self,
user_defined_callback: t.Callable[..., t.Any],
input_descriptor: IODescriptor[t.Any],
output_descriptor: IODescriptor[t.Any],
name: Optional[str],
doc: Optional[str] = None,
route: Optional[str] = None,
):
# Use user_defined_callback function variable if name not specified
name = user_defined_callback.__name__ if name is None else name
# Use user_defined_callback function docstring `__doc__` if doc not specified
doc = user_defined_callback.__doc__ if doc is None else doc
# Use API name as route if route not specified
route = name if route is None else route
InferenceAPI._validate_name(name)
InferenceAPI._validate_route(route)
self.name = name
self.needs_ctx = False
self.ctx_param = None
self.func = user_defined_callback
input_type = input_descriptor.input_type()
self.multi_input = isinstance(input_type, dict)
sig = inspect.signature(user_defined_callback)
if len(sig.parameters) == 0:
raise ValueError("Expected API function to take parameters.")
if isinstance(input_type, dict):
# note: in python 3.6 kwarg order was not guaranteed to be preserved,
# though it is in practice.
for key in sig.parameters:
if key not in input_type:
if (
key in ["context", "ctx"]
or sig.parameters[key].annotation == Context
):
if self.needs_ctx:
raise ValueError(
f"API function has two context parameters: '{self.ctx_param}' and '{key}'; it should only have one."
)
self.needs_ctx = True
self.ctx_param = key
continue
raise ValueError(
f"API function has extra parameter with name '{key}'."
)
annotation: t.Type[t.Any] = sig.parameters[key].annotation
if (
isinstance(annotation, t.Type)
and annotation != inspect.Signature.empty
):
# if type annotations have been successfully resolved
if not is_compatible_type(input_type[key], annotation):
raise TypeError(
f"Expected type of argument '{key}' to be '{input_type[key]}', got '{sig.parameters[key].annotation}'"
)
expected_args = len(input_type) + (1 if self.needs_ctx else 0)
if len(sig.parameters) != expected_args:
raise ValueError(
f"expected API function to have arguments ({', '.join(input_type.keys())}, [context]), got ({', '.join(sig.parameters.keys())})"
)
else:
param_iter = iter(sig.parameters)
first_arg = next(param_iter)
annotation = sig.parameters[first_arg].annotation
if isinstance(annotation, t.Type) and annotation != inspect.Signature.empty:
if not is_compatible_type(input_type, annotation):
raise TypeError(
f"Expected type of argument '{first_arg}' to be '{input_type}', got '{sig.parameters[first_arg].annotation}'"
)
if len(sig.parameters) > 2:
raise ValueError("API function should only take one or two arguments")
elif len(sig.parameters) == 2:
self.needs_ctx = True
second_arg = next(param_iter)
annotation = sig.parameters[second_arg].annotation
if (
isinstance(annotation, t.Type)
and annotation != inspect.Signature.empty
):
if not annotation == Context:
raise TypeError(
f"Expected type of argument '{second_arg}' to be '{input_type}', got '{sig.parameters[second_arg].annotation}'"
)
self.input = input_descriptor
self.output = output_descriptor
self.doc = doc
self.route = route
def __str__(self):
return f"{self.__class__.__name__}({str(self.input)} → {str(self.output)})"
@staticmethod
def _validate_name(api_name: str):
if not api_name.isidentifier():
raise InvalidArgument(
"Invalid API name: '{}', a valid identifier may only contain letters,"
" numbers, underscores and not starting with a number.".format(api_name)
)
if api_name in RESERVED_API_NAMES:
raise InvalidArgument(
"Reserved API name: '{}' is reserved for infra endpoints".format(
api_name
)
)
@staticmethod
def _validate_route(route: str):
if re.findall(
r"[?#]+|^(//)|^:", route
): # contains '?' or '#' OR start with '//' OR start with ':'
# https://tools.ietf.org/html/rfc3986#page-22
raise InvalidArgument(
"The path {} contains illegal url characters".format(route)
)
if route in RESERVED_API_NAMES:
raise InvalidArgument(
"Reserved API route: '{}' is reserved for infra endpoints".format(route)
)
def _InferenceAPI_dumper(dumper: yaml.Dumper, api: InferenceAPI) -> yaml.Node:
return dumper.represent_dict(
{
"route": api.route,
"doc": api.doc,
"input": api.input.__class__.__name__,
"output": api.output.__class__.__name__,
}
)
yaml.add_representer(InferenceAPI, _InferenceAPI_dumper)