Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(configuration): override options with envvar #3018

Merged
merged 8 commits into from Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions bentoml/_internal/configuration/__init__.py
Expand Up @@ -30,6 +30,7 @@ class version_mod:
DEBUG_ENV_VAR = "BENTOML_DEBUG"
QUIET_ENV_VAR = "BENTOML_QUIET"
CONFIG_ENV_VAR = "BENTOML_CONFIG"
CONFIG_OVERRIDE_ENV_VAR = "BENTOML_CONFIG_OPTIONS"
# https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
GRPC_DEBUG_ENV_VAR = "GRPC_VERBOSITY"

Expand Down Expand Up @@ -97,6 +98,13 @@ def get_bentoml_config_file_from_env() -> t.Optional[str]:
return None


def get_bentoml_override_config_from_env() -> t.Optional[str]:
if CONFIG_OVERRIDE_ENV_VAR in os.environ:
# User local config options for customizing bentoml
return os.environ.get(CONFIG_OVERRIDE_ENV_VAR, None)
return None


def set_debug_mode(enabled: bool) -> None:
os.environ[DEBUG_ENV_VAR] = str(enabled)
os.environ[GRPC_DEBUG_ENV_VAR] = "DEBUG"
Expand Down Expand Up @@ -147,6 +155,7 @@ def load_global_config(bentoml_config_file: t.Optional[str] = None):

bentoml_configuration = BentoMLConfiguration(
override_config_file=bentoml_config_file,
override_config_values=get_bentoml_override_config_from_env(),
)

BentoMLContainer.config.set(bentoml_configuration.as_dict())
Expand Down
74 changes: 48 additions & 26 deletions bentoml/_internal/configuration/containers.py
Expand Up @@ -21,10 +21,12 @@
from deepmerge.merger import Merger

from . import expand_env_var
from ..utils import split_with_quotes
from ..utils import validate_or_create_dir
from ..context import component_context
from ..resource import system_resources
from ...exceptions import BentoMLConfigException
from ..utils.unflatten import unflatten

if TYPE_CHECKING:
from bentoml._internal.models import ModelStore
Expand Down Expand Up @@ -214,6 +216,7 @@ class BentoMLConfiguration:
def __init__(
self,
override_config_file: t.Optional[str] = None,
override_config_values: t.Optional[str] = None,
validate_schema: bool = True,
):
# Load default configuration
Expand All @@ -223,15 +226,6 @@ def __init__(
with open(default_config_file, "rb") as f:
self.config: t.Dict[str, t.Any] = yaml.safe_load(f)

if validate_schema:
try:
SCHEMA.validate(self.config)
except SchemaError as e:
raise BentoMLConfigException(
"Default configuration 'default_configuration.yml' does not"
" conform to the required schema."
) from e

# User override configuration
if override_config_file is not None:
logger.info("Applying user config override from %s" % override_config_file)
Expand Down Expand Up @@ -271,26 +265,54 @@ def __init__(

config_merger.merge(self.config, override_config)

global_runner_cfg = {k: self.config["runners"][k] for k in RUNNER_CFG_KEYS}
for key in self.config["runners"]:
if key not in RUNNER_CFG_KEYS:
runner_cfg = self.config["runners"][key]
if override_config_values is not None:
logger.info(
"Applying user config override from ENV VAR: %s", override_config_values
)
lines = split_with_quotes(
override_config_values,
sep=r"\s+",
quote='"',
use_regex=True,
)
override_config_map = {
k: yaml.load(v)
for k, v in [
split_with_quotes(line, sep="=", quote='"') for line in lines
]
}
try:
override_config = unflatten(override_config_map) # type: ignore
except ValueError as e:
raise BentoMLConfigException(
f'Failed to parse config options from the env var: {e}. \n *** Note: You can use " to quote the key if it contains special characters. ***'
) from None
config_merger.merge(self.config, override_config)

# key is a runner name
if runner_cfg.get("resources") == "system":
runner_cfg["resources"] = system_resources()
if override_config_file is not None or override_config_values is not None:
self._finish_config()

self.config["runners"][key] = config_merger.merge(
deepcopy(global_runner_cfg), runner_cfg
)
if validate_schema:
try:
SCHEMA.validate(self.config)
except SchemaError as e:
raise BentoMLConfigException(
"Default configuration 'default_configuration.yml' does not"
" conform to the required schema."
) from e

if validate_schema:
try:
SCHEMA.validate(self.config)
except SchemaError as e:
raise BentoMLConfigException(
"Invalid configuration file was given."
) from e
def _finish_config(self):
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
global_runner_cfg = {k: self.config["runners"][k] for k in RUNNER_CFG_KEYS}
for key in self.config["runners"]:
if key not in RUNNER_CFG_KEYS:
runner_cfg = self.config["runners"][key]
# key is a runner name
if runner_cfg.get("resources") == "system":
runner_cfg["resources"] = system_resources()
self.config["runners"][key] = config_merger.merge(
deepcopy(global_runner_cfg),
runner_cfg,
)

def override(self, keys: t.List[str], value: t.Any):
if keys is None:
Expand Down
39 changes: 39 additions & 0 deletions bentoml/_internal/utils/__init__.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import re
import sys
import uuid
import random
Expand Down Expand Up @@ -171,6 +172,44 @@ def _(*args: P.args, **kwargs: P.kwargs) -> t.Optional[_T_co]:
return _


def split_with_quotes(
s: str,
sep: str = ",",
quote: str = '"',
use_regex: bool = False,
) -> list[str]:
"""
Split a string with quotes, e.g.:
>>> split_with_quotes('a,b,"c,d",e')
['a', 'b', 'c,d', 'e']
"""
if use_regex:
assert (
"(" not in sep and ")" not in sep
), "sep cannot contain '(' or ')' when using regex"
reg = "({quote}[^{quote}]*{quote})|({sep})".format(
quote=quote,
sep=sep,
)
else:
reg = "({quote}[^{quote}]*{quote})|({sep})".format(
quote=re.escape(quote),
sep=re.escape(sep),
)
raw_parts = re.split(reg, s)
parts: list[str] = []
part_begin = 0
for i in range(0, len(raw_parts), 3):
if i + 2 > len(raw_parts):
parts.append("".join(filter(None, raw_parts[part_begin : i + 2])))
continue
if raw_parts[i + 2] is not None:
parts.append("".join(filter(None, raw_parts[part_begin : i + 2])))
part_begin = i + 3
continue
return parts


@contextlib.contextmanager
def reserve_free_port(
host: str = "localhost",
Expand Down