Skip to content

Commit

Permalink
Fix styling and typing
Browse files Browse the repository at this point in the history
Also bypass cloudpipe/cloudpickle#403 by
not storing the Record (which subclasses pydantic.BaseModel) on the
transmitter, splitting it up into (data, index).
  • Loading branch information
jondequinor committed Mar 25, 2021
1 parent b82af7b commit b1ef5d4
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 109 deletions.
153 changes: 87 additions & 66 deletions ert3/data/_record.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import asyncio
import json
import shutil
import typing
import uuid
from abc import abstractmethod
from enum import Enum, auto
from functools import partial, wraps
from pathlib import Path
from typing import Awaitable, List, Mapping, Tuple, Union
from typing import Any, List, Mapping, Optional, Tuple, Union

import aiofiles
from aiofiles.os import wrap
from aiofiles.os import wrap # type: ignore
from pydantic import BaseModel, root_validator


Expand Down Expand Up @@ -127,36 +125,37 @@ def __init__(self, type_: RecordTransmitterType):
# TODO: implement state machine?
self._state = RecordTransmitterState.not_transmitted

@abstractmethod
def stream(self) -> asyncio.StreamReader:
# maybe private? Doesn't really make sense to give the user a stream?
pass

def set_transmitted(self):
def _set_transmitted(self):
self._state = RecordTransmitterState.transmitted

def is_transmitted(self):
return self._state == RecordTransmitterState.transmitted

@abstractmethod
async def dump(
self, location: Path, format: str = "json"
self, location: Path, mime: str = "text/json"
) -> None: # Should be RecordReference ?
# the result of this awaitable will be set to a RecordReference
# that has the folder into which this record was dumped
pass

@abstractmethod
async def load(self, format: str = "json") -> "asyncio.Future[Record]":
async def load(self) -> Record:
pass

@abstractmethod
async def transmit_data(
self,
data: Union[List[float], Mapping[int, float], Mapping[str, float], List[bytes]],
mime="text/json",
) -> None:
pass

