Skip to content

Commit

Permalink
LogisticRegression convert to float64 (for SAG solver) (scikit-learn#…
Browse files Browse the repository at this point in the history
…13243)

* Remove unused code

* Squash all the PR 9040 commits

initial PR commit

seq_dataset.pyx generated from template

seq_dataset.pyx generated from template scikit-learn#2

rename variables

fused types consistency test for seq_dataset

a

sklearn/utils/tests/test_seq_dataset.py

new if statement

add doc

sklearn/utils/seq_dataset.pyx.tp

minor changes

minor changes

typo fix

check numeric accuracy only up 5th decimal

Address oliver's request for changing test name

add test for make_dataset and rename a variable in test_seq_dataset

* FIX tests

* TST more numerically stable test_sgd.test_tol_parameter

* Added benchmarks to compare SAGA 32b and 64b

* Fixing gael's comments

* fix

* solve some issues

* PEP8

* Address lesteve comments

* fix merging

* avoid using assert_equal

* use all_close

* use explicit ArrayDataset64 and CSRDataset64

* fix: remove unused import

* Use parametrized to cover ArrayDaset-CSRDataset-32-64 matrix

* for consistency use 32 first then 64 + add 64 suffix to variables

* it would be cool if this worked !!!

* more verbose version

* revert SGD changes as much as possible.

* Add solvers back to bench_saga

* make 64 explicit in the naming

* remove checking native python type + add comparison between 32 64

* Add whatsnew with everyone with commits

* simplify a bit the testing

* simplify the parametrize

* update whatsnew

* fix pep8
  • Loading branch information
massich authored and koenvandevelde committed Jul 12, 2019
1 parent 45dd9f5 commit 4c836a7
Show file tree
Hide file tree
Showing 17 changed files with 789 additions and 396 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Expand Up @@ -71,3 +71,8 @@ _configtest.o.d

# Used by mypy
.mypy_cache/

