diff --git a/shap/explainers/_tree.py b/shap/explainers/_tree.py index b17a1e081..1f68cf7ef 100644 --- a/shap/explainers/_tree.py +++ b/shap/explainers/_tree.py @@ -1427,7 +1427,9 @@ class XGBTreeModelLoader(object): """ def __init__(self, xgb_model): # new in XGBoost 1.1, 'binf' is appended to the buffer - self.buf = xgb_model.save_raw().lstrip(b'binf') + self.buf = xgb_model.save_raw() + if self.buf.startswith(b'binf'): + self.buf = self.buf[4:] self.pos = 0 # load the model parameters diff --git a/tests/explainers/test_tree.py b/tests/explainers/test_tree.py index c5da9a3d5..0bdda734f 100644 --- a/tests/explainers/test_tree.py +++ b/tests/explainers/test_tree.py @@ -1115,3 +1115,20 @@ def objective_function(x): result_et.models[-1].predict(et_df)) assert np.allclose(shap_values_rf.sum(1) + explainer_rf.expected_value, result_rf.models[-1].predict(rf_df)) + + +def test_xgboost_buffer_strip(): + # test to make sure bug #1864 doesn't get reintroduced + xgboost = pytest.importorskip("xgboost") + X = np.array([[1, 2, 3, 4, 5], [3, 3, 3, 2, 4]]) + y = np.array([1, 0]) + # specific values (e.g. 1.3) caused the bug previously + model = xgboost.XGBRegressor(base_score=1.3) + model.fit(X, y, eval_metric="rmse") + # previous bug did .lstrip('binf'), so would have incorrectly handled + # buffer starting with binff + assert model.get_booster().save_raw().startswith(b"binff") + + # after this fix, this line should not error + explainer = shap.TreeExplainer(model) + assert isinstance(explainer, shap.explainers.Tree)