Skip to content

Commit

Permalink
feature(runner): add multiple output support
Browse files Browse the repository at this point in the history
  • Loading branch information
larme committed Aug 17, 2022
1 parent 1e6a65d commit e54024f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 4 deletions.
4 changes: 3 additions & 1 deletion bentoml/_internal/frameworks/keras.py
Expand Up @@ -321,7 +321,7 @@ def _mapping(item: "KerasArgType") -> "tf_ext.TensorLike":

def _run_method(
runnable_self: KerasRunnable, *args: "KerasArgType"
) -> "ext.NpNDArray":
) -> "ext.NpNDArray" | t.Tuple["ext.NpNDArray", ...]:

params = Params["KerasArgType"](*args)

Expand All @@ -339,6 +339,8 @@ def _run_method(
).isinstance(res):
return t.cast("ext.NpNDArray", res.numpy())

if isinstance(res, list):
return tuple(res)
return res

return _run_method
Expand Down
4 changes: 2 additions & 2 deletions bentoml/_internal/runner/runnable.py
Expand Up @@ -149,6 +149,6 @@ def __set_name__(self, owner: t.Any, name: str):
@attr.define
class RunnableMethodConfig:
batchable: bool
batch_dim: tuple[int, int]
input_spec: AnyType | t.Tuple[AnyType, ...] | None = None
batch_dim: tuple[int, int | tuple[int, ...]]
input_spec: AnyType | tuple[AnyType, ...] | None = None
output_spec: AnyType | None = None
4 changes: 4 additions & 0 deletions bentoml/_internal/runner/runner_handle/remote.py
Expand Up @@ -185,6 +185,10 @@ async def async_run_method(
f"Bento payload decode error: invalid Content-Type '{content_type}'."
)

if content_type == "application/vnd.bentoml.multiple_outputs":
payloads = pickle.loads(body)
return tuple(AutoContainer.from_payload(payload) for payload in payloads)

container = content_type.strip("application/vnd.bentoml.")

try:
Expand Down
37 changes: 36 additions & 1 deletion bentoml/_internal/server/runner_app.py
Expand Up @@ -185,6 +185,41 @@ async def _run(requests: t.Iterable[Request]) -> list[Response]:
*batched_params.args, **batched_params.kwargs
)

server_str = f"BentoML-Runner/{self.runner.name}/{runner_method.name}/{self.worker_index}"

# multiple output branch
if isinstance(batch_ret, tuple):
output_num = len(batch_ret)
if isinstance(output_batch_dim, int):
output_batch_dim = (output_batch_dim,) * output_num
else:
assert (
len(output_batch_dim) == output_num
), "output_batch_dim length should be equal to the number of outputs"

payloadss = [
AutoContainer.batch_to_payloads(
batch_ret[idx], indices, batch_dim=output_batch_dim[idx]
)
for idx in range(output_num)
]

return [
Response(
pickle.dumps(payloads),
headers={
PAYLOAD_META_HEADER: json.dumps({}),
"Content-Type": "application/vnd.bentoml.multiple_outputs",
"Server": server_str,
},
)
for payloads in zip(*payloadss)
]

# single output branch
assert isinstance(
output_batch_dim, int
), "output_batch_dim's should be int for single output"
payloads = AutoContainer.batch_to_payloads(
batch_ret,
indices,
Expand All @@ -197,7 +232,7 @@ async def _run(requests: t.Iterable[Request]) -> list[Response]:
headers={
PAYLOAD_META_HEADER: json.dumps(payload.meta),
"Content-Type": f"application/vnd.bentoml.{payload.container}",
"Server": f"BentoML-Runner/{self.runner.name}/{runner_method.name}/{self.worker_index}",
"Server": server_str,
},
)
for payload in payloads
Expand Down

0 comments on commit e54024f

Please sign in to comment.