diff --git a/src/data/array_interface.h b/src/data/array_interface.h index b7ca311439ae..6524f4512407 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -210,27 +210,28 @@ class ArrayInterfaceHandler { } static void ExtractStride(std::map const &column, - size_t strides[2], size_t rows, size_t cols, size_t itemsize) { + size_t *stride_r, size_t *stride_c, size_t rows, + size_t cols, size_t itemsize) { auto strides_it = column.find("strides"); if (strides_it == column.cend() || IsA(strides_it->second)) { // default strides - strides[0] = cols; - strides[1] = 1; + *stride_r = cols; + *stride_c = 1; } else { // strides specified by the array interface auto const &j_strides = get(strides_it->second); CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2); - strides[0] = get(j_strides[0]) / itemsize; + *stride_r = get(j_strides[0]) / itemsize; size_t n = 1; if (j_strides.size() == 2) { n = get(j_strides[1]) / itemsize; } - strides[1] = n; + *stride_c = n; } - auto valid = rows * strides[0] + cols * strides[1] >= (rows * cols); + auto valid = rows * (*stride_r) + cols * (*stride_c) >= (rows * cols); CHECK(valid) << "Invalid strides in array." - << " strides: (" << strides[0] << "," << strides[1] + << " strides: (" << (*stride_r) << "," << (*stride_c) << "), shape: (" << rows << ", " << cols << ")"; } @@ -281,8 +282,8 @@ class ArrayInterface { << "Masked array is not yet supported."; } - ArrayInterfaceHandler::ExtractStride(array, strides, num_rows, num_cols, - typestr[2] - '0'); + ArrayInterfaceHandler::ExtractStride(array, &stride_row, &stride_col, + num_rows, num_cols, typestr[2] - '0'); auto stream_it = array.find("stream"); if (stream_it != array.cend() && !IsA(stream_it->second)) { @@ -323,8 +324,8 @@ class ArrayInterface { num_rows = std::max(num_rows, static_cast(num_cols)); num_cols = 1; - strides[0] = std::max(strides[0], strides[1]); - strides[1] = 1; + stride_row = std::max(stride_row, stride_col); + stride_col = 1; } void AssignType(StringView typestr) { @@ -406,13 +407,14 @@ class ArrayInterface { template XGBOOST_DEVICE T GetElement(size_t r, size_t c) const { return this->DispatchCall( - [=](auto *p_values) -> T { return p_values[strides[0] * r + strides[1] * c]; }); + [=](auto *p_values) -> T { return p_values[stride_row * r + stride_col * c]; }); } RBitField8 valid; bst_row_t num_rows; bst_feature_t num_cols; - size_t strides[2]{0, 0}; + size_t stride_row{0}; + size_t stride_col{0}; void* data; Type type; }; diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 52086378ad5b..1b5ad266d3d9 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -29,6 +29,9 @@ def set_base_margin_info(DType, DMatrixT, tm: str): with pytest.raises(ValueError, match=r".*base_margin.*"): xgb.train({"tree_method": tm}, Xy) + # FIXME(jiamingy): Currently the metainfo has no concept of shape. If you pass a + # base_margin with shape (n_classes, n_samples) to XGBoost the result is undefined. + class TestDMatrix: def test_warn_missing(self):