Skip to content

Commit

Permalink
Fix filtering callable objects in skl xgb param. (#6466)
Browse files Browse the repository at this point in the history
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
  • Loading branch information
trivialfis and hcho3 committed Dec 5, 2020
1 parent 05e5563 commit d6386e4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python-package/xgboost/sklearn.py
Expand Up @@ -398,7 +398,7 @@ def get_xgb_params(self):
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder'}
filtered = dict()
for k, v in params.items():
if k not in wrapper_specific:
if k not in wrapper_specific and not callable(v):
filtered[k] = v
return filtered

Expand Down
15 changes: 15 additions & 0 deletions tests/python/test_with_sklearn.py
Expand Up @@ -399,6 +399,21 @@ def dummy_objective(y_true, y_preds):
X, y
)

cls = xgb.XGBClassifier(use_label_encoder=False, n_estimators=1)
cls.fit(X, y)

is_called = [False]

def wrapped(y, p):
is_called[0] = True
return logregobj(y, p)

cls.set_params(objective=wrapped)
cls.predict(X) # no throw
cls.fit(X, y)

assert is_called[0]


def test_sklearn_api():
from sklearn.datasets import load_iris
Expand Down

0 comments on commit d6386e4

Please sign in to comment.