Skip to content

Commit

Permalink
Fix weighted samples in multi-class AUC. (#7300)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 11, 2021
1 parent 69d3b1b commit 298af6f
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/metric/auc.cu
Expand Up @@ -291,7 +291,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
// labels is a vector of size n_samples.
float label = labels[idx % n_samples] == class_id;

float w = get_weight(i % n_samples);
float w = weights.empty() ? 1.0f : weights[d_sorted_idx[i] % n_samples];
float fp = (1.0 - label) * w;
float tp = label * w;
return thrust::make_pair(fp, tp);
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/helpers.cc
Expand Up @@ -143,7 +143,7 @@ void CheckRankingObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
}

xgboost::bst_float GetMetricEval(xgboost::Metric * metric,
xgboost::HostDeviceVector<xgboost::bst_float> preds,
xgboost::HostDeviceVector<xgboost::bst_float> const& preds,
std::vector<xgboost::bst_float> labels,
std::vector<xgboost::bst_float> weights,
std::vector<xgboost::bst_uint> groups) {
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/helpers.h
Expand Up @@ -86,7 +86,7 @@ void CheckRankingObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,

xgboost::bst_float GetMetricEval(
xgboost::Metric * metric,
xgboost::HostDeviceVector<xgboost::bst_float> preds,
xgboost::HostDeviceVector<xgboost::bst_float> const& preds,
std::vector<xgboost::bst_float> labels,
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float>(),
std::vector<xgboost::bst_uint> groups = std::vector<xgboost::bst_uint>());
Expand Down
10 changes: 10 additions & 0 deletions tests/cpp/metric/test_auc.cc
Expand Up @@ -90,6 +90,16 @@ TEST(Metric, DeclareUnifiedTest(MultiAUC)) {
},
{0, 1, 1}); // no class 2.
EXPECT_TRUE(std::isnan(auc)) << auc;

HostDeviceVector<float> predts{
0.0f, 1.0f, 0.0f,
1.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f,
0.0f, 0.0f, 1.0f,
};
std::vector<float> labels {1.0f, 0.0f, 2.0f, 1.0f};
auc = GetMetricEval(metric, predts, labels, {1.0f, 2.0f, 3.0f, 4.0f});
ASSERT_GT(auc, 0.714);
}

TEST(Metric, DeclareUnifiedTest(RankingAUC)) {
Expand Down
8 changes: 5 additions & 3 deletions tests/python-gpu/test_gpu_eval_metrics.py
Expand Up @@ -13,9 +13,11 @@ class TestGPUEvalMetrics:
def test_roc_auc_binary(self, n_samples):
self.cpu_test.run_roc_auc_binary("gpu_hist", n_samples)

@pytest.mark.parametrize("n_samples", [4, 100, 1000])
def test_roc_auc_multi(self, n_samples):
self.cpu_test.run_roc_auc_multi("gpu_hist", n_samples)
@pytest.mark.parametrize(
"n_samples,weighted", [(4, False), (100, False), (1000, False), (1000, True)]
)
def test_roc_auc_multi(self, n_samples, weighted):
self.cpu_test.run_roc_auc_multi("gpu_hist", n_samples, weighted)

@pytest.mark.parametrize("n_samples", [4, 100, 1000])
def test_roc_auc_ltr(self, n_samples):
Expand Down
38 changes: 25 additions & 13 deletions tests/python/test_eval_metrics.py
Expand Up @@ -191,11 +191,11 @@ def run_roc_auc_binary(self, tree_method, n_samples):
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)

@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
@pytest.mark.parametrize("n_samples", [100, 1000])
def test_roc_auc(self, n_samples):
self.run_roc_auc_binary("hist", n_samples)

def run_roc_auc_multi(self, tree_method, n_samples):
def run_roc_auc_multi(self, tree_method, n_samples, weighted):
import numpy as np
from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score
Expand All @@ -213,8 +213,14 @@ def run_roc_auc_multi(self, tree_method, n_samples):
n_classes=n_classes,
random_state=rng
)

Xy = xgb.DMatrix(X, y)
if weighted:
weights = rng.randn(n_samples)
weights -= weights.min()
weights /= weights.max()
else:
weights = None

Xy = xgb.DMatrix(X, y, weight=weights)
booster = xgb.train(
{
"tree_method": tree_method,
Expand All @@ -226,16 +232,22 @@ def run_roc_auc_multi(self, tree_method, n_samples):
num_boost_round=8,
)
score = booster.predict(Xy)
skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr")
skl_auc = roc_auc_score(
y, score, average="weighted", sample_weight=weights, multi_class="ovr"
)
auc = float(booster.eval(Xy).split(":")[1])
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)

X = rng.randn(*X.shape)
score = booster.predict(xgb.DMatrix(X))
skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr")
auc = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1])
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)

@pytest.mark.parametrize("n_samples", [4, 100, 1000])
def test_roc_auc_multi(self, n_samples):
self.run_roc_auc_multi("hist", n_samples)
score = booster.predict(xgb.DMatrix(X, weight=weights))
skl_auc = roc_auc_score(
y, score, average="weighted", sample_weight=weights, multi_class="ovr"
)
auc = float(booster.eval(xgb.DMatrix(X, y, weight=weights)).split(":")[1])
np.testing.assert_allclose(skl_auc, auc, rtol=1e-5)

@pytest.mark.parametrize(
"n_samples,weighted", [(4, False), (100, False), (1000, False), (1000, True)]
)
def test_roc_auc_multi(self, n_samples, weighted):
self.run_roc_auc_multi("hist", n_samples, weighted)

0 comments on commit 298af6f

Please sign in to comment.