Skip to content

Commit

Permalink
[SPARK-41064][CONNECT][PYTHON] Implement DataFrame.crosstab and `Da…
Browse files Browse the repository at this point in the history
…taFrame.stat.crosstab`

### What changes were proposed in this pull request?
Implement `DataFrame.crosstab` and `DataFrame.stat.crosstab`

### Why are the changes needed?
for api coverage

### Does this PR introduce _any_ user-facing change?
yes, new api

### How was this patch tested?
added ut

Closes apache#38578 from zhengruifeng/connect_df_crosstab.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng authored and SandishKumarHN committed Dec 12, 2022
1 parent 23b591d commit 78be1b9
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 66 deletions.
19 changes: 19 additions & 0 deletions connector/connect/src/main/protobuf/spark/connect/relations.proto
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ message Relation {

// stat functions
StatSummary summary = 100;
StatCrosstab crosstab = 101;

Unknown unknown = 999;
}
Expand Down Expand Up @@ -284,6 +285,24 @@ message StatSummary {
repeated string statistics = 2;
}

// Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
// It will invoke 'Dataset.stat.crosstab' (same as 'StatFunctions.crossTabulate')
// to compute the results.
message StatCrosstab {
// (Required) The input relation.
Relation input = 1;

// (Required) The name of the first column.
//
// Distinct items will make the first item of each row.
string col1 = 2;

// (Required) The name of the second column.
//
// Distinct items will make the column names of the DataFrame.
string col2 = 3;
}

// Rename columns on the input relation by the same length of names.
message RenameColumnsBySameLengthNames {
// Required. The input relation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,21 @@ package object dsl {
}
}

implicit class DslStatFunctions(val logicalPlan: Relation) {
def crosstab(col1: String, col2: String): Relation = {
Relation
.newBuilder()
.setCrosstab(
proto.StatCrosstab
.newBuilder()
.setInput(logicalPlan)
.setCol1(col1)
.setCol2(col2)
.build())
.build()
}
}

implicit class DslLogicalPlan(val logicalPlan: Relation) {
def select(exprs: Expression*): Relation = {
Relation
Expand Down Expand Up @@ -463,6 +478,8 @@ package object dsl {
Repartition.newBuilder().setInput(logicalPlan).setNumPartitions(num).setShuffle(true))
.build()

def stat: DslStatFunctions = new DslStatFunctions(logicalPlan)

def summary(statistics: String*): Relation = {
Relation
.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
transformSubqueryAlias(rel.getSubqueryAlias)
case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary)
case proto.Relation.RelTypeCase.CROSSTAB =>
transformStatCrosstab(rel.getCrosstab)
case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_SAME_LENGTH_NAMES =>
transformRenameColumnsBySamelenghtNames(rel.getRenameColumnsBySameLengthNames)
case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_NAME_TO_NAME_MAP =>
Expand Down Expand Up @@ -129,6 +131,14 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
.logicalPlan
}

private def transformStatCrosstab(rel: proto.StatCrosstab): LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
.stat
.crosstab(rel.getCol1, rel.getCol2)
.logicalPlan
}

private def transformRenameColumnsBySamelenghtNames(
rel: proto.RenameColumnsBySameLengthNames): LogicalPlan = {
Dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
sparkTestRelation.summary("count", "mean", "stddev"))
}

test("Test crosstab") {
comparePlans(
connectTestRelation.stat.crosstab("id", "name"),
sparkTestRelation.stat.crosstab("id", "name"))
}

test("Test toDF") {
comparePlans(connectTestRelation.toDF("col1", "col2"), sparkTestRelation.toDF("col1", "col2"))
}
Expand Down
62 changes: 62 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,18 @@ def intersectAll(self, other: "DataFrame") -> "DataFrame":
def where(self, condition: Expression) -> "DataFrame":
return self.filter(condition)

@property
def stat(self) -> "DataFrameStatFunctions":
"""Returns a :class:`DataFrameStatFunctions` for statistic functions.
.. versionadded:: 3.4.0
Returns
-------
:class:`DataFrameStatFunctions`
"""
return DataFrameStatFunctions(self)

def summary(self, *statistics: str) -> "DataFrame":
_statistics: List[str] = list(statistics)
for s in _statistics:
Expand All @@ -511,6 +523,41 @@ def summary(self, *statistics: str) -> "DataFrame":
session=self._session,
)

def crosstab(self, col1: str, col2: str) -> "DataFrame":
"""
Computes a pair-wise frequency table of the given columns. Also known as a contingency
table. The number of distinct values for each column should be less than 1e4. At most 1e6
non-zero pair frequencies will be returned.
The first column of each row will be the distinct values of `col1` and the column names
will be the distinct values of `col2`. The name of the first column will be `$col1_$col2`.
Pairs that have no occurrences will have zero as their counts.
:func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases.
.. versionadded:: 3.4.0
Parameters
----------
col1 : str
The name of the first column. Distinct items will make the first item of
each row.
col2 : str
The name of the second column. Distinct items will make the column names
of the :class:`DataFrame`.
Returns
-------
:class:`DataFrame`
Frequency matrix of two columns.
"""
if not isinstance(col1, str):
raise TypeError(f"'col1' must be str, but got {type(col1).__name__}")
if not isinstance(col2, str):
raise TypeError(f"'col2' must be str, but got {type(col2).__name__}")
return DataFrame.withPlan(
plan.StatCrosstab(child=self._plan, col1=col1, col2=col2),
session=self._session,
)

