Skip to content

Commit

Permalink
Introduce parameters config
Browse files Browse the repository at this point in the history
  • Loading branch information
markusdregi committed May 13, 2021
1 parent d671800 commit 8142133
Show file tree
Hide file tree
Showing 10 changed files with 771 additions and 146 deletions.
2 changes: 1 addition & 1 deletion ert3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from . import data
from . import stats
from . import config
from . import console
from . import evaluator
from . import stats
from . import workspace
from . import exceptions
from . import algorithms
Expand Down
3 changes: 3 additions & 0 deletions ert3/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ert3.config._ensemble_config import load_ensemble_config, EnsembleConfig
from ert3.config._stages_config import load_stages_config, StagesConfig, Function, Unix
from ert3.config._experiment_config import load_experiment_config, ExperimentConfig
from ert3.config._parameters_config import load_parameters_config, ParametersConfig

Step = Union[Function, Unix]

Expand All @@ -16,4 +17,6 @@
"Function",
"load_experiment_config",
"ExperimentConfig",
"load_parameters_config",
"ParametersConfig",
]
164 changes: 164 additions & 0 deletions ert3/config/_parameters_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import sys
from typing import Any, Dict, Iterator, List, Tuple, Union
from pydantic import (
BaseModel,
ValidationError,
root_validator,
validator,
StrictInt,
StrictStr,
)

import ert3

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


_IndexType = Tuple[Union[StrictStr, StrictInt], ...]


class _ParametersConfig(BaseModel):
class Config:
validate_all = True
validate_assignment = True
extra = "forbid"
allow_mutation = False
arbitrary_types_allowed = True


def _ensure_valid_name(name: str) -> str:
if not name:
raise ValueError("Names cannot be of zero length")

if not all(c.isalpha() or c == "_" for c in name):
raise ValueError(
"Names are expected to only contain characters and `_`, was: {name}"
)

return name


class _GaussianInput(_ParametersConfig):
mean: float
std: float

@validator("std")
def _ensure_positive_std(cls, value): # type: ignore
if value is None:
return None

if value <= 0:
raise ValueError(f"Expected positive std, was {value}")
return value


class _UniformInput(_ParametersConfig):
lower_bound: float
upper_bound: float

@root_validator
def _ensure_lower_upper(cls, values): # type: ignore
low = values.get("lower_bound")
up = values.get("upper_bound")

if low is None or up is None:
return values

if low <= up:
return values

raise ValueError(
f"Expected lower_bound ({low}) to be smaller than upper_bound ({up})"
)


class _GaussianDistribution(_ParametersConfig):
type: Literal["gaussian"]
input: _GaussianInput


class _UniformDistribution(_ParametersConfig):
type: Literal["uniform"]
input: _UniformInput


class _VariablesConfig(_ParametersConfig):
__root__: List[str]

@validator("__root__")
def _ensure_variables(cls, variables): # type: ignore
if len(variables) > 0:
return variables

raise ValueError("Parameter group cannot have no variables")

@validator("__root__", each_item=True)
def _ensure_valid_variable_names(cls, variable: Any) -> str:
return _ensure_valid_name(variable)

def __iter__(self) -> Iterator[str]: # type: ignore
return iter(self.__root__)

def __getitem__(self, item: int) -> str:
return self.__root__[item]

def __len__(self) -> int:
return len(self.__root__)


class _ParameterConfig(_ParametersConfig):
name: str
type: Literal["stochastic"]
distribution: Union[_GaussianDistribution, _UniformDistribution]
variables: _VariablesConfig

@validator("name")
def _ensure_valid_group_name(cls, value: Any) -> str:
return _ensure_valid_name(value)

def as_distribution(self) -> ert3.stats.Distribution:
dist_config = self.distribution
index: _IndexType = tuple(self.variables)
if dist_config.type == "gaussian":
assert dist_config.input.mean is not None
assert dist_config.input.std is not None

