Skip to content

Commit

Permalink
Add convergence test for sparse datasets. (#7922)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 23, 2022
1 parent f6babc8 commit 474366c
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 1 deletion.
17 changes: 17 additions & 0 deletions tests/python-gpu/test_gpu_updaters.py
Expand Up @@ -53,6 +53,23 @@ def test_gpu_hist(self, param, num_rounds, dataset):
note(result)
assert tm.non_increasing(result["train"][dataset.metric])

@given(tm.sparse_datasets_strategy)
@settings(deadline=None, print_blob=True)
def test_sparse(self, dataset):
param = {"tree_method": "hist", "max_bin": 64}
hist_result = train_result(param, dataset.get_dmat(), 16)
note(hist_result)
assert tm.non_increasing(hist_result['train'][dataset.metric])

param = {"tree_method": "gpu_hist", "max_bin": 64}
gpu_hist_result = train_result(param, dataset.get_dmat(), 16)
note(gpu_hist_result)
assert tm.non_increasing(gpu_hist_result['train'][dataset.metric])

np.testing.assert_allclose(
hist_result["train"]["rmse"], gpu_hist_result["train"]["rmse"], rtol=1e-2
)

@given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None, print_blob=True)
Expand Down
17 changes: 17 additions & 0 deletions tests/python/test_updaters.py
Expand Up @@ -99,6 +99,23 @@ def test_hist(self, param, hist_param, num_rounds, dataset):
note(result)
assert tm.non_increasing(result['train'][dataset.metric])

@given(tm.sparse_datasets_strategy)
@settings(deadline=None, print_blob=True)
def test_sparse(self, dataset):
param = {"tree_method": "hist", "max_bin": 64}
hist_result = train_result(param, dataset.get_dmat(), 16)
note(hist_result)
assert tm.non_increasing(hist_result['train'][dataset.metric])

param = {"tree_method": "approx", "max_bin": 64}
approx_result = train_result(param, dataset.get_dmat(), 16)
note(approx_result)
assert tm.non_increasing(approx_result['train'][dataset.metric])

np.testing.assert_allclose(
hist_result["train"]["rmse"], approx_result["train"]["rmse"]
)

def test_hist_categorical(self):
# hist must be same as exact on all-categorial data
dpath = 'demo/data/'
Expand Down
126 changes: 125 additions & 1 deletion tests/python/testing.py
@@ -1,5 +1,7 @@
# coding: utf-8
from concurrent.futures import ThreadPoolExecutor
import os
import multiprocessing
from typing import Tuple, Union
import urllib
import zipfile
import sys
Expand All @@ -11,6 +13,7 @@
import gc
import xgboost as xgb
import numpy as np
from scipy import sparse
import platform

hypothesis = pytest.importorskip('hypothesis')
Expand Down Expand Up @@ -327,6 +330,127 @@ def make_categorical(
return df, label


@memory.cache
def make_sparse_regression(
n_samples: int, n_features: int, sparsity: float, as_dense: bool
) -> Tuple[Union[sparse.csr_matrix], np.ndarray]:
"""Make sparse matrix.
Parameters
----------
as_dense:
Return the matrix as np.ndarray with missing values filled by NaN
"""
if not hasattr(np.random, "default_rng"):
# old version of numpy on s390x
rng = np.random.RandomState(1994)
X = sparse.random(
m=n_samples,
n=n_features,
density=1.0 - sparsity,
random_state=rng,
format="csr",
)
y = rng.normal(loc=0.0, scale=1.0, size=n_samples)
return X, y

# Use multi-thread to speed up the generation, convenient if you use this function
# for benchmarking.
n_threads = multiprocessing.cpu_count()
n_threads = min(n_threads, n_features)

def random_csc(t_id: int) -> sparse.csc_matrix:
rng = np.random.default_rng(1994 * t_id)
thread_size = n_features // n_threads
if t_id == n_threads - 1:
n_features_tloc = n_features - t_id * thread_size
else:
n_features_tloc = thread_size

X = sparse.random(
m=n_samples,
n=n_features_tloc,
density=1.0 - sparsity,
random_state=rng,
).tocsc()
y = np.zeros((n_samples, 1))

for i in range(X.shape[1]):
size = X.indptr[i + 1] - X.indptr[i]
if size != 0:
y += X[:, i].toarray() * rng.random((n_samples, 1)) * 0.2

return X, y

futures = []
with ThreadPoolExecutor(max_workers=n_threads) as executor:
for i in range(n_threads):
futures.append(executor.submit(random_csc, i))

X_results = []
y_results = []
for f in futures:
X, y = f.result()
X_results.append(X)
y_results.append(y)

assert len(y_results) == n_threads

csr: sparse.csr_matrix = sparse.hstack(X_results, format="csr")
y = np.asarray(y_results)
y = y.reshape((y.shape[0], y.shape[1])).T
y = np.sum(y, axis=1)

assert csr.shape[0] == n_samples
assert csr.shape[1] == n_features
assert y.shape[0] == n_samples

if as_dense:
arr = csr.toarray()
arr[arr == 0] = np.nan
return arr, y

return csr, y


sparse_datasets_strategy = strategies.sampled_from(
[
TestDataset(
"1e5x8-0.95-csr",
lambda: make_sparse_regression(int(1e5), 8, 0.95, False),
"reg:squarederror",
"rmse",
),
TestDataset(
"1e5x8-0.5-csr",
lambda: make_sparse_regression(int(1e5), 8, 0.5, False),
"reg:squarederror",
"rmse",
),
TestDataset(
"1e5x8-0.5-dense",
lambda: make_sparse_regression(int(1e5), 8, 0.5, True),
"reg:squarederror",
"rmse",
),
TestDataset(
"1e5x8-0.05-csr",
lambda: make_sparse_regression(int(1e5), 8, 0.05, False),
"reg:squarederror",
"rmse",
),
TestDataset(
"1e5x8-0.05-dense",
lambda: make_sparse_regression(int(1e5), 8, 0.05, True),
"reg:squarederror",
"rmse",
),
]
)

_unweighted_datasets_strategy = strategies.sampled_from(
[
TestDataset(
Expand Down

0 comments on commit 474366c

Please sign in to comment.