Skip to content

Commit

Permalink
[pyspark] sort qid for SparkRanker (dmlc#8497)
Browse files Browse the repository at this point in the history
* [pyspark] sort qid for SparkRandker

* resolve comments
  • Loading branch information
wbo4958 authored and trivialfis committed Dec 6, 2022
1 parent 58bc225 commit e59ba25
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 32 deletions.
6 changes: 5 additions & 1 deletion python-package/xgboost/spark/core.py
@@ -1,7 +1,7 @@
# type: ignore
"""Xgboost pyspark integration submodule for core code."""
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
import json
from typing import Iterator, Optional, Tuple

Expand Down Expand Up @@ -728,6 +728,10 @@ def _fit(self, dataset):
else:
dataset = dataset.repartition(num_workers)

if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
# XGBoost requires qid to be sorted for each partition
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)

train_params = self._get_distributed_train_params(dataset)
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
train_params
Expand Down
91 changes: 60 additions & 31 deletions tests/python/test_spark/test_spark_local.py
Expand Up @@ -390,28 +390,6 @@ def setUp(self):
"expected_prediction_with_base_margin",
],
)
self.ranker_df_train = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
],
["features", "label", "qid"],
)
self.ranker_df_test = self.session.createDataFrame(
[
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
],
["features", "qid", "expected_prediction"],
)

self.reg_df_sparse_train = self.session.createDataFrame(
[
Expand Down Expand Up @@ -1039,15 +1017,6 @@ def test_classifier_with_sparse_optim(self):
for row1, row2 in zip(pred_result, pred_result2):
self.assertTrue(np.allclose(row1.probability, row2.probability, rtol=1e-3))

def test_ranker(self):
ranker = SparkXGBRanker(qid_col="qid")
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
model = ranker.fit(self.ranker_df_train)
pred_result = model.transform(self.ranker_df_test).collect()

for row in pred_result:
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)

def test_empty_validation_data(self) -> None:
for tree_method in [
"hist",
Expand Down Expand Up @@ -1130,3 +1099,63 @@ def test_early_stop_param_validation(self):
def test_unsupported_params(self):
with pytest.raises(ValueError, match="evals_result"):
SparkXGBClassifier(evals_result={})


class XgboostRankerLocalTest(SparkTestCase):
def setUp(self):
self.session.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8")
self.ranker_df_train = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
],
["features", "label", "qid"],
)
self.ranker_df_test = self.session.createDataFrame(
[
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
],
["features", "qid", "expected_prediction"],
)
self.ranker_df_train_1 = self.session.createDataFrame(
[
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9),
(Vectors.dense(1.0, 2.0, 3.0), 0, 8),
(Vectors.dense(4.0, 5.0, 6.0), 1, 8),
(Vectors.dense(9.0, 4.0, 8.0), 2, 8),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7),
(Vectors.dense(1.0, 2.0, 3.0), 0, 6),
(Vectors.dense(4.0, 5.0, 6.0), 1, 6),
(Vectors.dense(9.0, 4.0, 8.0), 2, 6),
]
* 4,
["features", "label", "qid"],
)

def test_ranker(self):
ranker = SparkXGBRanker(qid_col="qid")
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
model = ranker.fit(self.ranker_df_train)
pred_result = model.transform(self.ranker_df_test).collect()

for row in pred_result:
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)

def test_ranker_qid_sorted(self):
ranker = SparkXGBRanker(qid_col="qid", num_workers=4)
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
model = ranker.fit(self.ranker_df_train_1)
model.transform(self.ranker_df_test).collect()

0 comments on commit e59ba25

Please sign in to comment.