diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 63c928c53a44..2091fe148fcb 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -291,7 +291,7 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info // labels is a vector of size n_samples. float label = labels[idx % n_samples] == class_id; - float w = get_weight(i % n_samples); + float w = weights.empty() ? 1.0f : weights[d_sorted_idx[i] % n_samples]; float fp = (1.0 - label) * w; float tp = label * w; return thrust::make_pair(fp, tp); diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 94516be04b9a..1e4731454d9b 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -143,7 +143,7 @@ void CheckRankingObjFunction(std::unique_ptr const& obj, } xgboost::bst_float GetMetricEval(xgboost::Metric * metric, - xgboost::HostDeviceVector preds, + xgboost::HostDeviceVector const& preds, std::vector labels, std::vector weights, std::vector groups) { diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index f5f88aff1711..fc9b594c0183 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -86,7 +86,7 @@ void CheckRankingObjFunction(std::unique_ptr const& obj, xgboost::bst_float GetMetricEval( xgboost::Metric * metric, - xgboost::HostDeviceVector preds, + xgboost::HostDeviceVector const& preds, std::vector labels, std::vector weights = std::vector(), std::vector groups = std::vector()); diff --git a/tests/cpp/metric/test_auc.cc b/tests/cpp/metric/test_auc.cc index 310efaa4ddea..76535ad94f3b 100644 --- a/tests/cpp/metric/test_auc.cc +++ b/tests/cpp/metric/test_auc.cc @@ -90,6 +90,16 @@ TEST(Metric, DeclareUnifiedTest(MultiAUC)) { }, {0, 1, 1}); // no class 2. EXPECT_TRUE(std::isnan(auc)) << auc; + + HostDeviceVector predts{ + 0.0f, 1.0f, 0.0f, + 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 1.0f, + }; + std::vector labels {1.0f, 0.0f, 2.0f, 1.0f}; + auc = GetMetricEval(metric, predts, labels, {1.0f, 2.0f, 3.0f, 4.0f}); + ASSERT_GT(auc, 0.714); } TEST(Metric, DeclareUnifiedTest(RankingAUC)) { diff --git a/tests/python-gpu/test_gpu_eval_metrics.py b/tests/python-gpu/test_gpu_eval_metrics.py index 36b6a70868b9..f2b605c8b302 100644 --- a/tests/python-gpu/test_gpu_eval_metrics.py +++ b/tests/python-gpu/test_gpu_eval_metrics.py @@ -13,9 +13,11 @@ class TestGPUEvalMetrics: def test_roc_auc_binary(self, n_samples): self.cpu_test.run_roc_auc_binary("gpu_hist", n_samples) - @pytest.mark.parametrize("n_samples", [4, 100, 1000]) - def test_roc_auc_multi(self, n_samples): - self.cpu_test.run_roc_auc_multi("gpu_hist", n_samples) + @pytest.mark.parametrize( + "n_samples,weighted", [(4, False), (100, False), (1000, False), (1000, True)] + ) + def test_roc_auc_multi(self, n_samples, weighted): + self.cpu_test.run_roc_auc_multi("gpu_hist", n_samples, weighted) @pytest.mark.parametrize("n_samples", [4, 100, 1000]) def test_roc_auc_ltr(self, n_samples): diff --git a/tests/python/test_eval_metrics.py b/tests/python/test_eval_metrics.py index 71a8691fe7f2..877f3ef33447 100644 --- a/tests/python/test_eval_metrics.py +++ b/tests/python/test_eval_metrics.py @@ -191,11 +191,11 @@ def run_roc_auc_binary(self, tree_method, n_samples): np.testing.assert_allclose(skl_auc, auc, rtol=1e-6) @pytest.mark.skipif(**tm.no_sklearn()) - @pytest.mark.parametrize("n_samples", [4, 100, 1000]) + @pytest.mark.parametrize("n_samples", [100, 1000]) def test_roc_auc(self, n_samples): self.run_roc_auc_binary("hist", n_samples) - def run_roc_auc_multi(self, tree_method, n_samples): + def run_roc_auc_multi(self, tree_method, n_samples, weighted): import numpy as np from sklearn.datasets import make_classification from sklearn.metrics import roc_auc_score @@ -213,8 +213,14 @@ def run_roc_auc_multi(self, tree_method, n_samples): n_classes=n_classes, random_state=rng ) - - Xy = xgb.DMatrix(X, y) + if weighted: + weights = rng.randn(n_samples) + weights -= weights.min() + weights /= weights.max() + else: + weights = None + + Xy = xgb.DMatrix(X, y, weight=weights) booster = xgb.train( { "tree_method": tree_method, @@ -226,16 +232,22 @@ def run_roc_auc_multi(self, tree_method, n_samples): num_boost_round=8, ) score = booster.predict(Xy) - skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr") + skl_auc = roc_auc_score( + y, score, average="weighted", sample_weight=weights, multi_class="ovr" + ) auc = float(booster.eval(Xy).split(":")[1]) np.testing.assert_allclose(skl_auc, auc, rtol=1e-6) X = rng.randn(*X.shape) - score = booster.predict(xgb.DMatrix(X)) - skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr") - auc = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1]) - np.testing.assert_allclose(skl_auc, auc, rtol=1e-6) - - @pytest.mark.parametrize("n_samples", [4, 100, 1000]) - def test_roc_auc_multi(self, n_samples): - self.run_roc_auc_multi("hist", n_samples) + score = booster.predict(xgb.DMatrix(X, weight=weights)) + skl_auc = roc_auc_score( + y, score, average="weighted", sample_weight=weights, multi_class="ovr" + ) + auc = float(booster.eval(xgb.DMatrix(X, y, weight=weights)).split(":")[1]) + np.testing.assert_allclose(skl_auc, auc, rtol=1e-5) + + @pytest.mark.parametrize( + "n_samples,weighted", [(4, False), (100, False), (1000, False), (1000, True)] + ) + def test_roc_auc_multi(self, n_samples, weighted): + self.run_roc_auc_multi("hist", n_samples, weighted)