From f0005070110abab49473470b3ba1564dc59668d1 Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 5 Dec 2020 04:28:06 +0800 Subject: [PATCH] Fix filtering callable objects in skl xgb param. --- python-package/xgboost/sklearn.py | 2 +- tests/python/test_with_sklearn.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 2703b8160d51..d3b2a1bf85b2 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -398,7 +398,7 @@ def get_xgb_params(self): 'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder'} filtered = dict() for k, v in params.items(): - if k not in wrapper_specific: + if k not in wrapper_specific and not callable(v): filtered[k] = v return filtered diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 318c349f3eba..8a4f17ffb66a 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -399,6 +399,21 @@ def dummy_objective(y_true, y_preds): X, y ) + cls = xgb.XGBClassifier(use_label_encoder=False, n_estimators=1) + cls.fit(X, y) + + is_called = [False] + + def wrapped(y, p): + is_called[0] = True + return logregobj(y, p) + + cls.set_params(objective=wrapped) + cls.predict(X) # no throw + cls.fit(X, y) + + assert is_called[0] + def test_sklearn_api(): from sklearn.datasets import load_iris