Skip to content

Commit

Permalink
Try to avoid using array.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 29, 2021
1 parent b10f886 commit 7f4b240
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
28 changes: 15 additions & 13 deletions src/data/array_interface.h
Expand Up @@ -210,27 +210,28 @@ class ArrayInterfaceHandler {
}

static void ExtractStride(std::map<std::string, Json> const &column,
size_t strides[2], size_t rows, size_t cols, size_t itemsize) {
size_t *stride_r, size_t *stride_c, size_t rows,
size_t cols, size_t itemsize) {
auto strides_it = column.find("strides");
if (strides_it == column.cend() || IsA<Null>(strides_it->second)) {
// default strides
strides[0] = cols;
strides[1] = 1;
*stride_r = cols;
*stride_c = 1;
} else {
// strides specified by the array interface
auto const &j_strides = get<Array const>(strides_it->second);
CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2);
strides[0] = get<Integer const>(j_strides[0]) / itemsize;
*stride_r = get<Integer const>(j_strides[0]) / itemsize;
size_t n = 1;
if (j_strides.size() == 2) {
n = get<Integer const>(j_strides[1]) / itemsize;
}
strides[1] = n;
*stride_c = n;
}

auto valid = rows * strides[0] + cols * strides[1] >= (rows * cols);
auto valid = rows * (*stride_r) + cols * (*stride_c) >= (rows * cols);
CHECK(valid) << "Invalid strides in array."
<< " strides: (" << strides[0] << "," << strides[1]
<< " strides: (" << (*stride_r) << "," << (*stride_c)
<< "), shape: (" << rows << ", " << cols << ")";
}

Expand Down Expand Up @@ -281,8 +282,8 @@ class ArrayInterface {
<< "Masked array is not yet supported.";
}

ArrayInterfaceHandler::ExtractStride(array, strides, num_rows, num_cols,
typestr[2] - '0');
ArrayInterfaceHandler::ExtractStride(array, &stride_row, &stride_col,
num_rows, num_cols, typestr[2] - '0');

auto stream_it = array.find("stream");
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
Expand Down Expand Up @@ -323,8 +324,8 @@ class ArrayInterface {
num_rows = std::max(num_rows, static_cast<size_t>(num_cols));
num_cols = 1;

strides[0] = std::max(strides[0], strides[1]);
strides[1] = 1;
stride_row = std::max(stride_row, stride_col);
stride_col = 1;
}

void AssignType(StringView typestr) {
Expand Down Expand Up @@ -406,13 +407,14 @@ class ArrayInterface {
template <typename T = float>
XGBOOST_DEVICE T GetElement(size_t r, size_t c) const {
return this->DispatchCall(
[=](auto *p_values) -> T { return p_values[strides[0] * r + strides[1] * c]; });
[=](auto *p_values) -> T { return p_values[stride_row * r + stride_col * c]; });
}

RBitField8 valid;
bst_row_t num_rows;
bst_feature_t num_cols;
size_t strides[2]{0, 0};
size_t stride_row{0};
size_t stride_col{0};
void* data;
Type type;
};
Expand Down
3 changes: 3 additions & 0 deletions tests/python/test_dmatrix.py
Expand Up @@ -29,6 +29,9 @@ def set_base_margin_info(DType, DMatrixT, tm: str):
with pytest.raises(ValueError, match=r".*base_margin.*"):
xgb.train({"tree_method": tm}, Xy)

# FIXME(jiamingy): Currently the metainfo has no concept of shape. If you pass a
# base_margin with shape (n_classes, n_samples) to XGBoost the result is undefined.


class TestDMatrix:
def test_warn_missing(self):
Expand Down

0 comments on commit 7f4b240

Please sign in to comment.