return ert3.stats.Gaussian(
dist_config.input.mean,
dist_config.input.std,
index=index,
)
elif dist_config.type == "uniform":
assert dist_config.input.lower_bound is not None
assert dist_config.input.upper_bound is not None

return ert3.stats.Uniform(
dist_config.input.lower_bound,
dist_config.input.upper_bound,
index=index,
)
else:
raise ValueError("Unknown distribution type: {}".format(dist_config.type))


class ParametersConfig(_ParametersConfig):
__root__: List[_ParameterConfig]

def __iter__(self) -> Iterator[_ParameterConfig]: # type: ignore
return iter(self.__root__)

def __getitem__(self, item: int) -> _ParameterConfig:
return self.__root__[item]

def __len__(self) -> int:
return len(self.__root__)


def load_parameters_config(config: List[Dict[str, Any]]) -> ParametersConfig:
try:
return ParametersConfig.parse_obj(config)
except ValidationError as err:
raise ert3.exceptions.ConfigValidationError(str(err), source="parameters")
16 changes: 14 additions & 2 deletions ert3/console/_console.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import ert3

from ert3.config import EnsembleConfig, StagesConfig, ExperimentConfig
from ert3.config import EnsembleConfig, StagesConfig, ExperimentConfig, ParametersConfig


_ERT3_DESCRIPTION = (
Expand Down Expand Up @@ -160,10 +160,12 @@ def _run(workspace: Path, args: Any) -> None:
ensemble = _load_ensemble_config(workspace, args.experiment_name)
stages_config = _load_stages_config(workspace)
experiment_config = _load_experiment_config(workspace, args.experiment_name)
parameters_config = _load_parameters_config(workspace)
ert3.engine.run(
ensemble,
stages_config,
experiment_config,
parameters_config,
workspace,
args.experiment_name,
)
Expand All @@ -177,8 +179,13 @@ def _export(workspace: Path, args: Any) -> None:
def _record(workspace: Path, args: Any) -> None:
assert args.sub_cmd == "record"
if args.sub_record_cmd == "sample":
parameters_config = _load_parameters_config(workspace)
ert3.engine.sample_record(
workspace, args.parameter_group, args.record_name, args.ensemble_size
workspace,
parameters_config,
args.parameter_group,
args.record_name,
args.ensemble_size,
)
elif args.sub_record_cmd == "load":
ert3.engine.load_record(workspace, args.record_name, args.record_file)
Expand Down Expand Up @@ -260,3 +267,8 @@ def _load_experiment_config(workspace: Path, experiment_name: str) -> Experiment
)
with open(experiment_config) as f:
return ert3.config.load_experiment_config(yaml.safe_load(f))


def _load_parameters_config(workspace: Path) -> ParametersConfig:
with open(workspace / "parameters.yml") as f:
return ert3.config.load_parameters_config(yaml.safe_load(f))
20 changes: 13 additions & 7 deletions ert3/engine/_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path

import ert3
from ert3.engine import _utils


def load_record(workspace: Path, record_name: str, record_file: Path) -> None:
Expand All @@ -20,19 +19,26 @@ def load_record(workspace: Path, record_name: str, record_file: Path) -> None:
)


def _get_distribution(
parameter_group_name: str, parameters_config: ert3.config.ParametersConfig
) -> ert3.stats.Distribution:
for parameter_group in parameters_config:
if parameter_group.name == parameter_group_name:
return parameter_group.as_distribution()

raise ValueError(f"No parameter group found named: {parameter_group_name}")


# pylint: disable=too-many-arguments
def sample_record(
workspace: Path,
parameters_config: ert3.config.ParametersConfig,
parameter_group_name: str,
record_name: str,
ensemble_size: int,
experiment_name: Optional[str] = None,
) -> None:
parameters = _utils.load_parameters(workspace)

if parameter_group_name not in parameters:
raise ValueError(f"No parameter group found named: {parameter_group_name}")
distribution = parameters[parameter_group_name]

