Skip to content

Commit

Permalink
rename in_models_tuples and use global_execution_env
Browse files Browse the repository at this point in the history
  • Loading branch information
jmorel committed Dec 3, 2019
1 parent a0b9acb commit fe376b4
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 60 deletions.
4 changes: 2 additions & 2 deletions substratest/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class Meta:


@dataclasses.dataclass
class ComputePlanCreated(_Asset, _FutureMixin):
class ComputePlanCreated(_Asset, _ComputePlanFutureMixin):
compute_plan_id: str
traintuple_keys: typing.List[str]
composite_traintuple_keys: typing.List[str]
Expand All @@ -372,7 +372,7 @@ class ComputePlanCreated(_Asset, _FutureMixin):


@dataclasses.dataclass
class ComputePlan(_Asset):
class ComputePlan(_Asset, _ComputePlanFutureMixin):
compute_plan_id: str
objective_key: str
traintuples: typing.List[str]
Expand Down
14 changes: 7 additions & 7 deletions substratest/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,30 +317,30 @@ class ComputePlanSpec(_Spec):
aggregatetuples: typing.List[ComputePlanAggregatetupleSpec]
testtuples: typing.List[ComputePlanTesttupleSpec]

def add_traintuple(self, algo, dataset, data_samples, in_models_tuples=None, tag=''):
in_models_tuples = in_models_tuples or []
def add_traintuple(self, algo, dataset, data_samples, in_models=None, tag=''):
in_models = in_models or []
spec = ComputePlanTraintupleSpec(
algo_key=algo.key,
traintuple_id=random_uuid(),
data_manager_key=dataset.key,
train_data_sample_keys=_get_keys(data_samples),
in_models_ids=[t.id for t in in_models_tuples],
in_models_ids=[t.id for t in in_models],
tag=tag,
)
self.traintuples.append(spec)
return spec

def add_aggregatetuple(self, aggregate_algo, worker, in_models_tuples=None, tag=''):
in_models_tuples = in_models_tuples or []
def add_aggregatetuple(self, aggregate_algo, worker, in_models=None, tag=''):
in_models = in_models or []

for t in in_models_tuples:
for t in in_models:
assert isinstance(t, (ComputePlanTraintupleSpec, ComputePlanCompositeTraintupleSpec))

spec = ComputePlanAggregatetupleSpec(
aggregatetuple_id=random_uuid(),
algo_key=aggregate_algo.key,
worker=worker,
in_models_ids=[t.id for t in in_models_tuples],
in_models_ids=[t.id for t in in_models],
tag=tag,
)
self.aggregatetuples.append(spec)
Expand Down
77 changes: 26 additions & 51 deletions tests/test_execution_compute_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_compute_plan(global_execution_env):
algo=algo_2,
dataset=dataset_1,
data_samples=dataset_1.train_data_sample_keys,
in_models_tuples=[traintuple_spec_1, traintuple_spec_2],
in_models=[traintuple_spec_1, traintuple_spec_2],
)

# submit compute plan and wait for it to complete
Expand Down Expand Up @@ -96,15 +96,15 @@ def test_compute_plan_single_session_success(global_execution_env):
algo=algo,
dataset=dataset,
data_samples=[data_sample_2],
in_models_tuples=[traintuple_spec_1]
in_models=[traintuple_spec_1]
)
cp_spec.add_testtuple(traintuple_spec_2)

traintuple_spec_3 = cp_spec.add_traintuple(
algo=algo,
dataset=dataset,
data_samples=[data_sample_3],
in_models_tuples=[traintuple_spec_2]
in_models=[traintuple_spec_2]
)
cp_spec.add_testtuple(traintuple_spec_3)

Expand Down Expand Up @@ -156,15 +156,15 @@ def test_compute_plan_single_session_failure(global_execution_env):
algo=algo,
dataset=dataset,
data_samples=[data_sample_2],
in_models_tuples=[traintuple_spec_1]
in_models=[traintuple_spec_1]
)
cp_spec.add_testtuple(traintuple_spec_2)

traintuple_spec_3 = cp_spec.add_traintuple(
algo=algo,
dataset=dataset,
data_samples=[data_sample_3],
in_models_tuples=[traintuple_spec_2]
in_models=[traintuple_spec_2]
)
cp_spec.add_testtuple(traintuple_spec_3)

