From d18be5b18adcf2d34da20acc197f879fb7a076c4 Mon Sep 17 00:00:00 2001 From: SELEE Date: Thu, 19 May 2022 17:25:25 +0900 Subject: [PATCH] FIX Update randomized SVD benchmark (#23373) --- benchmarks/bench_plot_randomized_svd.py | 2 +- sklearn/utils/extmath.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/benchmarks/bench_plot_randomized_svd.py b/benchmarks/bench_plot_randomized_svd.py index c16e21c3e0568..896b29ef471dd 100644 --- a/benchmarks/bench_plot_randomized_svd.py +++ b/benchmarks/bench_plot_randomized_svd.py @@ -153,7 +153,7 @@ def get_data(dataset_name): elif dataset_name == "rcv1": X = fetch_rcv1().data elif dataset_name == "CIFAR": - if handle_missing_dataset(CIFAR_FOLDER) == "skip": + if handle_missing_dataset(CIFAR_FOLDER) == 0: return X1 = [unpickle("%sdata_batch_%d" % (CIFAR_FOLDER, i + 1)) for i in range(5)] X = np.vstack(X1) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 2521990e6cc68..4438f67fb5729 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -216,9 +216,6 @@ def randomized_range_finder( # Generating normal random vectors with shape: (A.shape[1], size) Q = random_state.normal(size=(A.shape[1], size)) - if A.dtype.kind == "f": - # Ensure f32 is preserved as f32 - Q = Q.astype(A.dtype, copy=False) # Deal with "auto" mode if power_iteration_normalizer == "auto": @@ -243,6 +240,11 @@ def randomized_range_finder( # Sample the range of A using by linear projection of Q # Extract an orthonormal basis Q, _ = linalg.qr(safe_sparse_dot(A, Q), mode="economic") + + if hasattr(A, "dtype") and A.dtype.kind == "f": + # Ensure f32 is preserved as f32 + Q = Q.astype(A.dtype, copy=False) + return Q