distribution = _get_distribution(parameter_group_name, parameters_config)
ensrecord = ert3.data.EnsembleRecord(
records=[distribution.sample() for _ in range(ensemble_size)]
)
Expand Down
34 changes: 27 additions & 7 deletions ert3/engine/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Dict

import ert3
from ert3.engine import _utils


def _prepare_experiment(
Expand All @@ -23,12 +22,14 @@ def _prepare_experiment(
)


# pylint: disable=too-many-arguments
def _prepare_experiment_record(
record_name: str,
record_source: List[str],
ensemble_size: int,
experiment_name: str,
workspace_root: pathlib.Path,
parameters_config: ert3.config.ParametersConfig,
) -> None:
if record_source[0] == "storage":
assert len(record_source) == 2
Expand All @@ -44,6 +45,7 @@ def _prepare_experiment_record(
elif record_source[0] == "stochastic":
ert3.engine.sample_record(
workspace_root,
parameters_config,
record_source[1],
record_name,
ensemble_size,
Expand All @@ -55,6 +57,7 @@ def _prepare_experiment_record(

def _prepare_evaluation(
ensemble: ert3.config.EnsembleConfig,
parameters_config: ert3.config.ParametersConfig,
workspace_root: pathlib.Path,
experiment_name: str,
) -> None:
Expand All @@ -68,15 +71,23 @@ def _prepare_evaluation(
record_source = input_record.source.split(".")

_prepare_experiment_record(
record_name, record_source, ensemble.size, experiment_name, workspace_root
record_name,
record_source,
ensemble.size,
experiment_name,
workspace_root,
parameters_config,
)


def _load_ensemble_parameters(
ensemble: ert3.config.EnsembleConfig,
parameters_config: ert3.config.ParametersConfig,
workspace: pathlib.Path,
) -> Dict[str, ert3.stats.Distribution]:
parameters = _utils.load_parameters(workspace)
all_distributions = {
param.name: param.as_distribution() for param in parameters_config
}

ensemble_parameters = {}
for input_record in ensemble.input:
Expand All @@ -85,16 +96,19 @@ def _load_ensemble_parameters(
assert len(record_source) == 2
assert record_source[0] == "stochastic"
parameter_group_name = record_source[1]
ensemble_parameters[record_name] = parameters[parameter_group_name]
ensemble_parameters[record_name] = all_distributions[parameter_group_name]
return ensemble_parameters


def _prepare_sensitivity(
ensemble: ert3.config.EnsembleConfig,
parameters_config: ert3.config.ParametersConfig,
workspace_root: pathlib.Path,
experiment_name: str,
) -> None:
parameter_distributions = _load_ensemble_parameters(ensemble, workspace_root)
parameter_distributions = _load_ensemble_parameters(
ensemble, parameters_config, workspace_root
)
input_records = ert3.algorithms.one_at_the_time(parameter_distributions)

_prepare_experiment(workspace_root, experiment_name, ensemble, len(input_records))
Expand Down Expand Up @@ -164,18 +178,24 @@ def _evaluate(
_store_responses(workspace_root, experiment_name, responses)


# pylint: disable=too-many-arguments
def run(
ensemble: ert3.config.EnsembleConfig,
stages_config: ert3.config.StagesConfig,
experiment_config: ert3.config.ExperimentConfig,
parameters_config: ert3.config.ParametersConfig,
workspace_root: pathlib.Path,
experiment_name: str,
) -> None:

if experiment_config.type == "evaluation":
_prepare_evaluation(ensemble, workspace_root, experiment_name)
_prepare_evaluation(
ensemble, parameters_config, workspace_root, experiment_name
)
elif experiment_config.type == "sensitivity":
_prepare_sensitivity(ensemble, workspace_root, experiment_name)
_prepare_sensitivity(
ensemble, parameters_config, workspace_root, experiment_name
)
else:
raise ValueError(f"Unknown experiment type {experiment_config.type}")

Expand Down

0 comments on commit 8142133

Please sign in to comment.