Skip to content

Commit

Permalink
feat: configuration versioning 1
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Oct 5, 2022
1 parent 7ecbacc commit 6c3cc60
Show file tree
Hide file tree
Showing 16 changed files with 183 additions and 361 deletions.
79 changes: 29 additions & 50 deletions bentoml/_internal/configuration/containers.py
Expand Up @@ -25,7 +25,6 @@
from ..context import component_context
from ..resource import CpuResource
from ..resource import system_resources
from ...exceptions import InvalidArgument
from ...exceptions import BentoMLConfigException
from ..utils.unflatten import unflatten

Expand All @@ -39,45 +38,24 @@

config_merger = Merger(
# merge dicts
[(dict, "merge")],
type_strategies=[(dict, "merge")],
# override all other types
["override"],
fallback_strategies=["override"],
# override conflicting types
["override"],
type_conflict_strategies=["override"],
)

logger = logging.getLogger(__name__)


RUNNER_CFG_KEYS = ["batching", "resources", "logging", "metrics", "timeout"]


CONFIG_LATEST_VERSION = 2


def migrate_up(from_version: int, to_version: int):
# All migrate up functions should only take a flattened config dict as input
migrate = getattr(
import_configuration_spec(version=from_version),
f"migrate_to_v{to_version}",
None,
)
if migrate is None:
raise InvalidArgument(
"Given version %d requires 'migrate_to_v%d' to update config to version %d"
% (from_version, to_version, to_version)
)
return migrate


