Skip to content

Commit

Permalink
Support categorical data for dask functional interface and DQM.
Browse files Browse the repository at this point in the history
* Implement categorical data support for GPU GK-merge.
* Add support for dask functional interface.
* Add support for DQM.
  • Loading branch information
trivialfis committed Jun 18, 2021
1 parent dcd84b3 commit 8353b6e
Show file tree
Hide file tree
Showing 15 changed files with 363 additions and 166 deletions.
13 changes: 4 additions & 9 deletions python-package/xgboost/core.py
Expand Up @@ -321,6 +321,7 @@ class DataIter:
def __init__(self):
self._handle = _ProxyDMatrix()
self.exception = None
self.enable_categorical = False

@property
def proxy(self):
Expand All @@ -346,13 +347,12 @@ def data_handle(
data,
feature_names=None,
feature_types=None,
enable_categorical=False,
**kwargs
):
from .data import dispatch_device_quantile_dmatrix_set_data
from .data import _device_quantile_transform
data, feature_names, feature_types = _device_quantile_transform(
data, feature_names, feature_types, enable_categorical,
data, feature_names, feature_types, self.enable_categorical,
)
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
self.proxy.set_info(
Expand Down Expand Up @@ -1106,15 +1106,10 @@ def _init(self, data, enable_categorical, **meta):
data = _transform_dlpack(data)
if _is_iter(data):
it = data
if enable_categorical:
raise NotImplementedError(
"categorical support is not enabled on data iterator."
)
else:
it = SingleBatchInternalIter(
data=data, enable_categorical=enable_categorical, **meta
)
it = SingleBatchInternalIter(data=data, **meta)

it.enable_categorical = enable_categorical
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
next_callback = ctypes.CFUNCTYPE(
ctypes.c_int,
Expand Down
24 changes: 15 additions & 9 deletions python-package/xgboost/dask.py
Expand Up @@ -182,7 +182,7 @@ def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
lazy_isinstance(value[0], 'cudf.core.series', 'Series'):
from cudf import concat as CUDF_concat # pylint: disable=import-error
return CUDF_concat(value, axis=0)
if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'):
if lazy_isinstance(value[0], 'cupy._core.core', 'ndarray'):
import cupy
# pylint: disable=c-extension-no-member,no-member
d = cupy.cuda.runtime.getDevice()
Expand Down Expand Up @@ -258,17 +258,14 @@ def __init__(
self.feature_names = feature_names
self.feature_types = feature_types
self.missing = missing
self.enable_categorical = enable_categorical

if qid is not None and weight is not None:
raise NotImplementedError("per-group weight is not implemented.")
if group is not None:
raise NotImplementedError(
"group structure is not implemented, use qid instead."
)
if enable_categorical:
raise NotImplementedError(
"categorical support is not enabled on `DaskDMatrix`."
)

if len(data.shape) != 2:
raise ValueError(
Expand Down Expand Up @@ -311,7 +308,7 @@ async def _map_local_data(
qid: Optional[_DaskCollection] = None,
feature_weights: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None,
label_upper_bound: Optional[_DaskCollection] = None
label_upper_bound: Optional[_DaskCollection] = None,
) -> "DaskDMatrix":
'''Obtain references to local data.'''

Expand Down Expand Up @@ -430,6 +427,7 @@ def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
'feature_weights': self.feature_weights,
'meta_names': self.meta_names,
'missing': self.missing,
'enable_categorical': self.enable_categorical,
'parts': self.worker_map.get(worker_addr, None),
'is_quantile': self.is_quantile}

Expand Down Expand Up @@ -668,6 +666,7 @@ def _create_device_quantile_dmatrix(
missing: float,
parts: Optional[_DataParts],
max_bin: int,
enable_categorical: bool,
) -> DeviceQuantileDMatrix:
worker = distributed.get_worker()
if parts is None:
Expand All @@ -680,6 +679,7 @@ def _create_device_quantile_dmatrix(
feature_names=feature_names,
feature_types=feature_types,
max_bin=max_bin,
enable_categorical=enable_categorical,
)
return d

Expand Down Expand Up @@ -709,6 +709,7 @@ def _create_device_quantile_dmatrix(
feature_types=feature_types,
nthread=worker.nthreads,
max_bin=max_bin,
enable_categorical=enable_categorical,
)
dmatrix.set_info(feature_weights=feature_weights)
return dmatrix
Expand All @@ -720,6 +721,7 @@ def _create_dmatrix(
feature_weights: Optional[Any],
meta_names: List[str],
missing: float,
enable_categorical: bool,
parts: Optional[_DataParts]
) -> DMatrix:
'''Get data that local to worker from DaskDMatrix.
Expand All @@ -734,9 +736,12 @@ def _create_dmatrix(
if list_of_parts is None:
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address)
LOGGER.warning(msg)
d = DMatrix(numpy.empty((0, 0)),
feature_names=feature_names,
feature_types=feature_types)
d = DMatrix(
numpy.empty((0, 0)),
feature_names=feature_names,
feature_types=feature_types,
enable_categorical=enable_categorical,
)
return d

T = TypeVar('T')
Expand Down Expand Up @@ -764,6 +769,7 @@ def concat_or_none(data: Tuple[Optional[T], ...]) -> Optional[T]:
feature_names=feature_names,
feature_types=feature_types,
nthread=worker.nthreads,
enable_categorical=enable_categorical,
)
dmatrix.set_info(
base_margin=_base_margin,
Expand Down
10 changes: 6 additions & 4 deletions src/common/device_helpers.cuh
Expand Up @@ -1151,12 +1151,12 @@ struct SegmentedUniqueReduceOp {
* \return Number of unique values in total.
*/
template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt,
typename ValOutIt, typename Comp>
typename ValOutIt, typename CompValue, typename CompKey>
size_t
SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
Comp comp) {
CompValue comp, CompKey comp_key=thrust::equal_to<size_t>{}) {
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
auto unique_key_it = dh::MakeTransformIterator<Key>(
thrust::make_counting_iterator(static_cast<size_t>(0)),
Expand All @@ -1177,7 +1177,7 @@ SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec
exec, unique_key_it, unique_key_it + n_inputs,
val_first, reduce_it, val_out,
[=] __device__(Key const &l, Key const &r) {
if (l.first == r.first) {
if (comp_key(l.first, r.first)) {
// In the same segment.
return comp(l.second, r.second);
}
Expand All @@ -1195,7 +1195,9 @@ template <typename... Inputs,
* = nullptr>
size_t SegmentedUnique(Inputs &&...inputs) {
dh::XGBCachingDeviceAllocator<char> alloc;
return SegmentedUnique(thrust::cuda::par(alloc), std::forward<Inputs&&>(inputs)...);
return SegmentedUnique(thrust::cuda::par(alloc),
std::forward<Inputs &&>(inputs)...,
thrust::equal_to<size_t>{});
}

/**
Expand Down
104 changes: 51 additions & 53 deletions src/common/hist_util.cu
Expand Up @@ -129,60 +129,52 @@ void SortByWeight(dh::device_vector<float>* weights,
});
}

struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
};

void RemoveDuplicatedCategories(
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
dh::device_vector<Entry> *p_sorted_entries,
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
dh::caching_device_vector<size_t> *p_column_sizes_scan) {
auto d_feature_types = info.feature_types.ConstDeviceSpan();
auto& column_sizes_scan = *p_column_sizes_scan;
if (!info.feature_types.Empty() &&
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
IsCatOp{})) {
auto& sorted_entries = *p_sorted_entries;
// Removing duplicated entries in categorical features.
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
dh::SegmentedUnique(
column_sizes_scan.data().get(),
column_sizes_scan.data().get() + column_sizes_scan.size(),
sorted_entries.begin(), sorted_entries.end(),
new_column_scan.data().get(), sorted_entries.begin(),
[=] __device__(Entry const &l, Entry const &r) {
if (l.index == r.index) {
if (IsCat(d_feature_types, l.index)) {
return l.fvalue == r.fvalue;
}
}
return false;
});
CHECK(!d_feature_types.empty());
auto &column_sizes_scan = *p_column_sizes_scan;
auto &sorted_entries = *p_sorted_entries;
// Removing duplicated entries in categorical features.
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
dh::SegmentedUnique(column_sizes_scan.data().get(),
column_sizes_scan.data().get() + column_sizes_scan.size(),
sorted_entries.begin(), sorted_entries.end(),
new_column_scan.data().get(), sorted_entries.begin(),
[=] __device__(Entry const &l, Entry const &r) {
if (l.index == r.index) {
if (IsCat(d_feature_types, l.index)) {
return l.fvalue == r.fvalue;
}
}
return false;
});

// Renew the column scan and cut scan based on categorical data.
auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan);
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(
info.num_col_ + 1);
auto d_new_cuts_size = dh::ToSpan(new_cuts_size);
auto d_new_columns_ptr = dh::ToSpan(new_column_scan);
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) {
d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx];
if (idx == d_new_columns_ptr.size() - 1) {
return;
}
if (IsCat(d_feature_types, idx)) {
// Cut size is the same as number of categories in input.
d_new_cuts_size[idx] =
d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
} else {
d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx];
}
});
// Turn size into ptr.
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
new_cuts_size.cend(), d_cuts_ptr.data());
}
// Renew the column scan and cut scan based on categorical data.
auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan);
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(
info.num_col_ + 1);
auto d_new_cuts_size = dh::ToSpan(new_cuts_size);
auto d_new_columns_ptr = dh::ToSpan(new_column_scan);
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) {
d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx];
if (idx == d_new_columns_ptr.size() - 1) {
return;
}
if (IsCat(d_feature_types, idx)) {
// Cut size is the same as number of categories in input.
d_new_cuts_size[idx] =
d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
} else {
d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx];
}
});
// Turn size into ptr.
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
new_cuts_size.cend(), d_cuts_ptr.data());
}
} // namespace detail

Expand Down Expand Up @@ -215,8 +207,11 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
&column_sizes_scan);

if (sketch_container->HasCategorical()) {
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
&sorted_entries, &column_sizes_scan);
}

auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
Expand Down Expand Up @@ -281,8 +276,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
&column_sizes_scan);
if (sketch_container->HasCategorical()) {
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
&sorted_entries, &column_sizes_scan);
}

auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();

// Extract cuts
Expand Down
39 changes: 11 additions & 28 deletions src/common/quantile.cu
Expand Up @@ -210,6 +210,7 @@ void MergeImpl(int32_t device, Span<SketchEntry const> const &d_x,
Span<bst_row_t const> const &x_ptr,
Span<SketchEntry const> const &d_y,
Span<bst_row_t const> const &y_ptr,
Span<FeatureType const> feature_types,
Span<SketchEntry> out,
Span<bst_row_t> out_ptr) {
dh::safe_cuda(cudaSetDevice(device));
Expand Down Expand Up @@ -408,31 +409,6 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
return n_uniques;
}

size_t SketchContainer::Unique() {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
this->columns_ptr_.SetDevice(device_);
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
Span<SketchEntry> entries = dh::ToSpan(this->Current());
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
scan_out.SetDevice(device_);
auto d_scan_out = scan_out.DeviceSpan();

d_column_scan = this->columns_ptr_.DeviceSpan();
size_t n_uniques = dh::SegmentedUnique(
d_column_scan.data(), d_column_scan.data() + d_column_scan.size(),
entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(),
detail::SketchUnique{});
this->columns_ptr_.Copy(scan_out);
CHECK(!this->columns_ptr_.HostCanRead());

this->Current().resize(n_uniques);
timer_.Stop(__func__);
return n_uniques;
}

void SketchContainer::Prune(size_t to) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
Expand Down Expand Up @@ -490,13 +466,20 @@ void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
this->Other().resize(this->Current().size() + that.size());
CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size());

MergeImpl(device_, this->Data(), this->ColumnsPtr(),
that, d_that_columns_ptr,
dh::ToSpan(this->Other()), columns_ptr_b_.DeviceSpan());
auto feature_types = this->FeatureTypes().ConstDeviceSpan();
MergeImpl(device_, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr,
feature_types, dh::ToSpan(this->Other()),
columns_ptr_b_.DeviceSpan());
this->columns_ptr_.Copy(columns_ptr_b_);
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
this->Alternate();

if (this->HasCategorical()) {
auto d_feature_types = this->FeatureTypes().ConstDeviceSpan();
this->Unique([d_feature_types] __device__(size_t l_fidx, size_t r_fidx) {
return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx);
});
}
timer_.Stop(__func__);
}

Expand Down

0 comments on commit 8353b6e

Please sign in to comment.