From 0f5534d02009f49b27321d8b8057376df44e410a Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 20 May 2022 11:57:59 +0800 Subject: [PATCH 1/6] Test sparse dataset. --- tests/python-gpu/test_gpu_updaters.py | 17 ++++ tests/python/test_updaters.py | 17 ++++ tests/python/testing.py | 113 +++++++++++++++++++++++++- 3 files changed, 146 insertions(+), 1 deletion(-) diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 8748ddcbdf91..c6def8563355 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -53,6 +53,23 @@ def test_gpu_hist(self, param, num_rounds, dataset): note(result) assert tm.non_increasing(result["train"][dataset.metric]) + @given(tm.sparse_datasets_strategy) + @settings(deadline=None, print_blob=True) + def test_sparse(self, dataset): + param = {"tree_method": "hist", "max_bin": 64} + hist_result = train_result(param, dataset.get_dmat(), 16) + note(hist_result) + assert tm.non_increasing(hist_result['train'][dataset.metric]) + + param = {"tree_method": "gpu_hist", "max_bin": 64} + approx_result = train_result(param, dataset.get_dmat(), 16) + note(approx_result) + assert tm.non_increasing(approx_result['train'][dataset.metric]) + + np.testing.assert_allclose( + hist_result["train"]["rmse"], approx_result["train"]["rmse"], rtol=1e-2 + ) + @given(strategies.integers(10, 400), strategies.integers(3, 8), strategies.integers(1, 2), strategies.integers(4, 7)) @settings(deadline=None, print_blob=True) diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index fa02b009a0f5..251439cdfed6 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -99,6 +99,23 @@ def test_hist(self, param, hist_param, num_rounds, dataset): note(result) assert tm.non_increasing(result['train'][dataset.metric]) + @given(tm.sparse_datasets_strategy) + @settings(deadline=None, print_blob=True) + def test_sparse(self, dataset): + param = {"tree_method": "hist", "max_bin": 64} + hist_result = train_result(param, dataset.get_dmat(), 16) + note(hist_result) + assert tm.non_increasing(hist_result['train'][dataset.metric]) + + param = {"tree_method": "approx", "max_bin": 64} + approx_result = train_result(param, dataset.get_dmat(), 16) + note(approx_result) + assert tm.non_increasing(approx_result['train'][dataset.metric]) + + np.testing.assert_allclose( + hist_result["train"]["rmse"], approx_result["train"]["rmse"] + ) + def test_hist_categorical(self): # hist must be same as exact on all-categorial data dpath = 'demo/data/' diff --git a/tests/python/testing.py b/tests/python/testing.py index 29947f227f86..093fdb07c164 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -1,5 +1,7 @@ -# coding: utf-8 +from concurrent.futures import ThreadPoolExecutor import os +import psutil +from typing import Tuple, Union import urllib import zipfile import sys @@ -11,6 +13,7 @@ import gc import xgboost as xgb import numpy as np +from scipy import sparse import platform hypothesis = pytest.importorskip('hypothesis') @@ -327,6 +330,114 @@ def make_categorical( return df, label +@memory.cache +def make_sparse_regression( + n_samples: int, n_features: int, sparsity: float, as_dense: bool +) -> Tuple[Union[sparse.csr_matrix], np.ndarray]: + """Make sparse matrix. + + Parameters + ---------- + + as_dense: + + Return the matrix as np.ndarray with missing values filled by NaN + + """ + # Use multi-thread to speed up the generation, convenient if you use this function + # for benchmarking. + n_threads = psutil.cpu_count(logical=False) + n_threads = min(n_threads, n_features) + + def random_csc(t_id: int) -> sparse.csc_matrix: + rng = np.random.default_rng(1994 * t_id) + thread_size = n_features // n_threads + if t_id == n_threads - 1: + n_features_tloc = n_features - t_id * thread_size + else: + n_features_tloc = thread_size + + X = sparse.random( + m=n_samples, + n=n_features_tloc, + density=1.0 - sparsity, + random_state=rng, + ).tocsc() + y = np.zeros((n_samples, 1)) + + for i in range(X.shape[1]): + size = X.indptr[i + 1] - X.indptr[i] + if size != 0: + y += X[:, i].toarray() * rng.random((n_samples, 1)) * 0.2 + + return X, y + + futures = [] + with ThreadPoolExecutor(max_workers=n_threads) as executor: + for i in range(n_threads): + futures.append(executor.submit(random_csc, i)) + + X_results = [] + y_results = [] + for f in futures: + X, y = f.result() + X_results.append(X) + y_results.append(y) + + assert len(y_results) == n_threads + + csr: sparse.csr_matrix = sparse.hstack(X_results, format="csr") + y = np.asarray(y_results) + y = y.reshape((y.shape[0], y.shape[1])).T + y = np.sum(y, axis=1) + + assert csr.shape[0] == n_samples + assert csr.shape[1] == n_features + assert y.shape[0] == n_samples + + if as_dense: + arr = csr.toarray() + arr[arr == 0] = np.nan + return arr, y + + return csr, y + + +sparse_datasets_strategy = strategies.sampled_from( + [ + TestDataset( + "1e5x8-0.95-csr", + lambda: make_sparse_regression(int(1e5), 8, 0.95, False), + "reg:squarederror", + "rmse", + ), + TestDataset( + "1e5x8-0.5-csr", + lambda: make_sparse_regression(int(1e5), 8, 0.5, False), + "reg:squarederror", + "rmse", + ), + TestDataset( + "1e5x8-0.5-dense", + lambda: make_sparse_regression(int(1e5), 8, 0.5, True), + "reg:squarederror", + "rmse", + ), + TestDataset( + "1e5x8-0.05-csr", + lambda: make_sparse_regression(int(1e5), 8, 0.05, False), + "reg:squarederror", + "rmse", + ), + TestDataset( + "1e5x8-0.05-dense", + lambda: make_sparse_regression(int(1e5), 8, 0.05, True), + "reg:squarederror", + "rmse", + ), + ] +) + _unweighted_datasets_strategy = strategies.sampled_from( [ TestDataset( From c1fafaa5f70edff87c737371d6db4d8117a54d50 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 20 May 2022 12:05:41 +0800 Subject: [PATCH 2/6] Use mp instead. --- tests/python/testing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/testing.py b/tests/python/testing.py index 093fdb07c164..6be0a1d5efce 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -1,6 +1,6 @@ from concurrent.futures import ThreadPoolExecutor import os -import psutil +import multiprocessing from typing import Tuple, Union import urllib import zipfile @@ -346,7 +346,7 @@ def make_sparse_regression( """ # Use multi-thread to speed up the generation, convenient if you use this function # for benchmarking. - n_threads = psutil.cpu_count(logical=False) + n_threads = multiprocessing.cpu_count() n_threads = min(n_threads, n_features) def random_csc(t_id: int) -> sparse.csc_matrix: From 82ed1b62269bd793c0a7cbf7eea27524e9d76ea8 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 23 May 2022 16:35:38 +0800 Subject: [PATCH 3/6] Fix var name. --- tests/python-gpu/test_gpu_updaters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index c6def8563355..860b459294c2 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -62,12 +62,12 @@ def test_sparse(self, dataset): assert tm.non_increasing(hist_result['train'][dataset.metric]) param = {"tree_method": "gpu_hist", "max_bin": 64} - approx_result = train_result(param, dataset.get_dmat(), 16) - note(approx_result) - assert tm.non_increasing(approx_result['train'][dataset.metric]) + gpu_hist_result = train_result(param, dataset.get_dmat(), 16) + note(gpu_hist_result) + assert tm.non_increasing(gpu_hist_result['train'][dataset.metric]) np.testing.assert_allclose( - hist_result["train"]["rmse"], approx_result["train"]["rmse"], rtol=1e-2 + hist_result["train"]["rmse"], gpu_hist_result["train"]["rmse"], rtol=1e-2 ) @given(strategies.integers(10, 400), strategies.integers(3, 8), From e91ecb4dfde8484aeb4315ef96a7d27bdb279e12 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 23 May 2022 16:54:28 +0800 Subject: [PATCH 4/6] Old numpy version on s390x --- tests/python/testing.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python/testing.py b/tests/python/testing.py index 6be0a1d5efce..a006f7eeac0e 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -350,7 +350,10 @@ def make_sparse_regression( n_threads = min(n_threads, n_features) def random_csc(t_id: int) -> sparse.csc_matrix: - rng = np.random.default_rng(1994 * t_id) + if hasattr(np.random, "default_rng"): + rng = np.random.default_rng(1994 * t_id) + else: + rng = np.random.RandomState(1994 * t_id) thread_size = n_features // n_threads if t_id == n_threads - 1: n_features_tloc = n_features - t_id * thread_size From de6da13cb30ac6ca42b77f4b68e409c8401904db Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 23 May 2022 17:00:27 +0800 Subject: [PATCH 5/6] Revert "Old numpy version on s390x" This reverts commit e91ecb4dfde8484aeb4315ef96a7d27bdb279e12. --- tests/python/testing.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/python/testing.py b/tests/python/testing.py index a006f7eeac0e..6be0a1d5efce 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -350,10 +350,7 @@ def make_sparse_regression( n_threads = min(n_threads, n_features) def random_csc(t_id: int) -> sparse.csc_matrix: - if hasattr(np.random, "default_rng"): - rng = np.random.default_rng(1994 * t_id) - else: - rng = np.random.RandomState(1994 * t_id) + rng = np.random.default_rng(1994 * t_id) thread_size = n_features // n_threads if t_id == n_threads - 1: n_features_tloc = n_features - t_id * thread_size From 20e35fdbbe4ccc51052894c8a6070f020854b2d9 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 23 May 2022 17:04:55 +0800 Subject: [PATCH 6/6] s390x --- tests/python/testing.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/python/testing.py b/tests/python/testing.py index 6be0a1d5efce..1cfaca876e87 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -344,6 +344,19 @@ def make_sparse_regression( Return the matrix as np.ndarray with missing values filled by NaN """ + if not hasattr(np.random, "default_rng"): + # old version of numpy on s390x + rng = np.random.RandomState(1994) + X = sparse.random( + m=n_samples, + n=n_features, + density=1.0 - sparsity, + random_state=rng, + format="csr", + ) + y = rng.normal(loc=0.0, scale=1.0, size=n_samples) + return X, y + # Use multi-thread to speed up the generation, convenient if you use this function # for benchmarking. n_threads = multiprocessing.cpu_count()