class BentoMLConfiguration:
def __init__(
self,
override_config_file: str | None = None,
override_config_values: str | None = None,
*,
validate_schema: bool = True,
use_version: int = CONFIG_LATEST_VERSION,
use_version: int = 1,
):
# Load default configuration with latest version.
self.config = get_default_config(version=use_version)
Expand All @@ -93,23 +71,16 @@ def __init__(
# If users does not define a version, we then by default assume they are using v1
# and we will migrate it to latest version
logger.debug(
"User config does not define a version, assuming given config is version 1. Migrating to version %d..."
"User config does not define a version, assuming given config is version %d..."
% use_version
)
current = 1
current = use_version
else:
current = override["version"]
if current < use_version:
# Each version of configuration should have its own migration functions to the latest version
# i.e: v1 should have a migration function to v2 called "migrate_to_v2"
# Note that we should always migrate up, not down.
logger.debug(
"Migrating from version %d to version %d..."
% (current, use_version)
)
override = migrate_up(current, use_version)(
override_config=dict(flatten_dict(override))
)
compat = getattr(import_configuration_spec(current), "compat_layer", None)
# Running compatibliity layer if it exists
if compat:
override = compat(override_config=dict(flatten_dict(override)))
config_merger.merge(self.config, override)

if override_config_values is not None:
Expand All @@ -135,19 +106,21 @@ def __init__(
if "version" in override_config_map:
override_version = override_config_map["version"]
logger.debug(
"Found defined 'version=%d' in BENTOML_CONFIG_OPTIONS. We will migrate up to latest configuration if possible."
"Found defined 'version=%d' in BENTOML_CONFIG_OPTIONS."
% override_version
)
if override_version < use_version:
override_config_map = migrate_up(override_version, use_version)(
override_config=override_config_map
)
compat = getattr(
import_configuration_spec(override_version), "compat_layer", None
)
# Running compatibliity layer if it exists
if compat:
override_config_map = compat(override_config=override_config_map)
# Previous behaviour, before configuration versioning.
try:
override = unflatten(override_config_map)
except ValueError as e:
raise BentoMLConfigException(
f"Failed to parse config options from the env var: {e}.\n *** Note: You can use '\"' to quote the key if it contains special characters. ***"
f"Failed to parse config options from the env var:\n{e}.\n*** Note: You can use '\"' to quote the key if it contains special characters. ***"
) from None
config_merger.merge(self.config, override)

Expand All @@ -163,14 +136,20 @@ def __init__(
) from None

def _finalize(self):
global_runner_cfg = {k: self.config["runners"][k] for k in RUNNER_CFG_KEYS}
for key in self.config["runners"]:
if key not in RUNNER_CFG_KEYS:
runner_cfg = self.config["runners"][key]
GLOBAL_RUNNERS_KEY = ["batching", "resources", "logging", "metrics", "timeout"]
global_runner_cfg = {k: self.config["runners"][k] for k in GLOBAL_RUNNERS_KEY}
custom_runners_cfg = dict(
filter(
lambda kv: kv[0] not in GLOBAL_RUNNERS_KEY,
self.config["runners"].items(),
)
)
if custom_runners_cfg:
for runner_name, runner_cfg in custom_runners_cfg.items():
# key is a runner name
if runner_cfg.get("resources") == "system":
runner_cfg["resources"] = system_resources()
self.config["runners"][key] = config_merger.merge(
self.config["runners"][runner_name] = config_merger.merge(
deepcopy(global_runner_cfg),
runner_cfg,
)
Expand Down
@@ -1,4 +1,4 @@
version: 2
version: 1
api_server:
workers: ~ # cpu_count() will be used when null
timeout: 60
Expand Down Expand Up @@ -97,10 +97,9 @@ api_server:
grpc:
headers: ~
insecure: ~

runners:
timeout: 300
resources: ~
timeout: 300
batching:
enabled: true
max_batch_size: 100
Expand Down
13 changes: 8 additions & 5 deletions bentoml/_internal/configuration/helpers.py
Expand Up @@ -67,17 +67,18 @@ def rename_fields(
d[replace_with] = d.pop(current)


punctuation = r"""!"#$%&'()*+,-./:;<=>?@[\]^`{|}~"""


def flatten_dict(
d: t.MutableMapping[str, t.Any],
parent: str = "",
sep: str = ".",
) -> t.Generator[tuple[str, t.Any], None, None]:
"""Flatten nested dictionary into a single level dictionary."""
for k, v in d.items():
# TODO: we probably need to find a better way
# to normalize slash and special characters keys.
key = f'"{k}"' if "/" in k else k
nkey = parent + sep + key if parent else key
k = f'"{k}"' if any(i in punctuation for i in k) else k
nkey = parent + sep + k if parent else k
if isinstance(v, t.MutableMapping):
yield from flatten_dict(
t.cast(t.MutableMapping[str, t.Any], v), parent=nkey, sep=sep
Expand All @@ -99,7 +100,9 @@ def load_config_file(path: str) -> dict[str, t.Any]:

def get_default_config(version: int) -> dict[str, t.Any]:
config = load_config_file(
os.path.join(os.path.dirname(__file__), f"v{version}", "defaults.yaml")
os.path.join(
os.path.dirname(__file__), f"v{version}", "default_configuration.yaml"
)
)
mod = import_configuration_spec(version)
assert hasattr(mod, "SCHEMA"), (
Expand Down
120 changes: 84 additions & 36 deletions bentoml/_internal/configuration/v1/__init__.py
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
import typing as t

import schema as s
Expand All @@ -9,44 +10,66 @@
from ..helpers import rename_fields
from ..helpers import ensure_larger_than
from ..helpers import is_valid_ip_address
from ..helpers import ensure_iterable_type
from ..helpers import validate_tracing_type
from ..helpers import validate_otlp_protocol
from ..helpers import ensure_larger_than_zero
from ...utils.metrics import DEFAULT_BUCKET
from ...utils.unflatten import unflatten

__all__ = ["SCHEMA", "migrate_to_v2"]
__all__ = ["SCHEMA", "compat_layer"]

_TRACING_CONFIG = {
"type": s.Or(s.And(str, s.Use(str.lower), validate_tracing_type), None),
TRACING_CFG = {
"exporter_type": s.Or(s.And(str, s.Use(str.lower), validate_tracing_type), None),
"sample_rate": s.Or(s.And(float, ensure_range(0, 1)), None),
"timeout": s.Or(s.And(int, ensure_larger_than_zero), None),
"max_tag_value_length": s.Or(int, None),
"excluded_urls": s.Or([str], str, None),
s.Optional("zipkin"): {"url": s.Or(str, None)},
s.Optional("jaeger"): {"address": s.Or(str, None), "port": s.Or(int, None)},
s.Optional("otlp"): {
"zipkin": {
"endpoint": s.Or(str, None),
},
"jaeger": {
"protocol": s.Or(
s.And(str, s.Use(str.lower), lambda d: d in ["thrift", "grpc"]),
None,
),
"collector_endpoint": s.Or(str, None),
"thrift": {
"agent_host_name": s.Or(str, None),
"agent_port": s.Or(int, None),
"udp_split_oversized_batches": s.Or(bool, None),
},
"grpc": {
"insecure": s.Or(bool, None),
},
},
"otlp": {
"protocol": s.Or(s.And(str, s.Use(str.lower), validate_otlp_protocol), None),
"url": s.Or(str, None),
"endpoint": s.Or(str, None),
"compression": s.Or(
s.And(str, lambda d: d in {"gzip", "none", "deflate"}), None
),
"http": {
"certificate_file": s.Or(str, None),
"headers": s.Or(dict, None),
},
"grpc": {
"insecure": s.Or(bool, None),
"headers": s.Or(lambda d: isinstance(d, t.Sequence), None),
},
},
}
_API_SERVER_CONFIG = {
"host": s.And(str, is_valid_ip_address),
"port": s.And(int, ensure_larger_than_zero),
"workers": s.Or(s.And(int, ensure_larger_than_zero), None),
"timeout": s.And(int, ensure_larger_than_zero),
"backlog": s.And(int, ensure_larger_than(64)),
"max_request_size": s.And(int, ensure_larger_than_zero),
s.Optional("ssl"): {
s.Optional("certfile"): s.Or(str, None),
s.Optional("keyfile"): s.Or(str, None),
s.Optional("keyfile_password"): s.Or(str, None),
s.Optional("version"): s.Or(s.And(int, ensure_larger_than_zero), None),
s.Optional("cert_reqs"): s.Or(int, None),
s.Optional("ca_certs"): s.Or(str, None),
s.Optional("ciphers"): s.Or(str, None),
},
"metrics": {
"enabled": bool,
"namespace": str,
s.Optional("duration"): {
s.Optional("buckets", default=DEFAULT_BUCKET): s.Or(
s.And(list, ensure_iterable_type(float)), None
),
s.Optional("min"): s.Or(s.And(float, ensure_larger_than_zero), None),
s.Optional("max"): s.Or(s.And(float, ensure_larger_than_zero), None),
s.Optional("factor"): s.Or(s.And(float, ensure_larger_than(1.0)), None),
Expand All @@ -59,16 +82,48 @@
"request_content_type": s.Or(bool, None),
"response_content_length": s.Or(bool, None),
"response_content_type": s.Or(bool, None),
"format": {
"trace_id": str,
"span_id": str,
},
},
},
"cors": {
"http": {
"host": s.And(str, is_valid_ip_address),
"port": s.And(int, ensure_larger_than_zero),
"cors": {
"enabled": bool,
"allow_origin": s.Or(str, None),
"allow_origin_regex": s.Or(s.And(str, s.Use(re.compile)), None),
"allow_credentials": s.Or(bool, None),
"allow_headers": s.Or([str], str, None),
"allow_methods": s.Or([str], str, None),
"max_age": s.Or(int, None),
"expose_headers": s.Or([str], str, None),
},
},
"grpc": {
"host": s.And(str, is_valid_ip_address),
"port": s.And(int, ensure_larger_than_zero),
"metrics": {
"port": s.And(int, ensure_larger_than_zero),
"host": s.And(str, is_valid_ip_address),
},
"reflection": {"enabled": bool},
"max_concurrent_streams": s.Or(int, None),
"max_message_length": s.Or(int, None),
"maximum_concurrent_rpcs": s.Or(int, None),
},
"tracing": TRACING_CFG,
s.Optional("ssl"): {
"enabled": bool,
"access_control_allow_origin": s.Or(str, None),
"access_control_allow_credentials": s.Or(bool, None),
"access_control_allow_headers": s.Or([str], str, None),
"access_control_allow_methods": s.Or([str], str, None),
"access_control_max_age": s.Or(int, None),
"access_control_expose_headers": s.Or([str], str, None),
s.Optional("certfile"): s.Or(str, None),
s.Optional("keyfile"): s.Or(str, None),
s.Optional("keyfile_password"): s.Or(str, None),
s.Optional("version"): s.Or(s.And(int, ensure_larger_than_zero), None),
s.Optional("cert_reqs"): s.Or(int, None),
s.Optional("ca_certs"): s.Or(str, None),
s.Optional("ciphers"): s.Or(str, None),
},
}
_RUNNER_CONFIG = {
Expand Down Expand Up @@ -103,25 +158,18 @@
**_RUNNER_CONFIG,
s.Optional(str): _RUNNER_CONFIG,
},
"tracing": _TRACING_CONFIG,
"logging": {
"formatting": {
"trace_id_format": str,
"span_id_format": str,
}
},
}
)


def migrate_to_v2(*, override_config: dict[str, t.Any]):
def compat_layer(*, override_config: dict[str, t.Any]):
# We will use a flattened config to make it easier to migrate,
# Then we will convert it back to a nested config.
if depth(override_config) > 1:
raise ValueError("'override_config' must be a flattened dictionary.") from None

# We will set version of the migration to 2
override_config["version"] = 2
if "version" not in override_config:
override_config["version"] = 1

# First we migrate api_server field
# 1. remove api_server.max_request_size (deprecated)
Expand Down

0 comments on commit 6c3cc60

Please sign in to comment.