Skip to content

Commit

Permalink
Merge dispatching into median.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 13, 2022
1 parent 070e6f6 commit 48112fe
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 21 deletions.
13 changes: 8 additions & 5 deletions src/common/stats.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
* Copyright 2022 by XGBoost Contributors
*/

#include "common.h"
#include "stats.cuh"
#include "xgboost/generic_parameters.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/linalg.h"
#include <thrust/iterator/counting_iterator.h> // thrust::make_counting_iterator

#include "common.h" // common::OptionalWeights
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
#include "xgboost/generic_parameters.h" // Context
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/linalg.h" // linalg::TensorView, UnravelIndex, Apply

namespace xgboost {
namespace common {
Expand Down
26 changes: 16 additions & 10 deletions src/common/stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,29 @@ inline float Median(Context const*, linalg::TensorView<float const, 2>, common::
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda

inline float Median(Context const* ctx, linalg::TensorView<float const, 2> t,
common::OptionalWeights weights) {
inline float Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
HostDeviceVector<float> const& weights) {
if (!ctx->IsCPU()) {
return cuda::Median(ctx, t, weights);
weights.SetDevice(ctx->gpu_id);
auto opt_weights = OptionalWeights(weights.ConstDeviceSpan());
auto t_v = t.View(ctx->gpu_id);
return cuda::Median(ctx, t_v, opt_weights);
}

auto opt_weights = OptionalWeights(weights.ConstHostSpan());
auto t_v = t.HostView();
auto iter = common::MakeIndexTransformIter(
[&](size_t i) { return linalg::detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); });
[&](size_t i) { return linalg::detail::Apply(t_v, linalg::UnravelIndex(i, t_v.Shape())); });
float q{0};
if (weights.weights.empty()) {
q = common::Quantile(0.5, iter, iter + t.Size());
if (opt_weights.Empty()) {
q = common::Quantile(0.5, iter, iter + t_v.Size());
} else {
CHECK_NE(t.Shape(1), 0);
CHECK_NE(t_v.Shape(1), 0);
auto w_it = common::MakeIndexTransformIter([&](size_t i) {
auto sample_idx = i / t.Shape(1);
return weights[sample_idx];
auto sample_idx = i / t_v.Shape(1);
return opt_weights[sample_idx];
});
q = common::WeightedQuantile(0.5, iter, iter + t.Size(), w_it);
q = common::WeightedQuantile(0.5, iter, iter + t_v.Size(), w_it);
}
return q;
}
Expand Down
7 changes: 1 addition & 6 deletions src/objective/regression_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -711,13 +711,8 @@ class MeanAbsoluteError : public ObjFunction {
if (info.num_row_ == 0) {
out(0) = 0;
invalid++;
} else if (ctx_->IsCPU()) {
out(0) = common::Median(ctx_, info.labels.HostView(),
common::OptionalWeights{info.weights_.ConstHostSpan()});
} else {
info.weights_.SetDevice(ctx_->gpu_id);
out(0) = common::Median(ctx_, info.labels.View(ctx_->gpu_id),
common::OptionalWeights{info.weights_.DeviceSpan()});
out(0) = common::Median(ctx_, info.labels, info.weights_);
}

auto world = static_cast<float>(rabit::GetWorldSize());
Expand Down

0 comments on commit 48112fe

Please sign in to comment.