Skip to content

Commit

Permalink
Named tuples names instead of index in private_spark (OpenMined#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
MashaTelyatnikova committed Feb 2, 2022
1 parent a69143b commit 1f21249
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pipeline_dp/private_beam.py
Expand Up @@ -202,7 +202,7 @@ def expand(self, pcol: pvalue.PCollection) -> pvalue.PCollection:
# aggregate() returns a namedtuple of metrics for each partition key.
# Here is only one metric - count. Extract it from the list.
dp_result = backend.map_values(dp_result, lambda v: v.count,
"Extract sum")
"Extract count")
# dp_result : (partition_key, dp_count)

return dp_result
Expand Down
36 changes: 20 additions & 16 deletions pipeline_dp/private_spark.py
Expand Up @@ -86,11 +86,12 @@ def mean(self, mean_params: aggregate_params.MeanParams) -> RDD:
value_extractor=lambda x: mean_params.value_extractor(x[1]))

dp_result = dp_engine.aggregate(self._rdd, params, data_extractors)
# dp_result : (partition_key, [dp_mean])
# dp_result : (partition_key, (mean=dp_mean))

# aggregate() returns a list of metrics for each partition key.
# Here is only one metric - sum. Remove list.
dp_result = backend.map_values(dp_result, lambda v: v[0], "Unnest list")
# aggregate() returns a namedtuple of metrics for each partition key.
# Here is only one metric - mean. Extract it from the list.
dp_result = backend.map_values(dp_result, lambda v: v.mean,
"Extract mean")
# dp_result : (partition_key, dp_mean)

return dp_result
Expand Down Expand Up @@ -122,11 +123,12 @@ def sum(self, sum_params: aggregate_params.SumParams) -> RDD:
value_extractor=lambda x: sum_params.value_extractor(x[1]))

dp_result = dp_engine.aggregate(self._rdd, params, data_extractors)
# dp_result : (partition_key, [dp_sum])
# dp_result : (partition_key, (sum=dp_sum))

# aggregate() returns a list of metrics for each partition key.
# Here is only one metric - sum. Remove list.
dp_result = backend.map_values(dp_result, lambda v: v[0], "Unnest list")
# aggregate() returns a namedtuple of metrics for each partition key.
# Here is only one metric - sum. Extract it from the list.
dp_result = backend.map_values(dp_result, lambda v: v.sum,
"Extract sum")
# dp_result : (partition_key, dp_sum)

return dp_result
Expand Down Expand Up @@ -157,11 +159,12 @@ def count(self, count_params: aggregate_params.CountParams) -> RDD:
value_extractor=lambda x: None)

dp_result = dp_engine.aggregate(self._rdd, params, data_extractors)
# dp_result : (partition_key, [dp_count])
# dp_result : (partition_key, (count=dp_count))

# aggregate() returns a list of metrics for each partition key.
# Here is only one metric - count. Remove list.
dp_result = backend.map_values(dp_result, lambda v: v[0], "Unnest list")
# aggregate() returns a namedtuple of metrics for each partition key.
# Here is only one metric - count. Extract it from the list.
dp_result = backend.map_values(dp_result, lambda v: v.count,
"Extract count")
# dp_result : (partition_key, dp_count)

return dp_result
Expand Down Expand Up @@ -194,11 +197,12 @@ def privacy_id_count(
value_extractor=lambda x: None)

dp_result = dp_engine.aggregate(self._rdd, params, data_extractors)
# dp_result : (partition_key, [dp_privacy_id_count])
# dp_result : (partition_key, (privacy_id_count=dp_privacy_id_count))

# aggregate() returns a list of metrics for each partition key.
# Here is only one metric - privacy_id_count. Remove list.
dp_result = backend.map_values(dp_result, lambda v: v[0], "Unnest list")
# aggregate() returns a namedtuple of metrics for each partition key.
# Here is only one metric - privacy id count. Extract it from the list.
dp_result = backend.map_values(dp_result, lambda v: v.privacy_id_count,
"Extract privacy id count")
# dp_result : (partition_key, dp_privacy_id_count)

return dp_result
Expand Down
35 changes: 22 additions & 13 deletions tests/private_spark_test.py
Expand Up @@ -16,7 +16,7 @@
from unittest.mock import patch
import unittest
import sys