# files generated from a template
sklearn/utils/seq_dataset.pyx
sklearn/utils/seq_dataset.pxd
sklearn/linear_model/sag_fast.pyx
151 changes: 106 additions & 45 deletions benchmarks/bench_saga.py
@@ -1,11 +1,11 @@
"""Author: Arthur Mensch
"""Author: Arthur Mensch, Nelle Varoquaux
Benchmarks of sklearn SAGA vs lightning SAGA vs Liblinear. Shows the gain
in using multinomial logistic regression in term of learning time.
"""
import json
import time
from os.path import expanduser
import os

from joblib import delayed, Parallel, Memory
import matplotlib.pyplot as plt
Expand All @@ -21,7 +21,7 @@


def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
max_iter=10, skip_slow=False):
max_iter=10, skip_slow=False, dtype=np.float64):
if skip_slow and solver == 'lightning' and penalty == 'l1':
print('skip_slowping l1 logistic regression with solver lightning.')
return
Expand All @@ -37,7 +37,8 @@ def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
multi_class = 'ovr'
else:
multi_class = 'multinomial'

X = X.astype(dtype)
y = y.astype(dtype)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42,
stratify=y)
n_samples = X_train.shape[0]
Expand Down Expand Up @@ -69,11 +70,15 @@ def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
multi_class=multi_class,
C=C,
penalty=penalty,
fit_intercept=False, tol=1e-24,
fit_intercept=False, tol=0,
max_iter=this_max_iter,
random_state=42,
)

# Makes cpu cache even for all fit calls
X_train.max()
t0 = time.clock()

lr.fit(X_train, y_train)
train_time = time.clock() - t0

Expand Down Expand Up @@ -106,9 +111,13 @@ def _predict_proba(lr, X):
return softmax(pred)


def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
def exp(solvers, penalty, single_target,
n_samples=30000, max_iter=20,
dataset='rcv1', n_jobs=1, skip_slow=False):
mem = Memory(cachedir=expanduser('~/cache'), verbose=0)
dtypes_mapping = {
"float64": np.float64,
"float32": np.float32,
}

if dataset == 'rcv1':
rcv1 = fetch_rcv1()
Expand Down Expand Up @@ -151,21 +160,24 @@ def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
X = X[:n_samples]
y = y[:n_samples]

cached_fit = mem.cache(fit_single)
out = Parallel(n_jobs=n_jobs, mmap_mode=None)(
delayed(cached_fit)(solver, X, y,
delayed(fit_single)(solver, X, y,
penalty=penalty, single_target=single_target,
dtype=dtype,
C=1, max_iter=max_iter, skip_slow=skip_slow)
for solver in solvers
for penalty in penalties)
for dtype in dtypes_mapping.values())

res = []
idx = 0
for solver in solvers:
for penalty in penalties:
if not (skip_slow and solver == 'lightning' and penalty == 'l1'):
for dtype_name in dtypes_mapping.keys():
for solver in solvers:
if not (skip_slow and
solver == 'lightning' and
penalty == 'l1'):
lr, times, train_scores, test_scores, accuracies = out[idx]
this_res = dict(solver=solver, penalty=penalty,
dtype=dtype_name,
single_target=single_target,
times=times, train_scores=train_scores,
test_scores=test_scores,
Expand All @@ -177,68 +189,117 @@ def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
json.dump(res, f)


def plot():
def plot(outname=None):
import pandas as pd
with open('bench_saga.json', 'r') as f:
f = json.load(f)
res = pd.DataFrame(f)
res.set_index(['single_target', 'penalty'], inplace=True)
res.set_index(['single_target'], inplace=True)

grouped = res.groupby(level=['single_target', 'penalty'])
grouped = res.groupby(level=['single_target'])

colors = {'saga': 'blue', 'liblinear': 'orange', 'lightning': 'green'}
colors = {'saga': 'C0', 'liblinear': 'C1', 'lightning': 'C2'}
linestyles = {"float32": "--", "float64": "-"}
alpha = {"float64": 0.5, "float32": 1}

for idx, group in grouped:
single_target, penalty = idx
fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(131)

train_scores = group['train_scores'].values
ref = np.min(np.concatenate(train_scores)) * 0.999

for scores, times, solver in zip(group['train_scores'], group['times'],
group['solver']):
scores = scores / ref - 1
ax.plot(times, scores, label=solver, color=colors[solver])
single_target = idx
fig, axes = plt.subplots(figsize=(12, 4), ncols=4)
ax = axes[0]

for scores, times, solver, dtype in zip(group['train_scores'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(times, scores, label="%s - %s" % (solver, dtype),
color=colors[solver],
alpha=alpha[dtype],
marker=".",
linestyle=linestyles[dtype])
ax.axvline(times[-1], color=colors[solver],
alpha=alpha[dtype],
linestyle=linestyles[dtype])
ax.set_xlabel('Time (s)')
ax.set_ylabel('Training objective (relative to min)')
ax.set_yscale('log')

ax = fig.add_subplot(132)
ax = axes[1]

test_scores = group['test_scores'].values
ref = np.min(np.concatenate(test_scores)) * 0.999
for scores, times, solver, dtype in zip(group['test_scores'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(times, scores, label=solver, color=colors[solver],
linestyle=linestyles[dtype],
marker=".",
alpha=alpha[dtype])
ax.axvline(times[-1], color=colors[solver],
alpha=alpha[dtype],
linestyle=linestyles[dtype])

for scores, times, solver in zip(group['test_scores'], group['times'],
group['solver']):
scores = scores / ref - 1
ax.plot(times, scores, label=solver, color=colors[solver])
ax.set_xlabel('Time (s)')
ax.set_ylabel('Test objective (relative to min)')
ax.set_yscale('log')

ax = fig.add_subplot(133)
ax = axes[2]
for accuracy, times, solver, dtype in zip(group['accuracies'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(times, accuracy, label="%s - %s" % (solver, dtype),
alpha=alpha[dtype],
marker=".",
color=colors[solver], linestyle=linestyles[dtype])
ax.axvline(times[-1], color=colors[solver],
alpha=alpha[dtype],
linestyle=linestyles[dtype])

for accuracy, times, solver in zip(group['accuracies'], group['times'],
group['solver']):
ax.plot(times, accuracy, label=solver, color=colors[solver])
ax.set_xlabel('Time (s)')
ax.set_ylabel('Test accuracy')
ax.legend()
name = 'single_target' if single_target else 'multi_target'
name += '_%s' % penalty
plt.suptitle(name)
name += '.png'
if outname is None:
outname = name + '.png'
fig.tight_layout()
fig.subplots_adjust(top=0.9)
plt.savefig(name)
plt.close(fig)

ax = axes[3]
for scores, times, solver, dtype in zip(group['train_scores'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(np.arange(len(scores)),
scores, label="%s - %s" % (solver, dtype),
marker=".",
alpha=alpha[dtype],
color=colors[solver], linestyle=linestyles[dtype])

ax.set_yscale("log")
ax.set_xlabel('# iterations')
ax.set_ylabel('Objective function')
ax.legend()

plt.savefig(outname)


if __name__ == '__main__':
solvers = ['saga', 'liblinear', 'lightning']
penalties = ['l1', 'l2']
n_samples = [100000, 300000, 500000, 800000, None]
single_target = True
exp(solvers, penalties, single_target, n_samples=None, n_jobs=1,
dataset='20newspaper', max_iter=20)
plot()
for penalty in penalties:
for n_sample in n_samples:
exp(solvers, penalty, single_target,
n_samples=n_sample, n_jobs=1,
dataset='rcv1', max_iter=10)
if n_sample is not None:
outname = "figures/saga_%s_%d.png" % (penalty, n_sample)
else:
outname = "figures/saga_%s_all.png" % (penalty,)
try:
os.makedirs("figures")
except OSError:
pass
plot(outname)
5 changes: 5 additions & 0 deletions doc/whats_new/v0.21.rst
Expand Up @@ -162,6 +162,11 @@ Support for Python 3.4 and below has been officially dropped.
:mod:`sklearn.linear_model`
...........................

- |Enhancement| :class:`linear_model.make_dataset` now preserves
``float32`` and ``float64`` dtypes. :issues:`8769` and :issues:`11000` by
:user:`Nelle Varoquaux`_, :user:`Arthur Imbert <Henley13>`,
:user:`Guillaume Lemaitre <glemaitre>`, and :user:`Joan Massich <massich>`

- |Feature| :class:`linear_model.LogisticRegression` and
:class:`linear_model.LogisticRegressionCV` now support Elastic-Net penalty,
with the 'saga' solver. :issue:`11646` by :user:`Nicolas Hug <NicolasHug>`.
Expand Down
18 changes: 13 additions & 5 deletions sklearn/linear_model/base.py
Expand Up @@ -32,7 +32,8 @@
from ..utils.extmath import safe_sparse_dot
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
from ..utils.fixes import sparse_lsqr
from ..utils.seq_dataset import ArrayDataset, CSRDataset
from ..utils.seq_dataset import ArrayDataset32, CSRDataset32
from ..utils.seq_dataset import ArrayDataset64, CSRDataset64
from ..utils.validation import check_is_fitted
from ..exceptions import NotFittedError
from ..preprocessing.data import normalize as f_normalize
Expand Down Expand Up @@ -76,15 +77,22 @@ def make_dataset(X, y, sample_weight, random_state=None):
"""

