Skip to content

Commit

Permalink
XGBoost: Fix type mismatch for CSR conversion in c_api (#3194)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Oct 27, 2022
1 parent f48893c commit ee77f6c
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 2 deletions.
3 changes: 3 additions & 0 deletions docs/project/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ substitutions:
- {{ Fix }} Shared libraries with version suffix are now handled correctly.
{pr}`3154`

- {{ Fix }} Scipy CSR data is now handled correctly in XGBoost.
{pr}`3194`

- Added a new CLI command `pyodide sekeleton` which creates a package build recipe.
`pyodide-build mkpkg` will be replaced by `pyodide sekeleton pypi`.
{pr}`3175`
Expand Down
1 change: 1 addition & 0 deletions packages/xgboost/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ source:
patches:
- patches/0001-Add-missing-template-type.patch
- patches/0002-Add-library-loading-path.patch
- patches/0003-Fix-type-mismatch-for-CSR-conversion-in-c_api.patch
build:
cflags: |
-DDMLC_USE_FOPEN64=0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
From 4ac9a00d9e16b0879b4e734a4b604c7ce672894e Mon Sep 17 00:00:00 2001
From: Gyeongjae Choi <def6488@gmail.com>
Date: Mon, 9 May 2022 06:42:07 +0000
Subject: [PATCH 1/2] Add missing template type
Subject: [PATCH 1/3] Add missing template type

TODO: Remove this patch when XGBoost version is updated.
(Upstream PR: https://github.com/dmlc/xgboost/pull/7954)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
From 54c2a9faeb0b0169172c5ab53367e6092f132c5a Mon Sep 17 00:00:00 2001
From: Gyeongjae Choi <def6488@gmail.com>
Date: Mon, 9 May 2022 12:07:44 +0000
Subject: [PATCH 2/2] Add library loading path
Subject: [PATCH 2/3] Add library loading path

TODO: Remove this patch when XGBoost version is updated.
(Upstream PR: https://github.com/dmlc/xgboost/pull/7954)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
From 4ec1b506b424dd9e81fd7127f5712522800a5596 Mon Sep 17 00:00:00 2001
From: Yizhi Liu <liuyizhi@apache.org>
Date: Mon, 17 Oct 2022 15:16:45 -0700
Subject: [PATCH 3/3] Fix type mismatch for CSR conversion in c_api

TODO: Remove this patch when XGBoost version is updated.
(Upstream PR: https://github.com/dmlc/xgboost/pull/8369)

---
xgboost/core.py | 2 +-
xgboost/data.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/xgboost/core.py b/xgboost/core.py
index 36548d8..0246779 100644
--- a/xgboost/core.py
+++ b/xgboost/core.py
@@ -2119,7 +2119,7 @@ class Booster:
_array_interface(csr.indptr),
_array_interface(csr.indices),
_array_interface(csr.data),
- ctypes.c_size_t(csr.shape[1]),
+ c_bst_ulong(csr.shape[1]),
from_pystr_to_cstr(json.dumps(args)),
p_handle,
ctypes.byref(shape),
diff --git a/xgboost/data.py b/xgboost/data.py
index 119b354..b958436 100644
--- a/xgboost/data.py
+++ b/xgboost/data.py
@@ -88,7 +88,7 @@ def _from_scipy_csr(
_array_interface(data.indptr),
_array_interface(data.indices),
_array_interface(data.data),
- ctypes.c_size_t(data.shape[1]),
+ c_bst_ulong(data.shape[1]),
config,
ctypes.byref(handle),
)
--
2.35.1

18 changes: 18 additions & 0 deletions packages/xgboost/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,21 @@ def test_pandas_weight(selenium):
assert data.num_col() == kCols

np.testing.assert_array_equal(data.get_weight(), w)


@pytest.mark.driver_timeout(60)
@run_in_pyodide(packages=["xgboost", "numpy", "scipy"])
def test_scipy_sparse(selenium):
import numpy as np
import scipy
import xgboost as xgb

n_rows = 100
n_cols = 10
X = scipy.sparse.random(n_rows, n_cols, format="csr")
y = np.random.randn(n_rows)
dtrain = xgb.DMatrix(X, y)
booster = xgb.train({}, dtrain, num_boost_round=1)
copied_predt = booster.predict(xgb.DMatrix(X))
predt = booster.inplace_predict(X)
np.testing.assert_allclose(copied_predt, predt)

0 comments on commit ee77f6c

Please sign in to comment.