Skip to content

Commit

Permalink
experiment: Subclass distribution configs
Browse files Browse the repository at this point in the history
  • Loading branch information
markusdregi committed May 12, 2021
1 parent 7c652ba commit 4e70ec3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 59 deletions.
53 changes: 19 additions & 34 deletions ert3/config/_parameters_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Tuple, Union
from pydantic import (
BaseModel,
ValidationError,
Expand Down Expand Up @@ -41,11 +41,9 @@ def _ensure_valid_name(name: str) -> str:
return name


class _DistributionInput(_ParametersConfig):
mean: Optional[float]
std: Optional[float]
lower_bound: Optional[float]
upper_bound: Optional[float]
class _GaussianInput(_ParametersConfig):
mean: float
std: float

@validator("std")
def _ensure_positive_std(cls, value): # type: ignore
Expand All @@ -56,10 +54,15 @@ def _ensure_positive_std(cls, value): # type: ignore
raise ValueError(f"Expected positive std, was {value}")
return value


class _UniformInput(_ParametersConfig):
lower_bound: float
upper_bound: float

@root_validator
def _ensure_lower_upper(cls, values): # type: ignore
low = values["lower_bound"]
up = values["upper_bound"]
low = values.get("lower_bound")
up = values.get("upper_bound")

if low is None or up is None:
return values
Expand All @@ -72,32 +75,14 @@ def _ensure_lower_upper(cls, values): # type: ignore
)


class _Distribution(_ParametersConfig):
type: Literal["gaussian", "uniform"]
input: _DistributionInput

@root_validator(pre=True)
def _ensure_gaussian_required_fields(cls, values): # type: ignore
dist_type = values["type"]
input_fields = values["input"].keys()
if dist_type == "uniform":
expected = {"lower_bound", "upper_bound"}
actual = set(input_fields)
if expected != actual:
raise AssertionError(
f"Expected distribution inputs: {expected}, was: {actual}"
)
elif values["type"] == "gaussian":
expected = {"mean", "std"}
actual = set(input_fields)
if expected != actual:
raise AssertionError(
f"Expected distribution inputs: {expected}, was: {actual}"
)
else:
raise ValueError(f"Validating unknown distribution: {dist_type}")
class _GaussianDistribution(_ParametersConfig):
type: Literal["gaussian"]
input: _GaussianInput


return values
class _UniformDistribution(_ParametersConfig):
type: Literal["uniform"]
input: _UniformInput


class _VariablesConfig(_ParametersConfig):
Expand Down Expand Up @@ -127,7 +112,7 @@ def __len__(self) -> int:
class _ParameterConfig(_ParametersConfig):
name: str
type: Literal["stochastic"]
distribution: _Distribution
distribution: Union[_GaussianDistribution, _UniformDistribution]
variables: _VariablesConfig

@validator("name")
Expand Down
41 changes: 16 additions & 25 deletions tests/ert3/config/test_parameters_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def test_valid_gauss(mean, std):


@pytest.mark.parametrize(
("input_", "err_msg"),
"input_",
(
({"mean": 0, "std": 0}, "Expected positive std, was"),
({"mean": 0, "std": -1}, "Expected positive std, was"),
({"mean": 1}, "Expected distribution inputs"),
({"std": 1}, "Expected distribution inputs"),
({"mean": 0, "std": 1, "upper_bound": 10}, "Expected distribution inputs"),
{"mean": 0, "std": 0},
{"mean": 0, "std": -1},
{"mean": 1},
{"std": 1},
{"mean": 0, "std": 1, "upper_bound": 10},
),
)
def test_invalid_gauss(input_, err_msg):
def test_invalid_gauss(input_):
raw_config = [
{
"name": "my_parameter_group",
Expand All @@ -64,7 +64,7 @@ def test_invalid_gauss(input_, err_msg):
}
]

with pytest.raises(ert3.exceptions.ConfigValidationError, match=err_msg):
with pytest.raises(ert3.exceptions.ConfigValidationError):
ert3.config.load_parameters_config(raw_config)


Expand Down Expand Up @@ -105,18 +105,15 @@ def test_valid_uniform(lower_bound, upper_bound):


@pytest.mark.parametrize(
("input_", "err_msg"),
"input_",
(
({"lower_bound": 2, "upper_bound": 1}, "Expected lower_bound"),
({"lower_bound": 0}, "Expected distribution inputs"),
({"upper_bound": 0}, "Expected distribution inputs"),
(
{"lower_bound": 2, "upper_bound": 1, "mean": 0},
"Expected distribution inputs",
),
{"lower_bound": 2, "upper_bound": 1},
{"lower_bound": 0},
{"upper_bound": 0},
{"lower_bound": 2, "upper_bound": 1, "mean": 0},
),
)
def test_invalid_uniform(input_, err_msg):
def test_invalid_uniform(input_):
raw_config = [
{
"name": "my_parameter_group",
Expand All @@ -129,7 +126,7 @@ def test_invalid_uniform(input_, err_msg):
}
]

with pytest.raises(ert3.exceptions.ConfigValidationError, match=err_msg):
with pytest.raises(ert3.exceptions.ConfigValidationError):
ert3.config.load_parameters_config(raw_config)


Expand Down Expand Up @@ -282,9 +279,7 @@ def test_invalid_distribution():
}
]

with pytest.raises(
ert3.exceptions.ConfigValidationError, match="Validating unknown distribution"
):
with pytest.raises(ert3.exceptions.ConfigValidationError):
ert3.config.load_parameters_config(raw_config)


Expand Down Expand Up @@ -438,11 +433,7 @@ def test_multi_parameter_group():
assert param.distribution.type == "gaussian"
assert param.distribution.input.mean == 0
assert param.distribution.input.std == 1
assert param.distribution.input.lower_bound == None
assert param.distribution.input.upper_bound == None
else:
assert param.distribution.type == "uniform"
assert param.distribution.input.mean == None
assert param.distribution.input.std == None
assert param.distribution.input.lower_bound == 0
assert param.distribution.input.upper_bound == 1

0 comments on commit 4e70ec3

Please sign in to comment.