@abstractmethod
async def transmit(
async def transmit_file(
self,
data_or_file: typing.Union[
Path,
Union[List[float], Mapping[int, float], Mapping[str, float], List[bytes]],
],
file: Path,
mime="text/json",
) -> None:
pass

Expand All @@ -172,38 +171,46 @@ def __init__(self, name: str, storage_path: Path):
self._uri: typing.Optional[str] = None

def set_transmitted(self, uri: Path):
super().set_transmitted()
super()._set_transmitted()
self._uri = str(uri)

async def transmit(
self,
data_or_file: typing.Union[
Path, List[float], Mapping[int, float], Mapping[str, float], List[bytes]
],
mime="text/json",
) -> None:
if self.is_transmitted():
raise RuntimeError("Record already transmitted")
if isinstance(data_or_file, Path) or isinstance(data_or_file, str):
async with aiofiles.open(str(data_or_file), mode="r") as f:
contents = await f.read()
record = Record(data=json.loads(contents))
else:
record = Record(data=data_or_file)

async def _transmit(self, record: Record, mime: str):
storage_uri = self._storage_path / self._concrete_key
async with aiofiles.open(storage_uri, mode="w") as f:
if mime == "text/json":
contents = json.dumps(record.data)
await f.write(contents)
elif mime == "application/x-python-code":
# XXX: An opaque record is a list of bytes... yes
# sonso or dan or jond: do something about this
await f.write(record.data[0].decode())
if isinstance(record.data, list) and isinstance(record.data[0], bytes):
await f.write(record.data[0].decode())
else:
raise TypeError(f"unexpected record data type {type(record.data)}")
else:
raise ValueError(f"unsupported mime {mime}")
self.set_transmitted(storage_uri)

async def transmit_data(
self,
data: Union[List[float], Mapping[int, float], Mapping[str, float], List[bytes]],
mime="text/json",
) -> None:
if self.is_transmitted():
raise RuntimeError("Record already transmitted")
record = Record(data=data)
return await self._transmit(record, mime)

async def transmit_file(
self,
file: Path,
mime="text/json",
) -> None:
if self.is_transmitted():
raise RuntimeError("File already transmitted")
async with aiofiles.open(str(file), mode="r") as f:
contents = await f.read()
record = Record(data=json.loads(contents))
return await self._transmit(record, mime)

async def load(self) -> Record:
if self._state != RecordTransmitterState.transmitted:
raise RuntimeError("cannot load untransmitted record")
Expand All @@ -212,56 +219,70 @@ async def load(self) -> Record:
return Record(data=json.loads(contents))

# TODO: should use Path
async def dump(self, location: str):
async def dump(self, location: Path, mime: str = "text/json") -> None:
if self._state != RecordTransmitterState.transmitted:
raise RuntimeError("cannot dump untransmitted record")
await _copy(self._uri, location)
await _copy(self._uri, str(location))


class InMemoryRecordTransmitter(RecordTransmitter):
TYPE: RecordTransmitterType = RecordTransmitterType.in_memory
# TODO: this field should be Record, but that does not work until
# https://github.com/cloudpipe/cloudpickle/issues/403 has been released.
_data: Optional[Any] = None
_index: Optional[Any] = None

def __init__(self, name: str):
super().__init__(type_=self.TYPE)
self._name = name
self._record = None

def set_transmitted(self, record: Record):
super().set_transmitted()
self._record = record
super()._set_transmitted()
self._data = record.data
self._index = record.index

async def transmit(
@abstractmethod
async def transmit_data(
self,
data_or_file: typing.Union[
Path, List[float], Mapping[int, float], Mapping[str, float], List[bytes]
],
data: Union[List[float], Mapping[int, float], Mapping[str, float], List[bytes]],
mime="text/json",
):
) -> None:
if self.is_transmitted():
raise RuntimeError("Record already transmitted")
if isinstance(data_or_file, Path) or isinstance(data_or_file, str):
async with aiofiles.open(data_or_file) as f:
contents = await f.read()
record = Record(data=json.loads(contents))
else:
record = Record(data=data_or_file)
record = Record(data=data)
self.set_transmitted(record=record)

@abstractmethod
async def transmit_file(
self,
file: Path,
mime="text/json",
) -> None:
if self.is_transmitted():
raise RuntimeError("Record already transmitted")
async with aiofiles.open(str(file)) as f:
contents = await f.read()
record = Record(data=json.loads(contents))
self.set_transmitted(record=record)

async def load(self):
return self._record
return Record(data=self._data, index=self._index)

# TODO: should use Path
async def dump(self, location: str, format: str = "text/json"):
if format is None:
format = "text/json"
async def dump(self, location: Path, mime: str = "text/json"):
if mime is None:
mime = "text/json"
if not self.is_transmitted():
raise RuntimeError("cannot dump untransmitted record")
async with aiofiles.open(location, mode="w") as f:
if format == "text/json":
await f.write(json.dumps(self._record.data))
elif format == "application/x-python-code":
# XXX: An opaque record is a list of bytes... yes
# sonso or dan or jond: do something about this
await f.write(self._record.data[0].decode())
if self._data is None:
raise ValueError("Cannot dump Record with no data")

async with aiofiles.open(str(location), mode="w") as f:
if mime == "text/json":
await f.write(json.dumps(self._data))
elif mime == "application/x-python-code":
if isinstance(self._data, list) and isinstance(self._data[0], bytes):
await f.write(self._data[0].decode())
else:
raise TypeError(f"unexpected record data type {type(self._data)}")
else:
raise ValueError(f"unsupported mime {format}")
raise ValueError(f"unsupported mime {mime}")
73 changes: 40 additions & 33 deletions ert3/evaluator/_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import typing
from collections import defaultdict

import aiofiles

import ert3
from ert3.config._stages_config import StagesConfig
from ert_shared.ensemble_evaluator.config import EvaluatorServerConfig
Expand Down Expand Up @@ -38,11 +36,13 @@ def _prepare_input(
inputs: "ert3.data.MultiEnsembleRecord",
evaluation_tmp_dir,
ensemble_size,
) -> typing.Dict[str, typing.Dict[str, typing.List["ert3.data.RecordTransmitter"]]]:
) -> typing.Dict[int, typing.Dict[str, "ert3.data.RecordTransmitter"]]:
tmp_input_folder = evaluation_tmp_dir / "prep_input_files"
os.makedirs(tmp_input_folder)
storage_config = ee_config["storage"]
transmitters = defaultdict(dict)
transmitters: typing.Dict[
int, typing.Dict[str, "ert3.data.RecordTransmitter"]
] = defaultdict(dict)