rng = check_random_state(random_state)
# seed should never be 0 in SequentialDataset
# seed should never be 0 in SequentialDataset64
seed = rng.randint(1, np.iinfo(np.int32).max)

if X.dtype == np.float32:
CSRData = CSRDataset32
ArrayData = ArrayDataset32
else:
CSRData = CSRDataset64
ArrayData = ArrayDataset64

if sp.issparse(X):
dataset = CSRDataset(X.data, X.indptr, X.indices, y, sample_weight,
seed=seed)
dataset = CSRData(X.data, X.indptr, X.indices, y, sample_weight,
seed=seed)
intercept_decay = SPARSE_INTERCEPT_DECAY
else:
dataset = ArrayDataset(X, y, sample_weight, seed=seed)
dataset = ArrayData(X, y, sample_weight, seed=seed)
intercept_decay = 1.0

return dataset, intercept_decay
Expand Down
12 changes: 8 additions & 4 deletions sklearn/linear_model/logistic.py
Expand Up @@ -964,7 +964,7 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,

elif solver in ['sag', 'saga']:
if multi_class == 'multinomial':
target = target.astype(np.float64)
target = target.astype(X.dtype, copy=False)
loss = 'multinomial'
else:
loss = 'log'
Expand Down Expand Up @@ -1486,6 +1486,10 @@ def fit(self, X, y, sample_weight=None):
Returns
-------
self : object
Notes
-----
The SAGA solver supports both float64 and float32 bit arrays.
"""
solver = _check_solver(self.solver, self.penalty, self.dual)

Expand Down Expand Up @@ -1520,10 +1524,10 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Tolerance for stopping criteria must be "
"positive; got (tol=%r)" % self.tol)

if solver in ['newton-cg']:
_dtype = [np.float64, np.float32]
else:
if solver in ['lbfgs', 'liblinear']:
_dtype = np.float64
else:
_dtype = [np.float64, np.float32]

X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype, order="C",
accept_large_sparse=solver != 'liblinear')
Expand Down

0 comments on commit 4c836a7

Please sign in to comment.