def _get_alias(self) -> Optional[str]:
p = self._plan
while p is not None:
Expand Down Expand Up @@ -579,3 +626,18 @@ def explain(self) -> str:
return self._session.explain_string(query)
else:
return ""


class DataFrameStatFunctions:
"""Functionality for statistic functions with :class:`DataFrame`.
.. versionadded:: 3.4.0
"""

def __init__(self, df: DataFrame):
self.df = df

def crosstab(self, col1: str, col2: str) -> DataFrame:
return self.df.crosstab(col1, col2)

crosstab.__doc__ = DataFrame.crosstab.__doc__
32 changes: 32 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,3 +830,35 @@ def _repr_html_(self) -> str:
</li>
</ul>
"""


class StatCrosstab(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str) -> None:
super().__init__(child)
self.col1 = col1
self.col2 = col2

def plan(self, session: "RemoteSparkSession") -> proto.Relation:
assert self._child is not None

plan = proto.Relation()
plan.crosstab.input.CopyFrom(self._child.plan(session))
plan.crosstab.col1 = self.col1
plan.crosstab.col2 = self.col2
return plan

def print(self, indent: int = 0) -> str:
i = " " * indent
return f"""{i}<Crosstab col1='{self.col1}' col2='{self.col2}'>"""

def _repr_html_(self) -> str:
return f"""
<ul>
<li>
<b>Crosstab</b><br />
Col1: {self.col1} <br />
Col2: {self.col2} <br />
{self._child_repr_()}
</li>
</ul>
"""
134 changes: 68 additions & 66 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Relation(google.protobuf.message.Message):
RENAME_COLUMNS_BY_SAME_LENGTH_NAMES_FIELD_NUMBER: builtins.int
RENAME_COLUMNS_BY_NAME_TO_NAME_MAP_FIELD_NUMBER: builtins.int
SUMMARY_FIELD_NUMBER: builtins.int
CROSSTAB_FIELD_NUMBER: builtins.int
UNKNOWN_FIELD_NUMBER: builtins.int
@property
def common(self) -> global___RelationCommon: ...
Expand Down Expand Up @@ -122,6 +123,8 @@ class Relation(google.protobuf.message.Message):
def summary(self) -> global___StatSummary:
"""stat functions"""
@property
def crosstab(self) -> global___StatCrosstab: ...
@property
def unknown(self) -> global___Unknown: ...
def __init__(
self,
Expand All @@ -146,6 +149,7 @@ class Relation(google.protobuf.message.Message):
rename_columns_by_same_length_names: global___RenameColumnsBySameLengthNames | None = ...,
rename_columns_by_name_to_name_map: global___RenameColumnsByNameToNameMap | None = ...,
summary: global___StatSummary | None = ...,
crosstab: global___StatCrosstab | None = ...,
unknown: global___Unknown | None = ...,
) -> None: ...
def HasField(
Expand All @@ -155,6 +159,8 @@ class Relation(google.protobuf.message.Message):
b"aggregate",
"common",
b"common",
"crosstab",
b"crosstab",
"deduplicate",
b"deduplicate",
"filter",
Expand Down Expand Up @@ -204,6 +210,8 @@ class Relation(google.protobuf.message.Message):
b"aggregate",
"common",
b"common",
"crosstab",
b"crosstab",
"deduplicate",
b"deduplicate",
"filter",
Expand Down Expand Up @@ -268,6 +276,7 @@ class Relation(google.protobuf.message.Message):
"rename_columns_by_same_length_names",
"rename_columns_by_name_to_name_map",
"summary",
"crosstab",
"unknown",
] | None: ...

Expand Down Expand Up @@ -1141,6 +1150,47 @@ class StatSummary(google.protobuf.message.Message):

global___StatSummary = StatSummary

class StatCrosstab(google.protobuf.message.Message):
"""Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
It will invoke 'Dataset.stat.crosstab' (same as 'StatFunctions.crossTabulate')
to compute the results.
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

INPUT_FIELD_NUMBER: builtins.int
COL1_FIELD_NUMBER: builtins.int
COL2_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) The input relation."""
col1: builtins.str
"""(Required) The name of the first column.
Distinct items will make the first item of each row.
"""
col2: builtins.str
"""(Required) The name of the second column.
Distinct items will make the column names of the DataFrame.
"""
def __init__(
self,
*,
input: global___Relation | None = ...,
col1: builtins.str = ...,
col2: builtins.str = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["input", b"input"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal["col1", b"col1", "col2", b"col2", "input", b"input"],
) -> None: ...

global___StatCrosstab = StatCrosstab

class RenameColumnsBySameLengthNames(google.protobuf.message.Message):
"""Rename columns on the input relation by the same length of names."""

Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_plan_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ def test_summary(self):
["count", "mean", "stddev", "min", "25%"],
)

def test_crosstab(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).crosstab("col_a", "col_b")._plan.to_proto(self.connect)
self.assertEqual(plan.root.crosstab.col1, "col_a")
self.assertEqual(plan.root.crosstab.col2, "col_b")

plan = df.stat.crosstab("col_a", "col_b")._plan.to_proto(self.connect)
self.assertEqual(plan.root.crosstab.col1, "col_a")
self.assertEqual(plan.root.crosstab.col2, "col_b")

def test_limit(self):
df = self.connect.readTable(table_name=self.tbl_name)
limit_plan = df.limit(10)._plan.to_proto(self.connect)
Expand Down

0 comments on commit 78be1b9

Please sign in to comment.