Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EXPERIMENTAL[grpc]: synonymous configuration fields #2980

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
253 changes: 177 additions & 76 deletions bentoml/_internal/configuration/containers.py
Expand Up @@ -51,6 +51,7 @@
"zipkin",
"jaeger",
"otlp",
"in_memory",
)
_check_otlp_protocol: t.Callable[[str], bool] = lambda s: s in (
"grpc",
Expand Down Expand Up @@ -178,12 +179,37 @@ def _is_ip_address(addr: str) -> bool:
"tracing": {
"type": Or(And(str, Use(str.lower), _check_tracing_type), None),
"sample_rate": Or(And(float, lambda i: i >= 0 and i <= 1), None),
"timeout": Or(int, None),
"max_tag_value_length": Or(int, None),
"excluded_urls": Or([str], str, None),
Optional("zipkin"): {"url": Or(str, None)},
Optional("jaeger"): {"address": Or(str, None), "port": Or(int, None)},
Optional("otlp"): {
"zipkin": {
"endpoint": Or(str, None),
},
"jaeger": {
"protocol": Or(
And(str, Use(str.lower), lambda d: d in ["thrift", "grpc"]), None
),
"collector_endpoint": Or(str, None),
"thrift": {
"agent_host_name": Or(str, None),
"agent_port": Or(int, None),
"udp_split_oversized_batches": Or(bool, None),
},
"grpc": {
"insecure": Or(bool, None),
},
},
"otlp": {
"protocol": Or(And(str, Use(str.lower), _check_otlp_protocol), None),
"url": Or(str, None),
"endpoint": Or(str, None),
"compression": Or(
And(str, lambda d: d in {"gzip", "none", "deflate"}), None
),
"http": {"certificate_file": Or(str, None), "headers": Or(dict, None)},
"grpc": {
"insecure": Or(bool, None),
"headers": Or(lambda d: isinstance(d, t.Sequence), None),
},
},
},
Optional("yatai"): {
Expand All @@ -205,9 +231,88 @@ def _is_ip_address(addr: str) -> bool:
}
)

_WARNING_MESSAGE = (
"field 'api_server.%s' is deprecated and has been renamed to 'api_server.http.%s'"
)
_WARNING_MESSAGE = "Field '%s.%s' is deprecated and has been renamed to '%s'."


def v1_to_v2_migration(config_merger: Merger, override_config: dict[str, t.Any]):
# api_server compat
if "api_server" in override_config:
deprecated = ["port", "host", "cors"]
api_server_config = override_config["api_server"]
# max_request_size is deprecated
if "max_request_size" in api_server_config:
logger.warning(
"Field 'api_server.max_request_size' is deprecated and has become obsolete."
)
api_server_config.pop("max_request_size")
# check if user are using older configuration
if "http" not in api_server_config:
api_server_config["http"] = {}
# compat layer
for field in deprecated:
if field in api_server_config:
old_field = api_server_config.pop(field)
api_server_config["http"][field] = old_field
logger.warning(
_WARNING_MESSAGE, "api_server", field, f"api_server.http.{field}"
)
config_merger.merge(override_config["api_server"], api_server_config)
assert all(
key not in override_config["api_server"]
for key in ["cors", "max_request_size", "host", "port"]
)
# tracing compat
if "tracing" in override_config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a utility function rename_field might be useful? e.g. rename_field(config, "zipkin.url", "zipkin.endpoint") or something? I think it would make this easier to read.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg.

tracing_config = override_config["tracing"]
for exporter in ["zipkin", "otlp"]:
# 'url' should be called 'endpoint' to be consistent with exporters docs
if exporter in tracing_config:
exporter_config = tracing_config[exporter]
if "url" in exporter_config:
old_field = exporter_config.pop("url")
exporter_config["endpoint"] = old_field
logger.warning(
_WARNING_MESSAGE,
f"tracing.{exporter}",
"url",
f"tracing.{exporter}.endpoint",
)
assert "url" not in override_config["tracing"][exporter]
if "jaeger" in tracing_config:
jaeger = tracing_config["jaeger"]
# check if user are using older configuration
if "thrift" not in jaeger:
jaeger["thrift"] = {}
# default to thrift for HTTP if any of the old fields are present
# This is for users who are using older configuration.
# compat layer
if "address" in jaeger:
old_field = jaeger.pop("address")
jaeger["thrift"]["agent_host_name"] = old_field
logger.warning(
_WARNING_MESSAGE,
"tracing.jaeger",
"address",
"tracing.jaeger.agent_host_name",
)
if "port" in jaeger:
old_field = jaeger.pop("port")
jaeger["thrift"]["agent_port"] = old_field
logger.warning(
_WARNING_MESSAGE,
"tracing.jaeger",
"port",
"tracing.jaeger.agent_port",
)
# ideally we also want to sync this protocol when users is using serve-grpc
# since jaeger will only export gRPC traces if the protocol is set to gRPC.
if "protocol" not in jaeger and len(jaeger["thrift"]) != 0:
jaeger["protocol"] = "thrift"
assert all(
key not in override_config["tracing"]["jaeger"]
for key in ["address", "port"]
)
config_merger.merge(override_config["tracing"], tracing_config)


class BentoMLConfiguration:
Expand Down Expand Up @@ -242,32 +347,12 @@ def __init__(
with open(override_config_file, "rb") as f:
override_config: dict[str, t.Any] = yaml.safe_load(f)

# compatibility layer with old configuration pre gRPC features
# api_server.[cors|port|host] -> api_server.http.$^
if "api_server" in override_config:
user_api_config = override_config["api_server"]
# max_request_size is deprecated
if "max_request_size" in user_api_config:
logger.warning(
"'api_server.max_request_size' is deprecated and has become obsolete."
)
user_api_config.pop("max_request_size")
# check if user are using older configuration
if "http" not in user_api_config:
user_api_config["http"] = {}
# then migrate these fields to newer configuration fields.
for field in ["port", "host", "cors"]:
if field in user_api_config:
old_field = user_api_config.pop(field)
user_api_config["http"][field] = old_field
logger.warning(_WARNING_MESSAGE, field, field)

config_merger.merge(override_config["api_server"], user_api_config)

assert all(
key not in override_config["api_server"]
for key in ["cors", "max_request_size", "host", "port"]
)
# v1 -> v2 changes follow:
# - api_server.[max_request_size|cors|port|host] -> api_server.http.$^
# - tracing.jaeger.* -> tracing.jaeger.http.*
# - add tracing.jaeger.grpc.* with default values
# TODO: follow up versioning and deprecation generation.
v1_to_v2_migration(config_merger, override_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only call this if there's no version specified, i.e. it's a v1 config.

We need to figure out some way to enforce people to version their configs though.

Copy link
Member Author

@aarnphm aarnphm Sep 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to have a version field

version: v1
...

if no version is specified, then we default that to v1


config_merger.merge(self.config, override_config)

Expand All @@ -293,6 +378,13 @@ def __init__(
raise BentoMLConfigException(
f"Failed to parse config options from the env var: {e}. \n *** Note: You can use '\"' to quote the key if it contains special characters. ***"
) from None
# v1 -> v2 changes follow:
# - api_server.[max_request_size|cors|port|host] -> api_server.http.$^
# - tracing.jaeger.* -> tracing.jaeger.http.*
# - add tracing.jaeger.grpc.* with default values
# TODO: follow up versioning and deprecation generation.
v1_to_v2_migration(config_merger, override_config)

config_merger.merge(self.config, override_config)

if override_config_file is not None or override_config_values is not None:
Expand All @@ -303,8 +395,8 @@ def __init__(
SCHEMA.validate(self.config)
except SchemaError as e:
raise BentoMLConfigException(
"Invalid configuration file was given."
) from e
f"Invalid configuration file was given: {e}"
) from None

def _finalize(self):
global_runner_cfg = {k: self.config["runners"][k] for k in RUNNER_CFG_KEYS}
Expand Down Expand Up @@ -472,12 +564,12 @@ def metrics_client(
@staticmethod
def tracer_provider(
tracer_type: str = Provide[config.tracing.type],
sample_rate: t.Optional[float] = Provide[config.tracing.sample_rate],
zipkin_server_url: t.Optional[str] = Provide[config.tracing.zipkin.url],
jaeger_server_address: t.Optional[str] = Provide[config.tracing.jaeger.address],
jaeger_server_port: t.Optional[int] = Provide[config.tracing.jaeger.port],
otlp_server_protocol: t.Optional[str] = Provide[config.tracing.otlp.protocol],
otlp_server_url: t.Optional[str] = Provide[config.tracing.otlp.url],
sample_rate: float | None = Provide[config.tracing.sample_rate],
timeout: int | None = Provide[config.tracing.timeout],
max_tag_value_length: int | None = Provide[config.tracing.max_tag_value_length],
zipkin: dict[str, t.Any] = Provide[config.tracing.zipkin],
jaeger: dict[str, t.Any] = Provide[config.tracing.jaeger],
otlp: dict[str, t.Any] = Provide[config.tracing.otlp],
):
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.resources import Resource
Expand All @@ -494,9 +586,8 @@ def tracer_provider(

if sample_rate is None:
sample_rate = 0.0

_check_sample_rate(sample_rate)
resource = {}

# User can optionally configure the resource with the following environment variables. Only
# configure resource if user has not explicitly configured it.
if (
Expand All @@ -511,52 +602,62 @@ def tracer_provider(
resource[SERVICE_NAMESPACE] = component_context.bento_name
if component_context.bento_version:
resource[SERVICE_VERSION] = component_context.bento_version

# create tracer provider
provider = TracerProvider(
sampler=ParentBasedTraceIdRatio(sample_rate),
resource=Resource.create(resource),
)

if tracer_type == "zipkin" and zipkin_server_url is not None:
if tracer_type == "zipkin" and any(zipkin.values()):
from opentelemetry.exporter.zipkin.json import ZipkinExporter

exporter = ZipkinExporter(endpoint=zipkin_server_url)
provider.add_span_processor(BatchSpanProcessor(exporter))
_check_sample_rate(sample_rate)
return provider
elif (
tracer_type == "jaeger"
and jaeger_server_address is not None
and jaeger_server_port is not None
):
from opentelemetry.exporter.jaeger.thrift import JaegerExporter

exporter = ZipkinExporter(
endpoint=zipkin.get("endpoint"),
max_tag_value_length=max_tag_value_length,
timeout=timeout,
)
elif tracer_type == "jaeger" and any(jaeger.values()):
protocol = jaeger.get("protocol")
if protocol == "thrift":
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
elif protocol == "grpc":
from opentelemetry.exporter.jaeger.proto.grpc import JaegerExporter
else:
raise InvalidArgument(
f"Invalid 'tracing.jaeger.protocol' value: {protocol}"
) from None
exporter = JaegerExporter(
agent_host_name=jaeger_server_address, agent_port=jaeger_server_port
collector_endpoint=jaeger.get("collector_endpoint"),
max_tag_value_length=max_tag_value_length,
timeout=timeout,
**jaeger[protocol],
)
provider.add_span_processor(BatchSpanProcessor(exporter))
_check_sample_rate(sample_rate)
return provider
elif (
tracer_type == "otlp"
and otlp_server_protocol is not None
and otlp_server_url is not None
):
if otlp_server_protocol == "grpc":
elif tracer_type == "otlp" and any(otlp.values()):
protocol = otlp.get("protocol")
if protocol == "grpc":
from opentelemetry.exporter.otlp.proto.grpc import trace_exporter

elif otlp_server_protocol == "http":
elif protocol == "http":
from opentelemetry.exporter.otlp.proto.http import trace_exporter
else:
raise InvalidArgument(
f"Invalid otlp protocol: {otlp_server_protocol}"
) from None
exporter = trace_exporter.OTLPSpanExporter(endpoint=otlp_server_url)
provider.add_span_processor(BatchSpanProcessor(exporter))
_check_sample_rate(sample_rate)
return provider
raise InvalidArgument(f"Invalid otlp protocol: {protocol}") from None
exporter = trace_exporter.OTLPSpanExporter(
endpoint=otlp.get("endpoint", None),
compression=otlp.get("compression", None),
timeout=timeout,
**otlp[protocol],
)
elif tracer_type == "in_memory":
# This will be used during testing, user shouldn't use this otherwise.
# We won't document this in documentation.
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)

exporter = InMemorySpanExporter()
else:
return provider
# When exporter is set
provider.add_span_processor(BatchSpanProcessor(exporter))
return provider

@providers.SingletonFactory
@staticmethod
Expand Down
23 changes: 19 additions & 4 deletions bentoml/_internal/configuration/default_configuration.yaml
Expand Up @@ -61,11 +61,26 @@ tracing:
type: zipkin
sample_rate: ~
excluded_urls: ~
timeout: ~
max_tag_value_length: ~
zipkin:
url: ~
endpoint: ~
jaeger:
address: ~
port: ~
protocol: thrift
collector_endpoint: ~
thrift:
agent_host_name: ~
agent_port: ~
udp_split_oversized_batches: ~
grpc:
insecure: ~
otlp:
protocol: ~
url: ~
endpoint: ~
compression: ~
http:
certificate_file: ~
headers: ~
grpc:
headers: ~
insecure: ~