Skip to content

Commit

Permalink
fix: async bento client methods (#3152)
Browse files Browse the repository at this point in the history
* fix: async bento client methods

* format?

* format

* address review comment
  • Loading branch information
sauyon committed Oct 28, 2022
1 parent 78d3e00 commit 4220b51
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions src/bentoml/client.py
@@ -1,9 +1,13 @@
from __future__ import annotations

import json
import typing as t
import asyncio
import functools
from abc import ABC
from abc import abstractmethod
from http.client import HTTPConnection
from urllib.parse import urlparse

import aiohttp
import starlette.requests
Expand All @@ -27,15 +31,30 @@ def __init__(self, svc: Service, server_url: str):

for name, api in self._svc.apis.items():
if not hasattr(self, name):
setattr(self, name, functools.partial(self._call, _bentoml_api=api))
setattr(
self, name, functools.partial(self._sync_call, _bentoml_api=api)
)

for name, api in self._svc.apis.items():
if not hasattr(self, f"async_{name}"):
setattr(
self,
f"async_{name}",
functools.partial(self._call, _bentoml_api=api),
)

def call(self, api_name: str, inp: t.Any = None, **kwargs: t.Any) -> t.Any:
asyncio.run(self.async_call(api_name, inp))
def call(self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any) -> t.Any:
return asyncio.run(self.async_call(bentoml_api_name, inp))

async def async_call(
self, api_name: str, inp: t.Any = None, **kwargs: t.Any
self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any
) -> t.Any:
return self._call(inp, _bentoml_api=self._svc.apis[api_name])
return await self._call(inp, _bentoml_api=self._svc.apis[bentoml_api_name])

def _sync_call(
self, inp: t.Any = None, *, _bentoml_api: InferenceAPI, **kwagrs: t.Any
):
return asyncio.run(self._call(inp, _bentoml_api=_bentoml_api))

@abstractmethod
async def _call(
Expand All @@ -44,12 +63,16 @@ async def _call(
raise NotImplementedError

@staticmethod
async def from_url(server_url: str) -> Client:
def from_url(server_url: str) -> Client:
server_url = server_url if "://" in server_url else "http://" + server_url
url_parts = urlparse(server_url)

# TODO: SSL and grpc support
# connection is passed off to the client, and so is not closed
async with aiohttp.ClientSession(server_url) as conn:
async with conn.get("/docs.json") as resp:
openapi_spec = await resp.json()
conn = HTTPConnection(url_parts.netloc)
conn.request("GET", "/docs.json")
resp = conn.getresponse()
openapi_spec = json.load(resp)
conn.close()

dummy_service = Service(openapi_spec["info"]["title"])

Expand Down

0 comments on commit 4220b51

Please sign in to comment.