Skip to content

Commit

Permalink
fix(build_config): validation when NoneType (#3187)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarnphm committed Nov 4, 2022
1 parent 0469d8a commit aec7fa8
Showing 1 changed file with 32 additions and 6 deletions.
38 changes: 32 additions & 6 deletions src/bentoml/_internal/bento/build_config.py
Expand Up @@ -638,13 +638,15 @@ def _python_options_structure_hook(d: t.Any, _: t.Type[PythonOptions]) -> Python


if TYPE_CHECKING:
OptionsCls = t.Union[DockerOptions, CondaOptions, PythonOptions]
OptionsCls = DockerOptions | CondaOptions | PythonOptions


def dict_options_converter(
options_type: t.Type[OptionsCls],
) -> t.Callable[[t.Union[OptionsCls, t.Dict[str, t.Any]]], t.Any]:
def _converter(value: t.Union[OptionsCls, t.Dict[str, t.Any]]) -> options_type:
) -> t.Callable[[OptionsCls | dict[str, t.Any]], t.Any]:
def _converter(value: OptionsCls | dict[str, t.Any]) -> options_type:
if value is None:
return options_type()
if isinstance(value, dict):
return options_type(**value)
return value
Expand Down Expand Up @@ -673,18 +675,38 @@ class BentoBuildConfig:
include: t.Optional[t.List[str]] = None
exclude: t.Optional[t.List[str]] = None
docker: DockerOptions = attr.field(
factory=DockerOptions,
default=None,
converter=dict_options_converter(DockerOptions),
)
python: PythonOptions = attr.field(
factory=PythonOptions,
default=None,
converter=dict_options_converter(PythonOptions),
)
conda: CondaOptions = attr.field(
factory=CondaOptions,
default=None,
converter=dict_options_converter(CondaOptions),
)

if TYPE_CHECKING:

# NOTE: This is to ensure that BentoBuildConfig __init__
# satisfies type checker. docker, python, and conda accepts
# dict[str, t.Any] since our converter will handle the conversion.
# There is no way to tell type checker signatures of the converter from attrs
# if given attribute is alrady has a type annotation.
def __init__(
self,
service: str,
description: str | None = ...,
labels: dict[str, t.Any] | None = ...,
include: list[str] | None = ...,
exclude: list[str] | None = ...,
docker: DockerOptions | dict[str, t.Any] | None = ...,
python: PythonOptions | dict[str, t.Any] | None = ...,
conda: CondaOptions | dict[str, t.Any] | None = ...,
) -> None:
...

def __attrs_post_init__(self) -> None:
use_conda = not self.conda.is_empty()
use_cuda = self.docker.cuda_version is not None
Expand Down Expand Up @@ -715,6 +737,10 @@ def __attrs_post_init__(self) -> None:
)

if self.docker.cuda_version is not None:
if _spec.supported_cuda_versions is None:
raise BentoMLException(
f"{self.docker.distro} does not support CUDA, while 'docker.cuda_version={self.docker.cuda_version}' is provided."
)
if self.docker.cuda_version != "default" and (
self.docker.cuda_version not in _spec.supported_cuda_versions
):
Expand Down

0 comments on commit aec7fa8

Please sign in to comment.