diff --git a/src/bentoml/_internal/client/__init__.py b/src/bentoml/_internal/client/__init__.py index e41d518810b..38a9253052a 100644 --- a/src/bentoml/_internal/client/__init__.py +++ b/src/bentoml/_internal/client/__init__.py @@ -57,6 +57,16 @@ async def async_call( inp, _bentoml_api=self._svc.apis[bentoml_api_name], **kwargs ) + @abstractmethod + def wait_until_server_ready( + self, + *, + server_url: str | None = None, + timeout: int = 30, + **kwargs: t.Any, + ) -> None: + raise NotImplementedError + @t.overload @classmethod @abstractmethod diff --git a/src/bentoml/_internal/client/grpc.py b/src/bentoml/_internal/client/grpc.py index ba9b99ac9f0..94081baf20e 100644 --- a/src/bentoml/_internal/client/grpc.py +++ b/src/bentoml/_internal/client/grpc.py @@ -1,6 +1,8 @@ from __future__ import annotations +import time import typing as t +import asyncio import logging import functools from typing import TYPE_CHECKING @@ -28,7 +30,9 @@ import grpc from grpc import aio + from grpc_health.v1 import health_pb2 as pb_health from google.protobuf import json_format as _json_format + from google.protobuf.internal import python_message from ..types import PathType from ...grpc.v1.service_pb2 import Response @@ -41,6 +45,7 @@ class ClientCredentials(t.TypedDict): else: grpc, aio = import_grpc() + pb_health = LazyLoader("pb_health", globals(), "grpc_health.v1.health_pb2") _json_format = LazyLoader( "_json_format", globals(), @@ -80,6 +85,7 @@ def __init__( for k, v in ssl_client_credentials.items() } ) + self._call_rpc = f"/bentoml.grpc.{protocol_version}.BentoService/Call" super().__init__(svc, server_url) @property @@ -100,14 +106,40 @@ def channel(self): interceptors=self._interceptors, ) + def wait_until_server_ready( + self, + *, + server_url: str | None = None, + timeout: int = 30, + check_interval: float = 1, + **kwargs: t.Any, + ) -> None: + start_time = time.time() + while time.time() - start_time < timeout: + try: + res = asyncio.run( + self._health(service_name=self._call_rpc, timeout=timeout) + ) + if res.status == pb_health.HealthCheckResponse.SERVING: + break + else: + asyncio.run(asyncio.sleep(check_interval)) + except aio.AioRpcError as err: + logger.debug("[%s] Retrying to connect to the host %s", err, server_url) + asyncio.run(asyncio.sleep(check_interval)) + raise TimeoutError( + f"Timed out waiting {timeout} seconds for server at '{server_url}' to be ready." + ) + @cached_property - def _rpc_metadata(self): + def _rpc_metadata(self) -> dict[str, dict[str, t.Any]]: # Currently all RPCs in BentoService are unary-unary + # NOTE: we will set the types of the stubs to be Any. return { method: {"input_type": input_type, "output_type": output_type} for method, input_type, output_type in ( ( - f"/bentoml.grpc.{self._protocol_version}.BentoService/Call", + self._call_rpc, self._pb.Request, self._pb.Response, ), @@ -116,9 +148,21 @@ def _rpc_metadata(self): self._pb.ServiceMetadataRequest, self._pb.ServiceMetadataResponse, ), + ( + "/grpc.health.v1.Health/Check", + pb_health.HealthCheckRequest, + pb_health.HealthCheckResponse, + ), ) } + async def _health(self, service_name: str, *, timeout: int = 30) -> t.Any: + return await self._invoke( + method_name="/grpc.health.v1.Health/Check", + service=service_name, + _grpc_channel_timeout=timeout, + ) + async def _invoke(self, method_name: str, **attrs: t.Any): # channel kwargs include timeout, metadata, credentials, wait_for_ready and compression # to pass it in kwargs add prefix _grpc_channel_ diff --git a/src/bentoml/_internal/client/http.py b/src/bentoml/_internal/client/http.py index db97bda4dc4..e50682beb61 100644 --- a/src/bentoml/_internal/client/http.py +++ b/src/bentoml/_internal/client/http.py @@ -1,8 +1,12 @@ from __future__ import annotations import json +import time +import socket import typing as t import logging +import urllib.error +import urllib.request from http.client import HTTPConnection from urllib.parse import urlparse @@ -21,6 +25,35 @@ class HTTPClient(Client): + def wait_until_server_ready( + self, + *, + server_url: str | None = None, + timeout: int = 30, + check_interval: int = 1, + # set kwargs here to omit gRPC kwargs + **kwargs: t.Any, + ) -> None: + start_time = time.time() + if server_url is None: + server_url = self.server_url + + proxy_handler = urllib.request.ProxyHandler({}) + opener = urllib.request.build_opener(proxy_handler) + logger.debug("Waiting for host %s to be ready.", server_url) + while time.time() - start_time < timeout: + try: + if opener.open(f"http://{server_url}/readyz", timeout=1).status == 200: + break + else: + time.sleep(check_interval) + except (ConnectionError, urllib.error.URLError, socket.timeout) as err: + logger.debug("[%s] Retrying to connect to the host %s", err, server_url) + time.sleep(check_interval) + raise TimeoutError( + f"Timed out waiting {timeout} seconds for server at '{server_url}' to be ready." + ) + @classmethod def from_url(cls, server_url: str, **kwargs: t.Any) -> HTTPClient: server_url = server_url if "://" in server_url else "http://" + server_url diff --git a/src/bentoml/_internal/server/server.py b/src/bentoml/_internal/server/server.py index a94b1d05d8a..17553431e5e 100644 --- a/src/bentoml/_internal/server/server.py +++ b/src/bentoml/_internal/server/server.py @@ -28,12 +28,11 @@ def client(self): return self.get_client() def get_client(self): - from bentoml.client import Client + from ..client import Client - Client.wait_until_server_is_ready( - host=self.host, port=self.port, timeout=self.timeout - ) - return Client.from_url(f"http://localhost:{self.port}") + client = Client.from_url(f"http://{self.host}:{self.port}") + client.wait_until_server_ready(timeout=10) + return client def stop(self) -> None: self.process.kill() @@ -47,7 +46,7 @@ def __enter__(self): def __exit__( self, - exc_type: type[BaseException], + exc_type: type[BaseException] | None, exc_value: BaseException, traceback_type: TracebackType, ):