Skip to content

Commit

Permalink
Implement basic COALESCE functionality (#823)
Browse files Browse the repository at this point in the history
* Implememnt basic COALESCE functionality

* Enable coalesce for operations

* Un xfail queries

* refactor coalesce to support columns as input

* Add tests for coalesce on columns

Co-authored-by: Chris Jarrett <cjarrett@exp02.aselab.nvidia.com>
Co-authored-by: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com>
  • Loading branch information
3 people committed Dec 1, 2022
1 parent d2896fa commit 6ae69a8
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 4 deletions.
13 changes: 12 additions & 1 deletion dask_sql/physical/rel/logical/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,18 @@ def _collect_aggregations(
filter_backend_col = None

try:
aggregation_function = self.AGGREGATION_MAPPING[aggregation_name]
# This unifies CPU and GPU behavior by ensuring that performing a
# sum on a null column results in null and not 0
if aggregation_name == "sum" and isinstance(df._meta, pd.DataFrame):
aggregation_function = AggregationSpecification(
dd.Aggregation(
name="custom_sum",
chunk=lambda s: s.sum(min_count=1),
agg=lambda s0: s0.sum(min_count=1),
)
)
else:
aggregation_function = self.AGGREGATION_MAPPING[aggregation_name]
except KeyError:
try:
aggregation_function = context.schema[schema_name].functions[
Expand Down
20 changes: 20 additions & 0 deletions dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,25 @@ def overlay(self, s, replace, start, length=None):
return s


class CoalesceOperation(Operation):
def __init__(self):
super().__init__(self.coalesce)

def coalesce(self, *operands):
result = None
for operand in operands:
if is_frame(operand):
# Check if frame evaluates to nan or NA
if len(operand) == 1 and not operand.isnull().all().compute():
return operand if result is None else result.fillna(operand)
else:
result = operand if result is None else result.fillna(operand)
elif not pd.isna(operand):
return operand if result is None else result.fillna(operand)

return result


class ExtractOperation(Operation):
def __init__(self):
super().__init__(self.extract)
Expand Down Expand Up @@ -1059,6 +1078,7 @@ class RexCallPlugin(BaseRexPlugin):
"substr": SubStringOperation(),
"substring": SubStringOperation(),
"initcap": TensorScalarOperation(lambda x: x.str.title(), lambda x: x.title()),
"coalesce": CoalesceOperation(),
"replace": ReplaceOperation(),
# date/time operations
"extract": ExtractOperation(),
Expand Down
59 changes: 59 additions & 0 deletions tests/integration/test_rex.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,65 @@ def test_null(c):
assert_eq(df, expected_df)


@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_coalesce(c, gpu):
df = dd.from_pandas(
pd.DataFrame({"a": [1, 2, 3], "b": [np.nan] * 3}), npartitions=1
)
c.create_table("df", df, gpu=gpu)

df = c.sql(
"""
SELECT
COALESCE(3, 5) as c1,
COALESCE(NULL, NULL) as c2,
COALESCE(NULL, 'hi') as c3,
COALESCE(NULL, NULL, 'bye', 5/0) as c4,
COALESCE(NULL, 3/2, NULL, 'fly') as c5,
COALESCE(SUM(b), 'why', 2.2) as c6,
COALESCE(NULL, MEAN(b), MEAN(a), 4/0) as c7
FROM df
"""
)

expected_df = pd.DataFrame(
{
"c1": [3],
"c2": [np.nan],
"c3": ["hi"],
"c4": ["bye"],
"c5": ["1"],
"c6": ["why"],
"c7": [2.0],
}
)

assert_eq(df, expected_df, check_dtype=False)

df = c.sql(
"""
SELECT
COALESCE(a, b) as c1,
COALESCE(b, a) as c2,
COALESCE(a, a) as c3,
COALESCE(b, b) as c4
FROM df
"""
)

expected_df = pd.DataFrame(
{
"c1": [1, 2, 3],
"c2": [1, 2, 3],
"c3": [1, 2, 3],
"c4": [np.nan] * 3,
}
)

assert_eq(df, expected_df, check_dtype=False)
c.drop_table("df")


def test_boolean_operations(c):
df = dd.from_pandas(pd.DataFrame({"b": [1, 0, -1]}), npartitions=1)
df["b"] = df["b"].apply(
Expand Down
3 changes: 0 additions & 3 deletions tests/unit/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,8 @@
69,
70,
72,
75,
77,
78,
80,
84,
86,
87,
88,
Expand Down

0 comments on commit 6ae69a8

Please sign in to comment.