import collections
import pipeline_dp
from pipeline_dp import aggregate_params as agg
from pipeline_dp import budget_accounting, private_spark
Expand Down Expand Up @@ -83,8 +83,10 @@ def test_mean_calls_aggregate_with_correct_params(self, mock_aggregate):
# Arrange
dist_data = PrivateRDDTest.sc.parallelize([(1, 2.0, "pk1"),
(2, 2.0, "pk1")])
mock_aggregate.return_value = PrivateRDDTest.sc.parallelize([(2.0,
["pk1"])])
MetricsTuple = collections.namedtuple('MetricsTuple', ['mean'])
mock_aggregate.return_value = PrivateRDDTest.sc.parallelize([
("pk1", MetricsTuple(mean=2.0))
])
budget_accountant = budget_accounting.NaiveBudgetAccountant(1, 1e-10)

def privacy_id_extractor(x):
Expand Down Expand Up @@ -124,7 +126,7 @@ def privacy_id_extractor(x):
public_partitions=mean_params.public_partitions)
self.assertEqual(args[1], params)

self.assertEqual(actual_result.collect(), [(2.0, "pk1")])
self.assertEqual(actual_result.collect(), [("pk1", 2.0)])

def test_mean_returns_sensible_result(self):
# Arrange
Expand Down Expand Up @@ -216,8 +218,10 @@ def test_sum_calls_aggregate_with_correct_params(self, mock_aggregate):
# Arrange
dist_data = PrivateRDDTest.sc.parallelize([(1, 1.0, "pk1"),
(2, 2.0, "pk1")])
mock_aggregate.return_value = PrivateRDDTest.sc.parallelize([(3.0,
["pk1"])])
MetricsTuple = collections.namedtuple('MetricsTuple', ['sum'])
mock_aggregate.return_value = PrivateRDDTest.sc.parallelize([
("pk1", MetricsTuple(sum=3.0))
])
budget_accountant = budget_accounting.NaiveBudgetAccountant(1, 1e-10)

def privacy_id_extractor(x):
Expand Down Expand Up @@ -257,7 +261,7 @@ def privacy_id_extractor(x):
public_partitions=sum_params.public_partitions)
self.assertEqual(args[1], params)

self.assertEqual(actual_result.collect(), [(3.0, "pk1")])
self.assertEqual(actual_result.collect(), [("pk1", 3.0)])

def test_sum_returns_sensible_result(self):
# Arrange
Expand Down Expand Up @@ -347,8 +351,10 @@ def privacy_id_extractor(x):
def test_count_calls_aggregate_with_correct_params(self, mock_aggregate):
# Arrange
dist_data = PrivateRDDTest.sc.parallelize([(1, "pk1"), (2, "pk1")])
mock_aggregate.return_value = PrivateRDDTest.sc.parallelize([(2,
["pk1"])])
MetricsTuple = collections.namedtuple('MetricsTuple', ['count'])
mock_aggregate.return_value = PrivateRDDTest.sc.parallelize([
("pk1", MetricsTuple(count=2))
])
budget_accountant = budget_accounting.NaiveBudgetAccountant(1, 1e-10)

def privacy_id_extractor(x):
Expand Down Expand Up @@ -383,7 +389,7 @@ def privacy_id_extractor(x):
public_partitions=count_params.public_partitions)
self.assertEqual(args[1], params)

self.assertEqual(actual_result.collect(), [(2, "pk1")])
self.assertEqual(actual_result.collect(), [("pk1", 2)])

def test_count_returns_sensible_result(self):
# Arrange
Expand Down Expand Up @@ -471,8 +477,11 @@ def test_privacy_id_count_calls_aggregate_with_correct_params(
self, mock_aggregate):
# Arrange
dist_data = PrivateRDDTest.sc.parallelize([(1, "pk1"), (2, "pk1")])
mock_aggregate.return_value = PrivateRDDTest.sc.parallelize([(2,
["pk1"])])
MetricsTuple = collections.namedtuple('MetricsTuple',
['privacy_id_count'])
mock_aggregate.return_value = PrivateRDDTest.sc.parallelize([
("pk1", MetricsTuple(privacy_id_count=2))
])
budget_accountant = budget_accounting.NaiveBudgetAccountant(1, 1e-10)

def privacy_id_extractor(x):
Expand Down Expand Up @@ -505,7 +514,7 @@ def privacy_id_extractor(x):
public_partitions=privacy_id_count_params.public_partitions)
self.assertEqual(args[1], params)

self.assertEqual([(2, "pk1")], actual_result.collect())
self.assertEqual([("pk1", 2)], actual_result.collect())

def test_privacy_id_count_returns_sensible_result(self):
# Arrange
Expand Down

0 comments on commit 1f21249

Please sign in to comment.