Skip to content

Commit

Permalink
Merge pull request #2 from trinhcon/fix/ndcg-score-with-test
Browse files Browse the repository at this point in the history
deprecation warning for negative ndcg
  • Loading branch information
trinhcon committed Mar 4, 2022
2 parents 689e72d + b628cb0 commit 07a2447
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
10 changes: 8 additions & 2 deletions sklearn/metrics/_ranking.py
Expand Up @@ -1622,11 +1622,17 @@ def ndcg_score(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False

if (isinstance(y_true, np.ndarray)):
if (y_true.min() < 0):
raise DeprecationWarning("ndcg_score should not use negative y_true values")
warnings.warn(
"ndcg_score should not use negative y_true values",
DeprecationWarning,
)
else:
for value in y_true:
if (value < 0):
raise DeprecationWarning("ndcg_score should not use negative y_true values")
warnings.warn(
"ndcg_score should not use negative y_true values",
DeprecationWarning,
)
return np.average(gain, weights=sample_weight)


Expand Down
17 changes: 17 additions & 0 deletions sklearn/metrics/tests/test_ranking.py
Expand Up @@ -1640,6 +1640,23 @@ def test_ndcg_ignore_ties_with_k():
ndcg_score(a, a, k=3, ignore_ties=True)
)

def test_ndcg_negative_ndarray_warn():
y_true = np.array([-0.89, -0.53, -0.47, 0.39, 0.56]).reshape(1,-1)
y_score = np.array([0.07,0.31,0.75,0.33,0.27]).reshape(1,-1)
expected_message = "ndcg_score should not use negative y_true values"
with pytest.warns(DeprecationWarning, match=expected_message):
ndcg_score(y_true, y_score)

def test_ndcg_negative_output():
y_true = np.array([-0.89, -0.53, -0.47, 0.39, 0.56]).reshape(1,-1)
y_score = np.array([0.07,0.31,0.75,0.33,0.27]).reshape(1,-1)
assert ndcg_score(y_true, y_score) == pytest.approx(396.0329)

def test_ndcg_positive_ndarray():
y_true = np.array([0.11, 0.47, 0.53, 1.39, 1.56]).reshape(1,-1)
y_score = np.array([1.07, 1.31, 1.75, 1.33, 1.27]).reshape(1,-1)
with pytest.warns(None):
ndcg_score(y_true, y_score)

def test_ndcg_invariant():
y_true = np.arange(70).reshape(7, 10)
Expand Down

0 comments on commit 07a2447

Please sign in to comment.