diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index bcb98093be91..976b87d56ee8 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -97,8 +97,9 @@ class IterativeDMatrix : public DMatrix { batch_param_ = BatchParam{d, max_bin}; batch_param_.sparse_thresh = 0.2; // default from TrainParam - ctx_.UpdateAllowUnknown(Args{{"nthread", std::to_string(nthread)}}); - if (d == Context::kCpuId) { + ctx_.UpdateAllowUnknown( + Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}}); + if (ctx_.IsCPU()) { this->InitFromCPU(iter_handle, missing, ref); } else { this->InitFromCUDA(iter_handle, missing, ref);