Skip to content

Commit

Permalink
ENH add verbosity to newton-cg solver (#27526)
Browse files Browse the repository at this point in the history
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
  • Loading branch information
3 people committed Apr 12, 2024
1 parent 556e0cf commit 3ee60a7
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 14 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ Changelog
:mod:`sklearn.linear_model`
...........................

- |Enhancement| Solver `"newton-cg"` in :class:`linear_model.LogisticRegression` and
:class:`linear_model.LogisticRegressionCV` now emits information when `verbose` is
set to positive values.
:pr:`27526` by :user:`Christian Lorentzen <lorentzenchr>`.

- |Fix| :class:`linear_model.ElasticNet`, :class:`linear_model.ElasticNetCV`,
:class:`linear_model.Lasso` and :class:`linear_model.LassoCV` now explicitly don't
accept large sparse data formats. :pr:`27576` by :user:`Stefanie Senger
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/_glm/_newton_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def line_search(self, X, y, sample_weight):
is_verbose = self.verbose >= 2
if is_verbose:
print(" Backtracking Line Search")
print(f" eps=10 * finfo.eps={eps}")
print(f" eps=16 * finfo.eps={eps}")

for i in range(21): # until and including t = beta**20 ~ 1e-6
self.coef = self.coef_old + t * self.coef_newton
Expand Down
9 changes: 8 additions & 1 deletion sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,14 @@ def _logistic_regression_path(
l2_reg_strength = 1.0 / (C * sw_sum)
args = (X, target, sample_weight, l2_reg_strength, n_threads)
w0, n_iter_i = _newton_cg(
hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol
grad_hess=hess,
func=func,
grad=grad,
x0=w0,
args=args,
maxiter=max_iter,
tol=tol,
verbose=verbose,
)
elif solver == "newton-cholesky":
l2_reg_strength = 1.0 / (C * sw_sum)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/tests/test_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ def test_logreg_predict_proba_multinomial():
[
(
"newton-cg",
"newton-cg failed to converge. Increase the number of iterations.",
"newton-cg failed to converge.* Increase the number of iterations.",
),
(
"liblinear",
Expand Down
98 changes: 87 additions & 11 deletions sklearn/utils/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class _LineSearchError(RuntimeError):
pass


def _line_search_wolfe12(f, fprime, xk, pk, gfk, old_fval, old_old_fval, **kwargs):
def _line_search_wolfe12(
f, fprime, xk, pk, gfk, old_fval, old_old_fval, verbose=0, **kwargs
):
"""
Same as line_search_wolfe1, but fall back to line_search_wolfe2 if
suitable step length is not found, and raise an exception if a
Expand All @@ -39,24 +41,44 @@ def _line_search_wolfe12(f, fprime, xk, pk, gfk, old_fval, old_old_fval, **kwarg
If no suitable step size is found.
"""
is_verbose = verbose >= 2
eps = 16 * np.finfo(np.asarray(old_fval).dtype).eps
if is_verbose:
print(" Line Search")
print(f" eps=16 * finfo.eps={eps}")
print(" try line search wolfe1")

ret = line_search_wolfe1(f, fprime, xk, pk, gfk, old_fval, old_old_fval, **kwargs)

if is_verbose:
_not_ = "not " if ret[0] is None else ""
print(" wolfe1 line search was " + _not_ + "successful")

if ret[0] is None:
# Have a look at the line_search method of our NewtonSolver class. We borrow
# the logic from there
# Deal with relative loss differences around machine precision.
args = kwargs.get("args", tuple())
fval = f(xk + pk, *args)
eps = 16 * np.finfo(np.asarray(old_fval).dtype).eps
tiny_loss = np.abs(old_fval * eps)
loss_improvement = fval - old_fval
check = np.abs(loss_improvement) <= tiny_loss
if is_verbose:
print(
" check loss |improvement| <= eps * |loss_old|:"
f" {np.abs(loss_improvement)} <= {tiny_loss} {check}"
)
if check:
# 2.1 Check sum of absolute gradients as alternative condition.
sum_abs_grad_old = scipy.linalg.norm(gfk, ord=1)
grad = fprime(xk + pk, *args)
sum_abs_grad = scipy.linalg.norm(grad, ord=1)
check = sum_abs_grad < sum_abs_grad_old
if is_verbose:
print(
" check sum(|gradient|) < sum(|gradient_old|): "
f"{sum_abs_grad} < {sum_abs_grad_old} {check}"
)
if check:
ret = (
1.0, # step size
Expand All @@ -72,17 +94,22 @@ def _line_search_wolfe12(f, fprime, xk, pk, gfk, old_fval, old_old_fval, **kwarg
# TODO: It seems that the new check for the sum of absolute gradients above
# catches all cases that, earlier, ended up here. In fact, our tests never
# trigger this "if branch" here and we can consider to remove it.
if is_verbose:
print(" last resort: try line search wolfe2")
ret = line_search_wolfe2(
f, fprime, xk, pk, gfk, old_fval, old_old_fval, **kwargs
)
if is_verbose:
_not_ = "not " if ret[0] is None else ""
print(" wolfe2 line search was " + _not_ + "successful")

if ret[0] is None:
raise _LineSearchError()

return ret


def _cg(fhess_p, fgrad, maxiter, tol):
def _cg(fhess_p, fgrad, maxiter, tol, verbose=0):
"""
Solve iteratively the linear system 'fhess_p . xsupi = fgrad'
with a conjugate gradient descent.
Expand All @@ -107,30 +134,51 @@ def _cg(fhess_p, fgrad, maxiter, tol):
xsupi : ndarray of shape (n_features,) or (n_features + 1,)
Estimated solution.
"""
eps = 16 * np.finfo(np.float64).eps
xsupi = np.zeros(len(fgrad), dtype=fgrad.dtype)
ri = np.copy(fgrad)
ri = np.copy(fgrad) # residual = fgrad - fhess_p @ xsupi
psupi = -ri
i = 0
dri0 = np.dot(ri, ri)
# We also track of |p_i|^2.
# We also keep track of |p_i|^2.
psupi_norm2 = dri0
is_verbose = verbose >= 2

while i <= maxiter:
if np.sum(np.abs(ri)) <= tol:
if is_verbose:
print(
f" Inner CG solver iteration {i} stopped with\n"
f" sum(|residuals|) <= tol: {np.sum(np.abs(ri))} <= {tol}"
)
break

Ap = fhess_p(psupi)
# check curvature
curv = np.dot(psupi, Ap)
if 0 <= curv <= 16 * np.finfo(np.float64).eps * psupi_norm2:
if 0 <= curv <= eps * psupi_norm2:
# See https://arxiv.org/abs/1803.02924, Algo 1 Capped Conjugate Gradient.
if is_verbose:
print(
f" Inner CG solver iteration {i} stopped with\n"
f" tiny_|p| = eps * ||p||^2, eps = {eps}, "
f"squred L2 norm ||p||^2 = {psupi_norm2}\n"
f" curvature <= tiny_|p|: {curv} <= {eps * psupi_norm2}"
)
break
elif curv < 0:
if i > 0:
if is_verbose:
print(
f" Inner CG solver iteration {i} stopped with negative "
f"curvature, curvature = {curv}"
)
break
else:
# fall back to steepest descent direction
xsupi += dri0 / curv * psupi
if is_verbose:
print(" Inner CG solver iteration 0 fell back to steepest descent")
break
alphai = dri0 / curv
xsupi += alphai * psupi
Expand All @@ -142,7 +190,11 @@ def _cg(fhess_p, fgrad, maxiter, tol):
psupi_norm2 = dri1 + betai**2 * psupi_norm2
i = i + 1
dri0 = dri1 # update np.dot(ri,ri) for next time.

if is_verbose and i > maxiter:
print(
f" Inner CG solver stopped reaching maxiter={i - 1} with "
f"sum(|residuals|) = {np.sum(np.abs(ri))}"
)
return xsupi


Expand All @@ -157,6 +209,7 @@ def _newton_cg(
maxinner=200,
line_search=True,
warn=True,
verbose=0,
):
"""
Minimization of scalar function of one or more variables using the
Expand Down Expand Up @@ -210,6 +263,10 @@ def _newton_cg(
if line_search:
old_fval = func(x0, *args)
old_old_fval = None
else:
old_fval = 0

is_verbose = verbose > 0

# Outer loop: our Newton iteration
while k < maxiter:
Expand All @@ -218,7 +275,13 @@ def _newton_cg(
fgrad, fhess_p = grad_hess(xk, *args)

absgrad = np.abs(fgrad)
if np.max(absgrad) <= tol:
max_absgrad = np.max(absgrad)
check = max_absgrad <= tol
if is_verbose:
print(f"Newton-CG iter = {k}")
print(" Check Convergence")
print(f" max |gradient| <= tol: {max_absgrad} <= {tol} {check}")
if check:
break

maggrad = np.sum(absgrad)
Expand All @@ -227,14 +290,22 @@ def _newton_cg(

# Inner loop: solve the Newton update by conjugate gradient, to
# avoid inverting the Hessian
xsupi = _cg(fhess_p, fgrad, maxiter=maxinner, tol=termcond)
xsupi = _cg(fhess_p, fgrad, maxiter=maxinner, tol=termcond, verbose=verbose)

alphak = 1.0

if line_search:
try:
alphak, fc, gc, old_fval, old_old_fval, gfkp1 = _line_search_wolfe12(
func, grad, xk, xsupi, fgrad, old_fval, old_old_fval, args=args
func,
grad,
xk,
xsupi,
fgrad,
old_fval,
old_old_fval,
verbose=verbose,
args=args,
)
except _LineSearchError:
warnings.warn("Line Search failed")
Expand All @@ -245,9 +316,14 @@ def _newton_cg(

if warn and k >= maxiter:
warnings.warn(
"newton-cg failed to converge. Increase the number of iterations.",
(
f"newton-cg failed to converge at loss = {old_fval}. Increase the"
" number of iterations."
),
ConvergenceWarning,
)
elif is_verbose:
print(f" Solver did converge at loss = {old_fval}.")
return xk, k


Expand Down

0 comments on commit 3ee60a7

Please sign in to comment.