From e888c199adbbd7f6d6800c9c611104cae9124887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20CORBASSON?= Date: Thu, 5 Dec 2019 13:25:16 +0100 Subject: [PATCH] Merge with pull request #706 by @Mr-F https://github.com/Mr-F --- agate/aggregations/max.py | 2 +- agate/aggregations/mean.py | 7 +++++-- agate/aggregations/min.py | 2 +- tests/test_aggregations.py | 23 +++++++++++++++++------ 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/agate/aggregations/max.py b/agate/aggregations/max.py index dee11ca1..c6dde6c0 100644 --- a/agate/aggregations/max.py +++ b/agate/aggregations/max.py @@ -32,7 +32,7 @@ def validate(self, table): if not (isinstance(column.data_type, Number) or isinstance(column.data_type, Date) or isinstance(column.data_type, DateTime)): - raise DataTypeError('Max can only be applied to columns containing DateTime orNumber data.') + raise DataTypeError('Max can only be applied to columns containing DateTime or Number data.') def run(self, table): column = table.columns[self._column_name] diff --git a/agate/aggregations/mean.py b/agate/aggregations/mean.py index 6a83c1a3..d2de20c9 100644 --- a/agate/aggregations/mean.py +++ b/agate/aggregations/mean.py @@ -35,7 +35,10 @@ def validate(self, table): def run(self, table): column = table.columns[self._column_name] + num_of_values = len(column.values_without_nulls()) + # If there are no non-null columns then return null. + if num_of_values == 0: + return None sum_total = self._sum.run(table) - - return sum_total / len(column.values_without_nulls()) + return sum_total / num_of_values diff --git a/agate/aggregations/min.py b/agate/aggregations/min.py index 1829ac7a..59e2c8d0 100644 --- a/agate/aggregations/min.py +++ b/agate/aggregations/min.py @@ -32,7 +32,7 @@ def validate(self, table): if not (isinstance(column.data_type, Number) or isinstance(column.data_type, Date) or isinstance(column.data_type, DateTime)): - raise DataTypeError('Min can only be applied to columns containing DateTime orNumber data.') + raise DataTypeError('Min can only be applied to columns containing DateTime or Number data.') def run(self, table): column = table.columns[self._column_name] diff --git a/tests/test_aggregations.py b/tests/test_aggregations.py index d244130b..56e942f4 100644 --- a/tests/test_aggregations.py +++ b/tests/test_aggregations.py @@ -211,17 +211,17 @@ def test_max(self): class TestNumberAggregation(unittest.TestCase): def setUp(self): self.rows = ( - (Decimal('1.1'), Decimal('2.19'), 'a'), - (Decimal('2.7'), Decimal('3.42'), 'b'), - (None, Decimal('4.1'), 'c'), - (Decimal('2.7'), Decimal('3.42'), 'c') + (Decimal('1.1'), Decimal('2.19'), 'a', None), + (Decimal('2.7'), Decimal('3.42'), 'b', None), + (None, Decimal('4.1'), 'c', None), + (Decimal('2.7'), Decimal('3.42'), 'c', None) ) self.number_type = Number() self.text_type = Text() - self.column_names = ['one', 'two', 'three'] - self.column_types = [self.number_type, self.number_type, self.text_type] + self.column_names = ['one', 'two', 'three', 'four'] + self.column_types = [self.number_type, self.number_type, self.text_type, self.number_type] self.table = Table(self.rows, self.column_names, self.column_types) @@ -272,6 +272,17 @@ def test_mean(self): self.assertEqual(Mean('two').run(self.table), Decimal('3.2825')) + def test_mean_all_nulls(self): + """ + Test to confirm mean of only nulls doesn't cause a critical error. + + The assumption here is that if you attempt to perform a mean + calculation, on a column which contains only null values, then a null + value should be returned to the caller. + :return: + """ + self.assertIsNone(Mean('four').run(self.table)) + def test_mean_with_nulls(self): warnings.simplefilter('ignore')