futures = []
for input_ in step_config.input:
Expand All @@ -56,25 +56,26 @@ def _prepare_input(
raise ValueError(
f"Unsupported transmitter type: {storage_config.get('type')}"
)
futures.append(transmitter.transmit(record.data))
futures.append(transmitter.transmit_data(record.data))
transmitters[iens][input_.record] = transmitter
asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures))
for command in step_config.transportable_commands:
for iens in range(0, ensemble_size):
if storage_config.get("type") == "shared_disk":
transmitter = ert3.data.SharedDiskRecordTransmitter(
name=command.name,
storage_path=pathlib.Path(storage_config["storage_path"]),
)
else:
raise ValueError(
f"Unsupported transmitter type: {storage_config.get('type')}"
)
with open(command.location, "rb") as f:
asyncio.get_event_loop().run_until_complete(
transmitter.transmit([f.read()], mime=command.mime)
)
transmitters[iens][command.name] = transmitter
if step_config.transportable_commands is not None:
for command in step_config.transportable_commands:
for iens in range(0, ensemble_size):
if storage_config.get("type") == "shared_disk":
transmitter = ert3.data.SharedDiskRecordTransmitter(
name=command.name,
storage_path=pathlib.Path(storage_config["storage_path"]),
)
else:
raise ValueError(
f"Unsupported transmitter type: {storage_config.get('type')}"
)
with open(command.location, "rb") as f:
asyncio.get_event_loop().run_until_complete(
transmitter.transmit_data([f.read()], mime=command.mime)
)
transmitters[iens][command.name] = transmitter
return dict(transmitters)


Expand All @@ -83,12 +84,15 @@ def _prepare_output(
step_config: ert3.config._stages_config.Step,
evaluation_tmp_dir: pathlib.Path,
ensemble_size: int,
) -> typing.Dict[str, typing.Dict[str, typing.List["ert3.data.RecordTransmitter"]]]:
) -> typing.Dict[int, typing.Dict[str, "ert3.data.RecordTransmitter"]]:
# TODO: ensemble_size should rather be a list of ensemble ids
tmp_input_folder = evaluation_tmp_dir / "output_files"
os.makedirs(tmp_input_folder)
storage_config = ee_config["storage"]
transmitters = defaultdict(dict)
transmitters: typing.Dict[
int, typing.Dict[str, "ert3.data.RecordTransmitter"]
] = defaultdict(dict)

for output in step_config.output:
for iens in range(0, ensemble_size):
if storage_config.get("type") == "shared_disk":
Expand Down Expand Up @@ -121,7 +125,8 @@ def _build_ee_config(

stage_name = ensemble.forward_model.stages[0]
stage = stages_config.step_from_key(stage_name)
commands = stage.transportable_commands
assert stage is not None
commands = stage.transportable_commands if stage.transportable_commands else []
output_locations = [out.location for out in stage.output]
jobs = []

Expand All @@ -139,15 +144,16 @@ def command_location(name):
}
)

for script in stage.script:
name, *args = script.split()
jobs.append(
{
"name": name,
"executable": command_location(name),
"args": tuple(args),
}
)
if stage.script is not None:
for script in stage.script:
name, *args = script.split()
jobs.append(
{
"name": name,
"executable": command_location(name),
"args": tuple(args),
}
)

stages = [
{
Expand All @@ -160,6 +166,7 @@ def command_location(name):
"record": input_.record,
"location": input_.location,
"mime": input_.mime,
"is_executable": False,
}
for input_ in stage.input
]
Expand Down
2 changes: 1 addition & 1 deletion ert3/stats/_stats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import numpy as np
import numpy as np # type: ignore
import scipy.stats

import ert3
Expand Down
2 changes: 1 addition & 1 deletion ert_shared/ensemble_evaluator/entity/function_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def _load(io_name, transmitter):
function_output = func(**kwargs)

async def _transmit(io_name, transmitter, data):
await transmitter.transmit(data)
await transmitter.transmit_data(data)
return (io_name, transmitter)

futures = []
Expand Down
2 changes: 1 addition & 1 deletion ert_shared/ensemble_evaluator/entity/unix_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def run(self, inputs=None):
output.get_name()
]
futures.append(
outputs[output.get_name()].transmit(
outputs[output.get_name()].transmit_file(
os.path.join(run_path, output.get_path())
)
)
Expand Down

0 comments on commit b1ef5d4

Please sign in to comment.