Skip to content

Commit

Permalink
FEA Allow linear estimators to accept any dtype by converting to floa…
Browse files Browse the repository at this point in the history
…t32 by default
  • Loading branch information
dantegd committed May 14, 2024
1 parent 68d4336 commit d8576f6
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 3 deletions.
2 changes: 2 additions & 0 deletions python/cuml/linear_model/linear_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ class LinearRegression(LinearPredictMixin,

X_m, n_rows, self.n_features_in_, self.dtype = \
input_to_cuml_array(X,
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=[np.float32, np.float64],
deepcopy=need_explicit_copy)
_X_ptr = X_m.ptr
Expand Down
2 changes: 2 additions & 0 deletions python/cuml/linear_model/ridge.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ class Ridge(UniversalBase,
cdef uintptr_t _X_ptr, _y_ptr, _sample_weight_ptr
X_m, n_rows, self.n_features_in_, self.dtype = \
input_to_cuml_array(X, deepcopy=True,
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=[np.float32, np.float64])
_X_ptr = X_m.ptr
self.feature_names_in_ = X_m.index
Expand Down
5 changes: 4 additions & 1 deletion python/cuml/solvers/cd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,10 @@ class CD(Base,
"""
cdef uintptr_t sample_weight_ptr
X_m, n_rows, self.n_cols, self.dtype = \
input_to_cuml_array(X, check_dtype=[np.float32, np.float64])
input_to_cuml_array(X,
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=[np.float32, np.float64])

y_m, *_ = \
input_to_cuml_array(y, check_dtype=self.dtype,
Expand Down
6 changes: 5 additions & 1 deletion python/cuml/solvers/qn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,11 @@ class QN(Base,
# Handle dense inputs
else:
X_m, n_rows, self.n_cols, self.dtype = input_to_cuml_array(
X, check_dtype=[np.float32, np.float64], order='K'
X,
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=[np.float32, np.float64],
order='K'
)

y_m, _, _, _ = input_to_cuml_array(
Expand Down
5 changes: 4 additions & 1 deletion python/cuml/solvers/sgd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,10 @@ class SGD(Base,

"""
X_m, n_rows, self.n_cols, self.dtype = \
input_to_cuml_array(X, check_dtype=[np.float32, np.float64])
input_to_cuml_array(X,
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=[np.float32, np.float64])

y_m, _, _, _ = \
input_to_cuml_array(y, check_dtype=self.dtype,
Expand Down

0 comments on commit d8576f6

Please sign in to comment.