Expand All @@ -184,40 +184,19 @@ def test_compute_plan_single_session_failure(global_execution_env):
assert set(cp_created.testtuple_keys) == set(cp.testtuples)


def test_compute_plan_aggregate_composite_traintuples(factory, session_1, session_2):
def test_compute_plan_aggregate_composite_traintuples(global_execution_env):
"""
Compute plan version of the `test_aggregate_composite_traintuples` method from `test_execution.py`
"""
aggregate_worker = session_1.node_id
sessions = [session_1, session_2]
factory, network = global_execution_env
sessions = [s.copy() for s in network.sessions]

aggregate_worker = sessions[0].node_id
number_of_rounds = 2

# register objectives, datasets, and data samples
datasets = []
for s in sessions:
# register one dataset per node
spec = factory.create_dataset()
dataset = s.add_dataset(spec)
datasets.append(dataset)

# register one data sample per dataset per round of aggregation
for _ in range(number_of_rounds):
spec = factory.create_data_sample(test_only=False, datasets=[dataset])
s.add_data_sample(spec)
# reload datasets (to ensure they are properly linked with the created data samples)
datasets = [
sessions[i].get_dataset(d.key)
for i, d in enumerate(list(datasets))
]
# register test data on first node
spec = factory.create_data_sample(test_only=True, datasets=[datasets[0]])
test_data_sample = sessions[0].add_data_sample(spec)
# register objective on first node
spec = factory.create_objective(
dataset=datasets[0],
data_samples=[test_data_sample],
)
objective = sessions[0].add_objective(spec)
datasets = sessions[0].state.datasets + sessions[1].state.datasets
objective = sessions[0].state.objectives[0]

# register algos on first node
spec = factory.create_composite_algo()
Expand Down Expand Up @@ -253,7 +232,7 @@ def test_compute_plan_aggregate_composite_traintuples(factory, session_1, sessio
spec = cp_spec.add_aggregatetuple(
aggregate_algo=aggregate_algo,
worker=aggregate_worker,
in_models_tuples=composite_traintuple_specs,
in_models=composite_traintuple_specs,
)

# save state of round
Expand All @@ -266,46 +245,42 @@ def test_compute_plan_aggregate_composite_traintuples(factory, session_1, sessio
traintuple_spec=composite_traintuple_spec,
)

cp = session_1.add_compute_plan(cp_spec).future().wait()
tuples = (cp.list_traintuple(session_1) +
cp.list_composite_traintuples(session_1) +
cp.list_aggregate_tuples(session_1) +
cp.list_testtuples(session_1))
cp = sessions[0].add_compute_plan(cp_spec).future().wait()
tuples = (cp.list_traintuple(sessions) +
cp.list_composite_traintuples(sessions[0]) +
cp.list_aggregate_tuples(sessions[0]) +
cp.list_testtuples(sessions[0]))
for t in tuples:
assert t.status == 'done'


def test_compute_plan_circular_dependency_failure(factory, session):
spec = factory.create_dataset()
dataset = session.add_dataset(spec)
def test_compute_plan_circular_dependency_failure(global_execution_env):
factory, network = global_execution_env
session = network.sessions[0].copy()

dataset = session.state.datasets[0]
objective = session.state.objectives[0]

spec = factory.create_algo()
algo = session.add_algo(spec)

spec = factory.create_data_sample(test_only=False, datasets=[dataset])
data_sample = session.add_data_sample(spec)

spec = factory.create_objective(dataset=dataset)
objective = session.add_objective(spec)

cp_spec = factory.create_compute_plan(objective=objective)

traintuple_spec_1 = cp_spec.add_traintuple(
dataset=dataset,
algo=algo,
data_samples=[data_sample]
data_samples=dataset.train_data_sample_keys
)

traintuple_spec_2 = cp_spec.add_traintuple(
dataset=dataset,
algo=algo,
data_samples=[data_sample]
data_samples=dataset.train_data_sample_keys
)

traintuple_spec_1.in_models_ids.append(traintuple_spec_2.id)
traintuple_spec_2.in_models_ids.append(traintuple_spec_1.id)

# TODO make sur the creation is rejected
with pytest.raises(substra.exceptions.InvalidRequest) as e:
session.add_compute_plan(cp_spec)

Expand Down

0 comments on commit fe376b4

Please sign in to comment.