Skip to content

Commit

Permalink
Fix gpu_id with custom objective. (#7015)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 9, 2021
1 parent bd2ca54 commit 72f9daf
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/gbm/gbtree.cc
Expand Up @@ -228,7 +228,7 @@ void GBTree::DoBoost(DMatrix* p_fmat,
// break a lots of existing code.
auto device = tparam_.tree_method != TreeMethod::kGPUHist
? GenericParameter::kCpuId
: in_gpair->DeviceIdx();
: generic_param_->gpu_id;
auto out = MatrixView<float>(
&predt->predictions,
{static_cast<size_t>(p_fmat->Info().num_row_), static_cast<size_t>(ngroup)}, device);
Expand Down
11 changes: 8 additions & 3 deletions tests/python-gpu/test_gpu_basic_models.py
Expand Up @@ -6,11 +6,13 @@
sys.path.append("tests/python")
# Don't import the test class, otherwise they will run twice.
import test_callback as test_cb # noqa
import test_basic_models as test_bm
rng = np.random.RandomState(1994)


class TestGPUBasicModels:
cputest = test_cb.TestCallbacks()
cpu_test_cb = test_cb.TestCallbacks()
cpu_test_bm = test_bm.TestModels()

def run_cls(self, X, y, deterministic):
cls = xgb.XGBClassifier(tree_method='gpu_hist',
Expand All @@ -35,9 +37,12 @@ def run_cls(self, X, y, deterministic):

return hash(model_0), hash(model_1)

def test_custom_objective(self):
self.cpu_test_bm.run_custom_objective("gpu_hist")

def test_eta_decay_gpu_hist(self):
self.cputest.run_eta_decay('gpu_hist', True)
self.cputest.run_eta_decay('gpu_hist', False)
self.cpu_test_cb.run_eta_decay('gpu_hist', True)
self.cpu_test_cb.run_eta_decay('gpu_hist', False)

def test_deterministic_gpu_hist(self):
kRows = 1000
Expand Down
12 changes: 10 additions & 2 deletions tests/python/test_basic_models.py
Expand Up @@ -138,8 +138,13 @@ def test_boost_from_existing_model(self):
# behaviour is considered sub-optimal, feel free to change.
assert booster.num_boosted_rounds() == 4

def test_custom_objective(self):
param = {'max_depth': 2, 'eta': 1, 'objective': 'reg:logistic'}
def run_custom_objective(self, tree_method=None):
param = {
'max_depth': 2,
'eta': 1,
'objective': 'reg:logistic',
"tree_method": tree_method
}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 10

Expand Down Expand Up @@ -181,6 +186,9 @@ def neg_evalerror(preds, dtrain):
if int(preds2[i] > 0.5) != labels[i]) / float(len(preds2))
assert err == err2

def test_custom_objective(self):
self.run_custom_objective()

def test_multi_eval_metric(self):
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
param = {'max_depth': 2, 'eta': 0.2, 'verbosity': 1,
Expand Down

0 comments on commit 72f9daf

Please sign in to comment.