Skip to content

Commit

Permalink
Pass correct split_type to GPU predictor (#6491)
Browse files Browse the repository at this point in the history
* Pass correct split_type to GPU predictor

* Add a test
  • Loading branch information
hcho3 committed Dec 12, 2020
1 parent 0d483cb commit c31e3ef
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/predictor/gpu_predictor.cu
Expand Up @@ -272,7 +272,11 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
d_cat_tree_segments[tree_idx - tree_begin + 1] -
d_cat_tree_segments[tree_idx - tree_begin]);
float leaf = GetLeafWeight(global_idx, d_tree, d_tree_split_types,
auto tree_split_types =
d_tree_split_types.subspan(d_tree_segments[tree_idx - tree_begin],
d_tree_segments[tree_idx - tree_begin + 1] -
d_tree_segments[tree_idx - tree_begin]);
float leaf = GetLeafWeight(global_idx, d_tree, tree_split_types,
tree_cat_ptrs,
tree_categories,
&loader);
Expand Down
35 changes: 35 additions & 0 deletions tests/python-gpu/test_gpu_prediction.py
Expand Up @@ -3,8 +3,17 @@

import numpy as np
import xgboost as xgb
from xgboost.compat import PANDAS_INSTALLED

from hypothesis import given, strategies, assume, settings, note

if PANDAS_INSTALLED:
from hypothesis.extra.pandas import column, data_frames, range_indexes
else:
def noop(*args, **kwargs):
pass
column, data_frames, range_indexes = noop, noop, noop

sys.path.append("tests/python")
import testing as tm
from test_predict import run_threaded_predict # noqa
Expand Down Expand Up @@ -259,3 +268,29 @@ def test_predict_leaf_dart(self, param, dataset):
param['booster'] = 'dart'
param['tree_method'] = 'gpu_hist'
self.run_predict_leaf_booster(param, 10, dataset)

@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.skipif(**tm.no_pandas())
@given(df=data_frames([column('x0', elements=strategies.integers(min_value=0, max_value=3)),
column('x1', elements=strategies.integers(min_value=0, max_value=5))],
index=range_indexes(min_size=20, max_size=50)))
@settings(deadline=None)
def test_predict_categorical_split(self, df):
from sklearn.metrics import mean_squared_error

df = df.astype('category')
x0, x1 = df['x0'].to_numpy(), df['x1'].to_numpy()
y = (x0 * 10 - 20) + (x1 - 2)
dtrain = xgb.DMatrix(df, label=y, enable_categorical=True)

params = {'tree_method': 'gpu_hist', 'predictor': 'gpu_predictor',
'enable_experimental_json_serialization': True,
'max_depth': 3, 'learning_rate': 1.0, 'base_score': 0.0, 'eval_metric': 'rmse'}

eval_history = {}
bst = xgb.train(params, dtrain, num_boost_round=5, evals=[(dtrain, 'train')],
verbose_eval=False, evals_result=eval_history)

pred = bst.predict(dtrain)
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)

0 comments on commit c31e3ef

Please sign in to comment.