Skip to content

Commit

Permalink
Fix slice.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 31, 2020
1 parent 04eb3b4 commit 05692db
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 36 deletions.
37 changes: 5 additions & 32 deletions include/xgboost/data.h
Expand Up @@ -89,44 +89,17 @@ class MetaInfo {
* \brief Type of each feature. Automatically set when feature_type_names is specifed.
*/
HostDeviceVector<FeatureType> feature_types;

/*
* \brief Weight of each feature, used to define the probability of each feature being
* selected when using column sampling.
*/
HostDeviceVector<float> feature_weigths;

/*! \brief default constructor */
MetaInfo() = default;
MetaInfo(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo const& that) {
this->num_row_ = that.num_row_;
this->num_col_ = that.num_col_;
this->num_nonzero_ = that.num_nonzero_;

this->labels_.Resize(that.labels_.Size());
this->labels_.Copy(that.labels_);

this->group_ptr_ = that.group_ptr_;

this->weights_.Resize(that.weights_.Size());
this->weights_.Copy(that.weights_);

this->base_margin_.Resize(that.base_margin_.Size());
this->base_margin_.Copy(that.base_margin_);

this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size());
this->labels_lower_bound_.Copy(that.labels_lower_bound_);

this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size());
this->labels_upper_bound_.Copy(that.labels_upper_bound_);

this->feature_weigths.Resize(that.feature_weigths.Size());
this->feature_weigths.Copy(that.feature_weigths);

this->feature_names = that.feature_names;
this->feature_type_names = that.feature_type_names;
this->feature_types.Resize(that.feature_types.Size());
this->feature_types.Copy(that.feature_types);
return *this;
}
MetaInfo& operator=(MetaInfo const& that) = delete;

/*!
* \brief Validate all metainfo.
Expand Down
3 changes: 3 additions & 0 deletions src/data/data.cc
Expand Up @@ -293,6 +293,9 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
} else {
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
}

out.feature_weigths.Resize(this->feature_weigths.Size());
out.feature_weigths.Copy(this->feature_weigths);
return out;
}

Expand Down
15 changes: 11 additions & 4 deletions tests/python/test_dmatrix.py
Expand Up @@ -75,6 +75,9 @@ def test_slice(self):
X = rng.randn(100, 100)
y = rng.randint(low=0, high=3, size=100)
d = xgb.DMatrix(X, y)
fw = rng.randn(100)
d.feature_weights = fw

eval_res_0 = {}
booster = xgb.train(
{'num_class': 3, 'objective': 'multi:softprob'}, d,
Expand All @@ -85,13 +88,17 @@ def test_slice(self):
d.set_base_margin(predt)

ridxs = [1, 2, 3, 4, 5, 6]
d = d.slice(ridxs)
sliced_margin = d.get_float_info('base_margin')
sliced = d.slice(ridxs)
np.testing.assert_equal(sliced.feature_weights, d.feature_weights)

sliced = d.slice(ridxs)
sliced_margin = sliced.get_float_info('base_margin')
assert sliced_margin.shape[0] == len(ridxs) * 3

eval_res_1 = {}
xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, d,
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_1)
xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, sliced,
num_boost_round=2, evals=[(sliced, 'd')],
evals_result=eval_res_1)

eval_res_0 = eval_res_0['d']['merror']
eval_res_1 = eval_res_1['d']['merror']
Expand Down

0 comments on commit 05692db

Please sign in to comment.