From 173e6d65ff0257c21c29f2053e0755dafe91e921 Mon Sep 17 00:00:00 2001 From: Danny Sepler Date: Sun, 15 Mar 2020 15:04:35 -0400 Subject: [PATCH 1/2] Sum should work with time deltas, also isinstance cleanups --- AUTHORS.rst | 1 + agate/aggregations/max.py | 10 +++------- agate/aggregations/min.py | 10 +++------- agate/aggregations/sum.py | 18 +++++++++++++----- agate/table/pivot.py | 2 +- tests/test_aggregations.py | 12 ++++++++++++ 6 files changed, 33 insertions(+), 20 deletions(-) diff --git a/AUTHORS.rst b/AUTHORS.rst index 83967f92..28091202 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -41,3 +41,4 @@ agate is made by a community. The following individuals have contributed code, d * `Kartik Agaram `_ * `Loïc Corbasson `_ * `Robert Schütz `_ +* `Danny Sepler `_ diff --git a/agate/aggregations/max.py b/agate/aggregations/max.py index 470fab73..11b7cec5 100644 --- a/agate/aggregations/max.py +++ b/agate/aggregations/max.py @@ -21,18 +21,14 @@ def __init__(self, column_name): def get_aggregate_data_type(self, table): column = table.columns[self._column_name] - if (isinstance(column.data_type, Number) or - isinstance(column.data_type, Date) or - isinstance(column.data_type, DateTime)): + if isinstance(column.data_type, (Number, Date, DateTime)): return column.data_type def validate(self, table): column = table.columns[self._column_name] - if not (isinstance(column.data_type, Number) or - isinstance(column.data_type, Date) or - isinstance(column.data_type, DateTime)): - raise DataTypeError('Min can only be applied to columns containing DateTime orNumber data.') + if not isinstance(column.data_type, (Number, Date, DateTime)): + raise DataTypeError('Min can only be applied to columns containing DateTime, Date or Number data.') def run(self, table): column = table.columns[self._column_name] diff --git a/agate/aggregations/min.py b/agate/aggregations/min.py index e74914de..974dfbda 100644 --- a/agate/aggregations/min.py +++ b/agate/aggregations/min.py @@ -21,18 +21,14 @@ def __init__(self, column_name): def get_aggregate_data_type(self, table): column = table.columns[self._column_name] - if (isinstance(column.data_type, Number) or - isinstance(column.data_type, Date) or - isinstance(column.data_type, DateTime)): + if isinstance(column.data_type, (Number, Date, DateTime)): return column.data_type def validate(self, table): column = table.columns[self._column_name] - if not (isinstance(column.data_type, Number) or - isinstance(column.data_type, Date) or - isinstance(column.data_type, DateTime)): - raise DataTypeError('Min can only be applied to columns containing DateTime orNumber data.') + if not isinstance(column.data_type, (Number, Date, DateTime)): + raise DataTypeError('Min can only be applied to columns containing DateTime, Date or Number data.') def run(self, table): column = table.columns[self._column_name] diff --git a/agate/aggregations/sum.py b/agate/aggregations/sum.py index efe793d9..0d45342a 100644 --- a/agate/aggregations/sum.py +++ b/agate/aggregations/sum.py @@ -1,7 +1,8 @@ #!/usr/bin/env python +import datetime from agate.aggregations.base import Aggregation -from agate.data_types import Number +from agate.data_types import Number, TimeDelta from agate.exceptions import DataTypeError @@ -16,15 +17,22 @@ def __init__(self, column_name): self._column_name = column_name def get_aggregate_data_type(self, table): - return Number() + column = table.columns[self._column_name] + + if isinstance(column.data_type, (Number, TimeDelta)): + return column.data_type def validate(self, table): column = table.columns[self._column_name] - if not isinstance(column.data_type, Number): - raise DataTypeError('Sum can only be applied to columns containing Number data.') + if not isinstance(column.data_type, (Number, TimeDelta)): + raise DataTypeError('Sum can only be applied to columns containing Number or TimeDelta data.') def run(self, table): column = table.columns[self._column_name] - return sum(column.values_without_nulls()) + start = 0 + if isinstance(column.data_type, TimeDelta): + start = datetime.timedelta() + + return sum(column.values_without_nulls(), start) diff --git a/agate/table/pivot.py b/agate/table/pivot.py index 32d2548c..1d90cbc7 100644 --- a/agate/table/pivot.py +++ b/agate/table/pivot.py @@ -110,7 +110,7 @@ def apply_computation(table): if pivot is not None: groups = groups.group_by(pivot) - column_type = aggregation.get_aggregate_data_type(groups) + column_type = aggregation.get_aggregate_data_type(self) table = groups.aggregate([ (aggregation_name, aggregation) diff --git a/tests/test_aggregations.py b/tests/test_aggregations.py index d244130b..1f3c90b0 100644 --- a/tests/test_aggregations.py +++ b/tests/test_aggregations.py @@ -207,6 +207,18 @@ def test_max(self): Max('test').validate(table) self.assertEqual(Max('test').run(table), datetime.datetime(1994, 3, 3, 6, 31)) + def test_sum(self): + rows = [ + [datetime.timedelta(seconds=10)], + [datetime.timedelta(seconds=20)], + ] + + table = Table(rows, ['test'], [TimeDelta()]) + + self.assertIsInstance(Sum('test').get_aggregate_data_type(table), TimeDelta) + Sum('test').validate(table) + self.assertEqual(Sum('test').run(table), datetime.timedelta(seconds=30)) + class TestNumberAggregation(unittest.TestCase): def setUp(self): From 29ccc9d994601105880a6fb489063307ae2c2192 Mon Sep 17 00:00:00 2001 From: Danny Sepler Date: Sun, 15 Mar 2020 20:42:51 -0400 Subject: [PATCH 2/2] another isinstance --- agate/computations/change.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/agate/computations/change.py b/agate/computations/change.py index 927bbeb1..4925f67b 100644 --- a/agate/computations/change.py +++ b/agate/computations/change.py @@ -27,11 +27,7 @@ def __init__(self, before_column_name, after_column_name): def get_computed_data_type(self, table): before_column = table.columns[self._before_column_name] - if isinstance(before_column.data_type, Date): - return TimeDelta() - elif isinstance(before_column.data_type, DateTime): - return TimeDelta() - elif isinstance(before_column.data_type, TimeDelta): + if isinstance(before_column.data_type, (Date, DateTime, TimeDelta)): return TimeDelta() elif isinstance(before_column.data_type, Number): return Number()