New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pyspark] sort qid for SparkRanker #8497
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -383,28 +383,6 @@ def setUp(self): | |
"expected_prediction_with_base_margin", | ||
], | ||
) | ||
self.ranker_df_train = self.session.createDataFrame( | ||
[ | ||
(Vectors.dense(1.0, 2.0, 3.0), 0, 0), | ||
(Vectors.dense(4.0, 5.0, 6.0), 1, 0), | ||
(Vectors.dense(9.0, 4.0, 8.0), 2, 0), | ||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1), | ||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1), | ||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1), | ||
], | ||
["features", "label", "qid"], | ||
) | ||
self.ranker_df_test = self.session.createDataFrame( | ||
[ | ||
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988), | ||
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556), | ||
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570), | ||
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988), | ||
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612), | ||
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826), | ||
], | ||
["features", "qid", "expected_prediction"], | ||
) | ||
|
||
self.reg_df_sparse_train = self.session.createDataFrame( | ||
[ | ||
|
@@ -1033,15 +1011,6 @@ def test_classifier_with_sparse_optim(self): | |
for row1, row2 in zip(pred_result, pred_result2): | ||
self.assertTrue(np.allclose(row1.probability, row2.probability, rtol=1e-3)) | ||
|
||
def test_ranker(self): | ||
ranker = SparkXGBRanker(qid_col="qid") | ||
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" | ||
model = ranker.fit(self.ranker_df_train) | ||
pred_result = model.transform(self.ranker_df_test).collect() | ||
|
||
for row in pred_result: | ||
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3) | ||
|
||
def test_empty_validation_data(self) -> None: | ||
for tree_method in [ | ||
"hist", | ||
|
@@ -1124,3 +1093,80 @@ def test_early_stop_param_validation(self): | |
def test_unsupported_params(self): | ||
with pytest.raises(ValueError, match="evals_result"): | ||
SparkXGBClassifier(evals_result={}) | ||
|
||
|
||
class XgboostRankerLocalTest(SparkTestCase): | ||
def setUp(self): | ||
self.session.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8") | ||
self.ranker_df_train = self.session.createDataFrame( | ||
[ | ||
(Vectors.dense(1.0, 2.0, 3.0), 0, 0), | ||
(Vectors.dense(4.0, 5.0, 6.0), 1, 0), | ||
(Vectors.dense(9.0, 4.0, 8.0), 2, 0), | ||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1), | ||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1), | ||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1), | ||
], | ||
["features", "label", "qid"], | ||
) | ||
self.ranker_df_test = self.session.createDataFrame( | ||
[ | ||
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988), | ||
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556), | ||
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570), | ||
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988), | ||
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612), | ||
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826), | ||
], | ||
["features", "qid", "expected_prediction"], | ||
) | ||
self.ranker_df_train_1 = self.session.createDataFrame( | ||
[ | ||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How did you produce this data and the expected result? Please try not to use hardcoded results. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, the qid is the descending order. without the fix, it will throw exception |
||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9), | ||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9), | ||
(Vectors.dense(1.0, 2.0, 3.0), 0, 8), | ||
(Vectors.dense(4.0, 5.0, 6.0), 1, 8), | ||
(Vectors.dense(9.0, 4.0, 8.0), 2, 8), | ||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7), | ||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7), | ||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7), | ||
(Vectors.dense(1.0, 2.0, 3.0), 0, 6), | ||
(Vectors.dense(4.0, 5.0, 6.0), 1, 6), | ||
(Vectors.dense(9.0, 4.0, 8.0), 2, 6), | ||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 5), | ||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 5), | ||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 5), | ||
(Vectors.dense(1.0, 2.0, 3.0), 0, 4), | ||
(Vectors.dense(4.0, 5.0, 6.0), 1, 4), | ||
(Vectors.dense(9.0, 4.0, 8.0), 2, 4), | ||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 3), | ||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 3), | ||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 3), | ||
(Vectors.dense(1.0, 2.0, 3.0), 0, 2), | ||
(Vectors.dense(4.0, 5.0, 6.0), 1, 2), | ||
(Vectors.dense(9.0, 4.0, 8.0), 2, 2), | ||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1), | ||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1), | ||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1), | ||
(Vectors.dense(1.0, 2.0, 3.0), 0, 0), | ||
(Vectors.dense(4.0, 5.0, 6.0), 1, 0), | ||
(Vectors.dense(9.0, 4.0, 8.0), 2, 0), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: do we need hardcode so long data list ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
], | ||
["features", "label", "qid"], | ||
) | ||
|
||
def test_ranker(self): | ||
ranker = SparkXGBRanker(qid_col="qid") | ||
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" | ||
model = ranker.fit(self.ranker_df_train) | ||
pred_result = model.transform(self.ranker_df_test).collect() | ||
|
||
for row in pred_result: | ||
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wbo4958 This is not only checking exception. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is moved from https://github.com/dmlc/xgboost/pull/8497/files#diff-3b3ca1f9bd10767b61c3eab170a027b67408881dcf57e4e992c2caa47d660ff5L386-L407, I didn't change it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah ... That's a headache, I'm blocked by these tests and don't know how to recreate them... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we can have the following PR to refactor these tests by not hardcoding them |
||
|
||
def test_ranker_qid_sorted(self): | ||
ranker = SparkXGBRanker(qid_col="qid", num_workers=2) | ||
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" | ||
model = ranker.fit(self.ranker_df_train_1) | ||
model.transform(self.ranker_df_test).collect() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the purpose of this test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to test if the SparkRanker will throw exception |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: add
ascending=True
explicitly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done