diff --git a/src/bentoml/client.py b/src/bentoml/client.py index cef67a1eac..4b08fd4514 100644 --- a/src/bentoml/client.py +++ b/src/bentoml/client.py @@ -1,9 +1,13 @@ from __future__ import annotations +import json import typing as t +import asyncio import functools from abc import ABC from abc import abstractmethod +from http.client import HTTPConnection +from urllib.parse import urlparse import aiohttp import starlette.requests @@ -27,15 +31,30 @@ def __init__(self, svc: Service, server_url: str): for name, api in self._svc.apis.items(): if not hasattr(self, name): - setattr(self, name, functools.partial(self._call, _bentoml_api=api)) + setattr( + self, name, functools.partial(self._sync_call, _bentoml_api=api) + ) + + for name, api in self._svc.apis.items(): + if not hasattr(self, f"async_{name}"): + setattr( + self, + f"async_{name}", + functools.partial(self._call, _bentoml_api=api), + ) - def call(self, api_name: str, inp: t.Any = None, **kwargs: t.Any) -> t.Any: - asyncio.run(self.async_call(api_name, inp)) + def call(self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any) -> t.Any: + return asyncio.run(self.async_call(bentoml_api_name, inp)) async def async_call( - self, api_name: str, inp: t.Any = None, **kwargs: t.Any + self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any ) -> t.Any: - return self._call(inp, _bentoml_api=self._svc.apis[api_name]) + return await self._call(inp, _bentoml_api=self._svc.apis[bentoml_api_name]) + + def _sync_call( + self, inp: t.Any = None, *, _bentoml_api: InferenceAPI, **kwagrs: t.Any + ): + return asyncio.run(self._call(inp, _bentoml_api=_bentoml_api)) @abstractmethod async def _call( @@ -44,12 +63,16 @@ async def _call( raise NotImplementedError @staticmethod - async def from_url(server_url: str) -> Client: + def from_url(server_url: str) -> Client: + server_url = server_url if "://" in server_url else "http://" + server_url + url_parts = urlparse(server_url) + # TODO: SSL and grpc support - # connection is passed off to the client, and so is not closed - async with aiohttp.ClientSession(server_url) as conn: - async with conn.get("/docs.json") as resp: - openapi_spec = await resp.json() + conn = HTTPConnection(url_parts.netloc) + conn.request("GET", "/docs.json") + resp = conn.getresponse() + openapi_spec = json.load(resp) + conn.close() dummy_service = Service(openapi_spec["info"]["title"])