Skip to content

Commit

Permalink
Generic compute plan (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmorel authored and inalgnu committed Dec 5, 2019
1 parent 2ab87d8 commit e5b069c
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 77 deletions.
93 changes: 79 additions & 14 deletions substratest/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@
FUTURE_TIMEOUT = 120 # seconds


class Future:
class BaseFuture(abc.ABC):
@abc.abstractmethod
def wait(self, timeout=FUTURE_TIMEOUT, raises=True):
raise NotImplementedError

@abc.abstractmethod
def get(self):
raise NotImplementedError


class Future(BaseFuture):
"""Future asset."""
# mapper from asset class name to client getter method
_methods = {
Expand Down Expand Up @@ -50,17 +60,57 @@ def get(self):
return self._asset


class _FutureMixin(abc.ABC):
class ComputePlanFuture(BaseFuture):
def __init__(self, compute_plan, session):
self._compute_plan = compute_plan
self._session = session

def wait(self, timeout=FUTURE_TIMEOUT):
"""wait until all tuples are completed (done or failed)."""
tuples = (self._compute_plan.list_traintuple(self._session) +
self._compute_plan.list_composite_traintuple(self._session) +
self._compute_plan.list_aggregatetuple(self._session))
tuples = sorted(tuples, key=lambda t: t.rank)
# testtuples do not have a rank attribute
tuples += self._compute_plan.list_testtuple(self._session)

for tuple_ in tuples:
tuple_.future().wait(timeout, raises=False)

return self.get()

def get(self):
return self._session.get_compute_plan(self._compute_plan.compute_plan_id)


class _BaseFutureMixin(abc.ABC):
_future_cls = None

def attach(self, session):
"""Attach session to asset."""
self._session = session
return self

def future(self):
"""Returns future from asset."""
return self._future_cls(self, self._session)


class _FutureMixin(_BaseFutureMixin):
_future_cls = Future

def attach(self, session):
self._session = session
return self

def future(self):
assert hasattr(self, 'status')
assert hasattr(self, 'key')
return Future(self, self._session)
return super().future()


class _ComputePlanFutureMixin(_BaseFutureMixin):
_future_cls = ComputePlanFuture


def _convert(name):
Expand Down Expand Up @@ -321,23 +371,38 @@ class Meta:


@dataclasses.dataclass
class ComputePlanCreated(_Asset):
class ComputePlan(_Asset, _ComputePlanFutureMixin):
compute_plan_id: str
objective_key: str
traintuple_keys: typing.List[str]
composite_traintuple_keys: typing.List[str]
aggregatetuple_keys: typing.List[str]
testtuple_keys: typing.List[str]

def __post_init__(self):
if self.composite_traintuple_keys is None:
self.composite_traintuple_keys = []

@dataclasses.dataclass
class ComputePlan(_Asset):
compute_plan_id: str
algo_key: str
objective_key: str
traintuples: typing.List[str]
testtuples: typing.List[str]
if self.aggregatetuple_keys is None:
self.aggregatetuple_keys = []

def __post_init__(self):
if self.testtuples is None:
self.testtuples = []
if self.testtuple_keys is None:
self.testtuple_keys = []

def list_traintuple(self, session):
return session.list_traintuple(filters=[f'traintuple:computePlanId:{self.compute_plan_id}'])

def list_composite_traintuple(self, session):
return session.list_composite_traintuple(
filters=[f'composite_traintuple:computePlanId:{self.compute_plan_id}']
)

def list_aggregatetuple(self, session):
return session.list_aggregatetuple(filters=[f'aggregatetuple:computePlanId:{self.compute_plan_id}'])

def list_testtuple(self, session):
filters = [f'testtuple:key:{k}' for k in self.testtuple_keys]
return session.list_testtuple(filters=filters)


@dataclasses.dataclass(frozen=True)
Expand Down
12 changes: 6 additions & 6 deletions substratest/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def add_testtuple(self, spec):

def add_compute_plan(self, spec):
res = self._client.add_compute_plan(spec.to_dict())
compute_plan = assets.ComputePlanCreated.load(res)
compute_plan = assets.ComputePlan.load(res).attach(self)
self.state.compute_plans.append(compute_plan)
return compute_plan

Expand All @@ -156,7 +156,7 @@ def list_compute_plan(self, *args, **kwargs):

def get_compute_plan(self, *args, **kwargs):
res = self._client.get_compute_plan(*args, **kwargs)
compute_plan = assets.ComputePlan.load(res)
compute_plan = assets.ComputePlan.load(res).attach(self)
self.state.update_compute_plan(compute_plan)
return compute_plan

Expand Down Expand Up @@ -214,7 +214,7 @@ def get_traintuple(self, *args, **kwargs):

def list_traintuple(self, *args, **kwargs):
res = self._client.list_traintuple(*args, **kwargs)
return [assets.Traintuple.load(x) for x in res]
return [assets.Traintuple.load(x).attach(self) for x in res]

def get_aggregatetuple(self, *args, **kwargs):
res = self._client.get_aggregatetuple(*args, **kwargs)
Expand All @@ -224,7 +224,7 @@ def get_aggregatetuple(self, *args, **kwargs):

def list_aggregatetuple(self, *args, **kwargs):
res = self._client.list_aggregatetuple(*args, **kwargs)
return [assets.Aggregatetuple.load(x) for x in res]
return [assets.Aggregatetuple.load(x).attach(self) for x in res]

def get_composite_traintuple(self, *args, **kwargs):
res = self._client.get_composite_traintuple(*args, **kwargs)
Expand All @@ -234,7 +234,7 @@ def get_composite_traintuple(self, *args, **kwargs):

def list_composite_traintuple(self, *args, **kwargs):
res = self._client.list_composite_traintuple(*args, **kwargs)
return [assets.CompositeTraintuple.load(x) for x in res]
return [assets.CompositeTraintuple.load(x).attach(self) for x in res]

def get_testtuple(self, *args, **kwargs):
res = self._client.get_testtuple(*args, **kwargs)
Expand All @@ -244,7 +244,7 @@ def get_testtuple(self, *args, **kwargs):

def list_testtuple(self, *args, **kwargs):
res = self._client.list_testtuple(*args, **kwargs)
return [assets.Testtuple.load(x) for x in res]
return [assets.Testtuple.load(x).attach(self) for x in res]

def list_node(self, *args, **kwargs):
res = self._client.list_node(*args, **kwargs)
Expand Down
98 changes: 87 additions & 11 deletions substratest/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class Permissions:


DEFAULT_PERMISSIONS = Permissions(public=True, authorized_ids=[])
DEFAULT_OUT_MODEL_PERMISSIONS = Permissions(public=False, authorized_ids=[])


@dataclasses.dataclass
Expand Down Expand Up @@ -243,12 +244,46 @@ class TesttupleSpec(_Spec):

@dataclasses.dataclass
class ComputePlanTraintupleSpec:
algo_key: str
data_manager_key: str
train_data_sample_keys: str
traintuple_id: str
in_models_ids: typing.List[str]
tag: str

@property
def id(self):
return self.traintuple_id


@dataclasses.dataclass
class ComputePlanAggregatetupleSpec(_Spec):
aggregatetuple_id: str
algo_key: str
worker: str
in_models_ids: typing.List[str]
tag: str

@property
def id(self):
return self.aggregatetuple_id


@dataclasses.dataclass
class ComputePlanCompositeTraintupleSpec(_Spec):
composite_traintuple_id: str
algo_key: str
data_manager_key: str
train_data_sample_keys: str
in_head_model_id: str
in_trunk_model_id: str
tag: str
out_trunk_model_permissions: typing.Dict

@property
def id(self):
return self.composite_traintuple_id


@dataclasses.dataclass
class ComputePlanTesttupleSpec:
Expand Down Expand Up @@ -276,26 +311,69 @@ def _get_keys(obj, field='key'):

@dataclasses.dataclass
class ComputePlanSpec(_Spec):
algo_key: str
objective_key: str
traintuples: typing.List[ComputePlanTraintupleSpec]
composite_traintuples: typing.List[ComputePlanCompositeTraintupleSpec]
aggregatetuples: typing.List[ComputePlanAggregatetupleSpec]
testtuples: typing.List[ComputePlanTesttupleSpec]

def add_traintuple(self, dataset, data_samples, traintuple_specs=None, tag=None):
traintuple_specs = traintuple_specs or []
def add_traintuple(self, algo, dataset, data_samples, in_models=None, tag=''):
in_models = in_models or []
spec = ComputePlanTraintupleSpec(
algo_key=algo.key,
traintuple_id=random_uuid(),
data_manager_key=dataset.key,
train_data_sample_keys=_get_keys(data_samples),
in_models_ids=[t.traintuple_id for t in traintuple_specs],
tag=tag or '',
in_models_ids=[t.id for t in in_models],
tag=tag,
)
self.traintuples.append(spec)
return spec

def add_aggregatetuple(self, aggregate_algo, worker, in_models=None, tag=''):
in_models = in_models or []

for t in in_models:
assert isinstance(t, (ComputePlanTraintupleSpec, ComputePlanCompositeTraintupleSpec))

spec = ComputePlanAggregatetupleSpec(
aggregatetuple_id=random_uuid(),
algo_key=aggregate_algo.key,
worker=worker,
in_models_ids=[t.id for t in in_models],
tag=tag,
)
self.aggregatetuples.append(spec)
return spec

def add_composite_traintuple(self, composite_algo, dataset=None, data_samples=None,
in_head_model=None, in_trunk_model=None,
out_trunk_model_permissions=None, tag=''):
data_samples = data_samples or []

if in_head_model and in_trunk_model:
assert isinstance(in_head_model, ComputePlanCompositeTraintupleSpec)
assert isinstance(
in_trunk_model,
(ComputePlanCompositeTraintupleSpec, ComputePlanAggregatetupleSpec)
)

spec = ComputePlanCompositeTraintupleSpec(
composite_traintuple_id=random_uuid(),
algo_key=composite_algo.key,
data_manager_key=dataset.key if dataset else None,
train_data_sample_keys=_get_keys(data_samples),
in_head_model_id=in_head_model.id if in_head_model else None,
in_trunk_model_id=in_trunk_model.id if in_trunk_model else None,
out_trunk_model_permissions=out_trunk_model_permissions or DEFAULT_OUT_MODEL_PERMISSIONS,
tag=tag,
)
self.composite_traintuples.append(spec)
return spec

def add_testtuple(self, traintuple_spec, tag=None):
spec = ComputePlanTesttupleSpec(
traintuple_id=traintuple_spec.traintuple_id,
traintuple_id=traintuple_spec.id,
tag=tag or '',
)
self.testtuples.append(spec)
Expand Down Expand Up @@ -486,8 +564,6 @@ def create_composite_traintuple(self, algo=None, objective=None, dataset=None,
permissions=None):
data_samples = data_samples or []

kwargs = {}

if head_traintuple and trunk_traintuple:
assert isinstance(head_traintuple, assets.CompositeTraintuple)
assert isinstance(
Expand All @@ -511,7 +587,6 @@ def create_composite_traintuple(self, algo=None, objective=None, dataset=None,
compute_plan_id=compute_plan_id,
rank=rank,
out_trunk_model_permissions=permissions or DEFAULT_PERMISSIONS,
**kwargs,
)

def create_testtuple(self, traintuple=None, tag=None):
Expand All @@ -520,10 +595,11 @@ def create_testtuple(self, traintuple=None, tag=None):
tag=tag,
)

def create_compute_plan(self, algo=None, objective=None):
def create_compute_plan(self, objective=None):
return ComputePlanSpec(
algo_key=algo.key if algo else None,
objective_key=objective.key if objective else None,
traintuples=[],
composite_traintuples=[],
aggregatetuples=[],
testtuples=[],
)

0 comments on commit e5b069c

Please sign in to comment.