Skip to content

Commit

Permalink
Add back support for scipy.sparse.coo_matrix (#6162)
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Sep 25, 2020
1 parent 72ef553 commit bd2b1ea
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python-package/xgboost/data.py
Expand Up @@ -82,6 +82,15 @@ def _from_scipy_csc(data, missing, feature_names, feature_types):
return handle, feature_names, feature_types


def _is_scipy_coo(data):
try:
import scipy
except ImportError:
scipy = None
return False
return isinstance(data, scipy.sparse.coo_matrix)


def _is_numpy_array(data):
return isinstance(data, (np.ndarray, np.matrix))

Expand Down Expand Up @@ -504,6 +513,8 @@ def dispatch_data_backend(data, missing, threads,
return _from_scipy_csr(data, missing, feature_names, feature_types)
if _is_scipy_csc(data):
return _from_scipy_csc(data, missing, feature_names, feature_types)
if _is_scipy_coo(data):
return _from_scipy_csr(data.tocsr(), missing, feature_names, feature_types)
if _is_numpy_array(data):
return _from_numpy_array(data, missing, threads, feature_names,
feature_types)
Expand Down
9 changes: 9 additions & 0 deletions tests/python/test_dmatrix.py
Expand Up @@ -76,6 +76,15 @@ def test_csc(self):
assert dtrain.num_row() == 3
assert dtrain.num_col() == 3

def test_coo(self):
row = np.array([0, 2, 2, 0, 1, 2])
col = np.array([0, 0, 1, 2, 2, 2])
data = np.array([1, 2, 3, 4, 5, 6])
X = scipy.sparse.coo_matrix((data, (row, col)), shape=(3, 3))
dtrain = xgb.DMatrix(X)
assert dtrain.num_row() == 3
assert dtrain.num_col() == 3

def test_np_view(self):
# Sliced Float32 array
y = np.array([12, 34, 56], np.float32)[::2]
Expand Down

0 comments on commit bd2b1ea

Please sign in to comment.