Source code for mrinufft.extras.optim

"""Implements the LSQR algorithm."""

from collections.abc import Callable
import numpy as np
from tqdm.auto import tqdm
from numpy.typing import NDArray

from mrinufft._array_compat import get_array_module, with_numpy_cupy, CUPY_AVAILABLE
from mrinufft.operators.base import FourierOperatorBase
from mrinufft._utils import MethodRegister

_optim_docs = dict(
    base_params=r"""
nufft: FourierOperatorBase
    The NUFFT operator representing the forward model.
kspace_data: NDArray
    The right-hand side vector (`kspace` data). Shape is typically (n_batchs,
    n_coils, n_samples).
damp: float, optional
    Damping (regularization) parameter. Default is 0.0 (no regularization).
x0: NDArray or None, optional
    Damping vector. If None, uses zero. Shape is typically (n_batchs, n_coils or 1,
    *nufft.shape).
x_init: NDArray or None, optional
    Initial guess vector. If ommitted, default to x0. Must have same shape as x0.
callback: Callable, optional
    If provided, a callback function will be called at the end of each
    iteration with the current estimate. It should have the following signature
    ``callback(operator, kspace_data, damp, x0)``
max_iter: int, optional
    Maximum number of iterations. Default is 100.
progressbar: bool, optional
    If True (default) display a progress bar to track iterations.
""",
    returns="""
Returns
-------
NDArray:
    Solution vector with shape (n_batchs, n_coils or 1, *nufft.shape), dtype and
    device matching input.
""",
)


register_optim = MethodRegister("optim", _optim_docs)
get_optimizer = register_optim.make_getter()


def _norm_batched_np(x: NDArray) -> NDArray:
    return np.sqrt(np.sum(abs(x) ** 2, axis=tuple(range(1, x.ndim))))


if CUPY_AVAILABLE:
    import cupy as cp

    def _norm_batched_cp(x: NDArray) -> NDArray:
        return cp.linalg.norm(x.reshape(x.shape[0], -1), axis=-1)


def _bc_left(x, y):
    """Broadcast x to y shape, starting from first axis.

    Usefull for applying batch-wise scaling factors. as regular numpy
    broadcasting start from the last axis.

    Parameters
    ----------
    x : NDArray
        Array to broadcast.
    y : NDArray
        Target array.

    Returns
    -------
    NDArray
        Broadcasted array.
    """
    return x.reshape(x.shape + (1,) * (y.ndim - x.ndim))


[docs] @with_numpy_cupy def loss_l2_reg( image: NDArray, operator: FourierOperatorBase, kspace_data: NDArray, damp: float = 0.0, x0: NDArray | None = None, ): r""" Compute the regularized least squares loss for MRI reconstruction. Computes the loss :math:`\|A x - y\|_2^2 + \gamma^2 \|x - x_0\|_2^2` where A is the measurement operator, x is the current image estimate, y is the acquired k-space data, damp (:math:`\gamma`) is a regularization parameter, and :math:`x_0` is an initial guess. Parameters ---------- image : NDArray Current image estimate. Shape and dtype must be compatible with the operator. operator : FourierOperatorBase The NUFFT (non-uniform FFT) operator used for forward modeling. kspace_data : NDArray Measured k-space data. Shape must match the output of the operator.op(image). damp : float or None, optional Regularization parameter (default=None). If None, no regularization is applied. x0 : NDArray or None, optional Reference image for regularization (default=None). Returns ------- norm_res : float or NDArray The computed L2 regularized least squares loss value(s). If batched, shape = (n_batchs,). Notes ----- - Batch dimension is preserved if present. - This function can be used as a callback in cg or lsqr method to keep track of the convergence. """ xp = get_array_module(image) norm_batched = _norm_batched_cp if xp.__name__ == "cupy" else _norm_batched_np residual = operator.op(image).reshape(operator.ksp_full_shape) residual -= kspace_data.reshape(operator.ksp_full_shape) residual.reshape(operator.n_batchs, -1) norm_res = norm_batched(residual).squeeze() if (isinstance(damp, float | np.generic | xp.generic) and damp != 0.0) or ( isinstance(damp, xp.ndarray) and xp.any(damp) ): image_ = image.reshape(operator.img_full_shape) if x0 is not None: image_ = image_ - x0.reshape(operator.img_full_shape) norm_damp = ( damp**2 * norm_batched(image_.reshape(operator.n_batchs, -1)) ).squeeze() norm_res += norm_damp return norm_res
[docs] @with_numpy_cupy def loss_l2_AHreg( image: NDArray, operator: FourierOperatorBase, kspace_data: NDArray, *args, **kwargs, ): """ Compute the norm of the residual in the image domain. Parameters ---------- image : NDArray Current image estimate. Shape and dtype must be compatible with the operator. operator : FourierOperatorBase The NUFFT (non-uniform FFT) operator used for forward modeling. kspace_data : NDArray Measured k-space data. Shape must match the output of the operator.op(image). Returns ------- norm_res : float or NDArray The computed L2 regularized least squares loss value(s). If batched, shape = (n_batchs,). Notes ----- - Batch dimension is preserved if present. - This function can be used as a callback in cg or lsqr method to keep track of the convergence. """ xp = get_array_module(image) norm_batched = _norm_batched_cp if xp.__name__ == "cupy" else _norm_batched_np residual = operator.op(image).reshape(operator.ksp_full_shape) residual -= kspace_data.reshape(operator.ksp_full_shape) img_residual = operator.adj_op(residual).reshape(operator.img_full_shape) img_residual.reshape(operator.n_batchs, -1) norm_res = norm_batched(img_residual).squeeze() return norm_res
def _sym_ortho(a: NDArray, b: NDArray) -> tuple[NDArray, NDArray, NDArray]: """ Stable implementation of Givens rotation. Notes ----- The routine 'SymOrtho' was added for numerical stability. This is recommended by S.-C. Choi in [1]_. It removes the unpleasant potential of ``1/eps`` in some important places (see, for example text following "Compute the next plane rotation Qk" in minres.py). References ---------- .. [1] S.-C. Choi, "Iterative Methods for Singular Linear Equations and Least-Squares Problems", Dissertation, http://www.stanford.edu/group/SOL/dissertations/sou-cheng-choi-thesis.pdf """ xp = get_array_module(a) if xp.any(b == 0): return xp.sign(a), 0, abs(a) elif xp.any(a == 0): return 0, xp.sign(b), abs(b) elif xp.any(abs(b) > abs(a)): tau = a / b s = xp.sign(b) / xp.sqrt(1 + tau * tau) c = s * tau r = b / s else: tau = b / a c = xp.sign(a) / xp.sqrt(1 + tau * tau) s = c * tau r = a / c return c, s, r
[docs] @register_optim @with_numpy_cupy def lsqr( operator: FourierOperatorBase, kspace_data: NDArray, damp: float = 0.0, atol: float = 1e-6, btol: float = 1e-6, conlim: float = 1e8, max_iter: int = 100, x0: NDArray | None = None, x_init: NDArray | None = None, callback: Callable | None = None, progressbar: bool = True, ): r""" Solve a general regularized linear least-squares problem using the LSQR algorithm. Solves problems of the form .. math:: \arg\min \|A x - b\|_2^2 + \gamma^2 \|x - x0\|_2^2 Stop iterating if: - numerical convergence is reached: :math:`\|Ax-b\| <= atol \|A\| * \|x\| + btol * \|b\|` - estimation of the conditioning of the problem diverge: ``cond(A)>=conlim`` - Maximum number of iteration reached. Parameters ---------- $base_params atol : float, optional Stopping tolerance on the absolute error. Default is 1e-6. btol : float, optional Stopping tolerance on the relative error. Default is 1e-6. conlim : float, optional Limit on condition number. Iteration stops if condition exceeds this value. Default is 1e8. $returns References ---------- .. [1] Paige, C. C., & Saunders, M. A. (1982). LSQR: An algorithm for sparse linear equations and sparse least squares. ACM Transactions on Mathematical Software, 8(1), 43-71. .. [2] S.-C. Choi, "Iterative Methods for Singular Linear Equations and Least-Squares Problems", Dissertation, http://www.stanford.edu/group/SOL/dissertations/sou-cheng-choi-thesis.pdf .. [3] https://github.com/scipy/scipy/blob/v1.16.2/scipy/sparse/linalg/_isolve/lsqr.py """ xp = get_array_module(kspace_data) norm_batched = _norm_batched_cp if xp.__name__ == "cupy" else _norm_batched_np ctol = 0 if conlim > 0: ctol = 1 / conlim eps = xp.finfo(kspace_data.dtype).eps kspace_data = kspace_data.reshape(operator.ksp_full_shape) if kspace_data.ndim > 1: kspace_data.squeeze() u = kspace_data.copy() bnorm = norm_batched(u) if x_init is None: if x0 is None: x_init = xp.zeros(operator.img_full_shape, dtype=operator.cpx_dtype) else: x0 = x0.reshape(operator.img_full_shape) x_init = xp.copy(x0).reshape(operator.img_full_shape) else: x_init = x_init.reshape(operator.img_full_shape) x = x_init beta = bnorm.copy() if x0 is not None: u -= operator.op(x0).reshape(operator.ksp_full_shape) beta = norm_batched(u) if xp.all(beta) > 0: u /= _bc_left(beta, u) v = operator.adj_op(u).reshape(operator.img_full_shape) alpha = norm_batched(v) else: v = xp.copy(x) alpha = xp.zeros(v.shape[0]) if xp.any((alpha * beta) == 0): return x if xp.all(alpha) > 0: v /= _bc_left(alpha, v) w = xp.copy(v) rhobar = alpha phibar = rnorm = r1norm = beta arnorm = alpha * beta ddnorm = res2 = xnorm = xxnorm = z = anorm = acond = 0.0 dampsq = damp**2 cs2 = -1 sn2 = 0.0 istop = 0 callback_returns = [] for _ in tqdm(range(max_iter), disable=not progressbar): u *= -_bc_left(alpha, u) u += operator.op(v).reshape(operator.ksp_full_shape) beta = norm_batched(u) if xp.all(beta) > 0: u /= _bc_left(beta, u) anorm = xp.sqrt(anorm**2 + alpha**2 + beta**2 + dampsq) v *= -_bc_left(beta, v) v += operator.adj_op(u).reshape(operator.img_full_shape) alpha = norm_batched(v) if xp.all(alpha) > 0: v /= _bc_left(alpha, v) if damp: rhobar1 = xp.sqrt(rhobar**2 + dampsq) cs1 = rhobar / rhobar1 sn1 = damp / rhobar1 psi = sn1 * phibar phibar = cs1 * phibar else: rhobar1 = rhobar psi = 0.0 # use a plane rotation to eliminate the subdiagonal element (beta) # of the lower-bidiagonal matrix, giving an upper-bidiagonal matrix. cs, sn, rho = _sym_ortho(rhobar1, beta) theta = sn * alpha rhobar = -cs * alpha phi = cs * phibar phibar = sn * phibar tau = sn * phi t1 = phi / rho t2 = -theta / rho dk = w / _bc_left(rho, w) # update x and w x += _bc_left(t1, w) * w w *= _bc_left(t2, w) w += v ddnorm += norm_batched(dk) ** 2 # Use a plane rotation on the right to eliminate the # super-diagonal element (theta) of the upper-bidiagonal matrix. # Then use the result to estimate norm(x). delta = sn2 * rho gambar = -cs2 * rho rhs = phi - delta * z zbar = rhs / gambar xnorm = xp.sqrt(xxnorm + zbar**2) gamma = xp.sqrt(gambar**2 + theta**2) cs2 = gambar / gamma sn2 = theta / gamma z = rhs / gamma xxnorm += z**2 # Test for convergence. # First, estimate the condition of the matrix Abar, # and the norms of rbar and Abar'rbar. acond = anorm * xp.sqrt(ddnorm) res1 = phibar**2 res2 += psi**2 rnorm = xp.sqrt(res1 + res2) arnorm = alpha * xp.abs(tau) # Distinguish between # r1norm = ||b - Ax|| and # r2norm = rnorm in current code # = sqrt(r1norm^2 + damp^2*||x - x0||^2). # Estimate r1norm from # r1norm = sqrt(r2norm^2 - damp^2*||x - x0||^2). # Although there is cancellation, it might be accurate enough. if damp > 0: r1sq = rnorm**2 - dampsq * xxnorm r1norm = xp.sqrt(xp.abs(r1sq)) if r1sq < 0: r1norm = -r1norm else: r1norm = rnorm # Now use these norms to estimate certain other quantities, # some of which will be small near a solution. test1 = rnorm / bnorm test2 = arnorm / (anorm * rnorm + eps) test3 = 1 / (acond + eps) t1 = test1 / (1 + anorm * xnorm / bnorm) rtol = btol + atol * anorm * xnorm / bnorm # The following tests guard against extremely small values of # atol, btol or ctol. (The user may have set any or all of # the parameters atol, btol, conlim to 0.) # The effect is equivalent to the normal tests using # atol = eps, btol = eps, conlim = 1/eps. if callback: callback_returns.append( callback(x, operator, kspace_data, damp=damp, x0=x0) ) if xp.all(1 + test3 <= 1): istop = 6 elif xp.all(1 + test2 <= 1): istop = 5 elif xp.all(1 + t1 <= 1): istop = 4 # Allow for tolerances set by the user. elif xp.all(test3 <= ctol): istop = 3 elif xp.all(test2 <= atol): istop = 2 elif xp.all(test1 <= rtol): istop = 1 if istop: break if operator.squeeze_dims: x = operator._safe_squeeze(x) if callback_returns: return x, callback_returns return x
[docs] @register_optim @with_numpy_cupy def lsmr( operator: FourierOperatorBase, kspace_data: NDArray, damp: float = 0.0, atol: float = 1e-6, btol: float = 1e-6, conlim: float = 1e8, max_iter: int = 100, x0: NDArray | None = None, x_init: NDArray | None = None, callback: Callable | None = None, progressbar: bool = True, ): r""" Solve a general regularized linear least-squares problem using the LSMR algorithm. Solves problems of the form .. math:: \arg\min \|A x - b\|_2^2 + \gamma^2 \|x - x0\|_2^2 Stop iterating if: - numerical convergence is reached: :math:`\|Ax-b\| <= atol \|A\| * \|x\| + btol * \|b\|` - estimation of the conditioning of the problem diverge: ``cond(A)>=conlim`` - Maximum number of iteration reached. Parameters ---------- $base_params atol : float, optional Stopping tolerance on the absolute error. Default is 1e-6. btol : float, optional Stopping tolerance on the relative error. Default is 1e-6. conlim : float, optional Limit on condition number. Iteration stops if condition exceeds this value. Default is 1e8. $returns References ---------- .. [1] D. C.-L. Fong and M. A. Saunders, "LSMR: An iterative algorithm for sparse least-squares problems", SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011. :arxiv:`1006.0758` .. [2] LSMR Software, https://web.stanford.edu/group/SOL/software/lsmr/ Notes ----- - LSMR is generally more stable than LSQR, notably in term of image residual norm :math:`\|A^H(Ax-b)\|`, and is similar to the MINRES algorithm for least squares problems. - It usually converges faster than LSQR and can stop in fewer iterations. """ xp = get_array_module(kspace_data) norm_batched = _norm_batched_cp if xp.__name__ == "cupy" else _norm_batched_np ctol = 0 if conlim > 0: ctol = 1 / conlim # eps = xp.finfo(kspace_data.dtype).eps IMG_COIL_DIM = operator.n_coils if not operator.uses_sense else 1 def AT(y): return operator.adj_op(y).reshape( operator.n_batchs, IMG_COIL_DIM, *operator.shape ) def A(x): return operator.op(x).reshape( operator.n_batchs, operator.n_coils, operator.n_samples ) kspace_data = kspace_data.reshape( (operator.n_batchs, operator.n_coils, operator.n_samples) ) if kspace_data.ndim > 1: kspace_data.squeeze() u = kspace_data.copy() normb = norm_batched(u) if x_init is None: if x0 is None: x_init = xp.zeros(operator.img_full_shape, dtype=operator.cpx_dtype) else: x0 = x0.reshape(operator.img_full_shape) x_init = xp.copy(x0).reshape(operator.img_full_shape) else: x_init = x_init.reshape(operator.img_full_shape) x = x_init beta = normb.copy() if x0 is not None: u -= A(x) beta = norm_batched(u) if xp.all(beta) > 0: u /= _bc_left(beta, u) v = AT(u) alpha = norm_batched(v) else: v = xp.copy(x) alpha = xp.zeros(v.shape[0]) if xp.any((alpha * beta) == 0): return x if xp.all(alpha) > 0: v /= _bc_left(alpha, v) damp = xp.full(operator.n_batchs, damp, xp.float32) # initialize variable for 1st iteration itn = 0 zetabar = alpha * beta alphabar = alpha rho = 1 rhobar = 1 cbar = 1 sbar = 0 h = v.copy() hbar = xp.zeros(v.shape, operator.cpx_dtype) # Initialize variables for estimation of ||r||. betadd = beta betad = 0 rhodold = 1 tautildeold = 0 thetatilde = 0 zeta = 0 d = 0 # Initialize variables for estimation of ||A|| and cond(A) normA2 = alpha * alpha maxrbar = 0 minrbar = 1e100 normA = xp.sqrt(normA2) condA = 1 normx = 0 # Items for use in stopping rules, normb set earlier istop = 0 normr = beta callback_returns = [] for _ in tqdm(range(max_iter), disable=not progressbar): u *= -_bc_left(alpha, u) u += A(v) beta = norm_batched(u) if xp.all(beta) > 0: u /= _bc_left(beta, u) v *= -_bc_left(beta, v) v += AT(u) alpha = norm_batched(v) if xp.all(alpha) > 0: v /= _bc_left(alpha, v) chat, shat, alphahat = _sym_ortho(alphabar, damp) rhoold = rho c, s, rho = _sym_ortho(alphahat, beta) thetanew = s * alpha alphabar = c * alpha # Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar rhobarold = rhobar zetaold = zeta thetabar = sbar * rho rhotemp = cbar * rho cbar, sbar, rhobar = _sym_ortho(cbar * rho, thetanew) zeta = cbar * zetabar zetabar = -sbar * zetabar # Update h, h_hat, x. hbar *= _bc_left(-(thetabar * rho / (rhoold * rhobarold)), hbar) hbar += h x += _bc_left((zeta / (rho * rhobar)), hbar) * hbar h *= _bc_left(-(thetanew / rho), h) h += v # Estimate of ||r||. # Apply rotation Qhat_{k,2k+1}. betaacute = chat * betadd betacheck = -shat * betadd # Apply rotation Q_{k,k+1}. betahat = c * betaacute betadd = -s * betaacute # Apply rotation Qtilde_{k-1}. # betad = betad_{k-1} here. thetatildeold = thetatilde ctildeold, stildeold, rhotildeold = _sym_ortho(rhodold, thetabar) thetatilde = stildeold * rhobar rhodold = ctildeold * rhobar betad = -stildeold * betad + ctildeold * betahat # betad = betad_k here. # rhodold = rhod_k here. tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold taud = (zeta - thetatilde * tautildeold) / rhodold d = d + betacheck * betacheck normr = xp.sqrt(d + (betad - taud) ** 2 + betadd * betadd) # Estimate ||A||. normA2 = normA2 + beta * beta normA = xp.sqrt(normA2) normA2 = normA2 + alpha * alpha # Estimate cond(A). # We use the batch dimension for getting better estimates maxrbar = xp.max(xp.maximum(maxrbar, rhobarold)) if itn > 1: minrbar = xp.min(xp.minimum(minrbar, rhobarold)) condA = xp.mean(xp.maximum(maxrbar, rhotemp) / xp.minimum(minrbar, rhotemp)) # Test for convergence. # Compute norms for convergence testing. normar = abs(zetabar) normx = norm_batched(x) # Now use these norms to estimate certain other quantities, # some of which will be small near a solution. test1 = normr / normb if xp.all((normA * normr) != 0): test2 = normar / (normA * normr) else: test2 = xp.inf test3 = 1 / condA t1 = test1 / (1 + normA * normx / normb) rtol = btol + atol * normA * normx / normb # The following tests guard against extremely small values of # atol, btol or ctol. (The user may have set any or all of # the parameters atol, btol, conlim to 0.) # The effect is equivalent to the normAl tests using # atol = eps, btol = eps, conlim = 1/eps. if callback: callback_returns.append( callback(x, operator, kspace_data, damp=damp, x0=x0) ) if xp.all(1 + test3 <= 1): istop = 6 elif xp.all(1 + test2 <= 1): istop = 5 elif xp.all(1 + t1 <= 1): istop = 4 # Allow for tolerances set by the user. elif xp.all(test3 <= ctol): istop = 3 elif xp.all(test2 <= atol): istop = 2 elif xp.all(test1 <= rtol): istop = 1 if istop: break if operator.squeeze_dims: x = operator._safe_squeeze(x) if callback_returns: return x, callback_returns return x
[docs] @register_optim @with_numpy_cupy def cg( operator: FourierOperatorBase, kspace_data: NDArray, damp: float = 0.0, x0: NDArray | None = None, x_init: NDArray | None = None, max_iter: int = 10, tol: float = 1e-4, progressbar: bool = True, callback: Callable | None = None, ): r""" Perform conjugate gradient (CG) optimization for image reconstruction. The image is updated using the gradient of a data consistency term, and a velocity vector is used to accelerate convergence. Parameters ---------- $base_params tol: float Tolerance for converge check. $returns References ---------- https://en.m.wikipedia.org/wiki/Nonlinear_conjugate_gradient_method """ lipschitz_cst = operator.get_lipschitz_cst() xp = get_array_module(kspace_data) image = ( xp.zeros(operator.img_full_shape, dtype=kspace_data.dtype) if x_init is None else x_init.reshape(operator.img_full_shape) ) velocity = xp.zeros_like(image) grad = operator.data_consistency(image, kspace_data).reshape( operator.img_full_shape ) if damp: if x0: grad += damp * (image - x0) else: grad += damp * image velocity = tol * velocity + grad / lipschitz_cst image = image - velocity callbacks_results = [] for _ in tqdm(range(max_iter), disable=not progressbar): grad_new = operator.data_consistency(image, kspace_data).reshape( operator.img_full_shape ) if damp: if x0 is not None: grad_new += damp * (image - x0.reshape(operator.img_full_shape)) else: grad_new += damp * image if xp.linalg.norm(grad_new) <= tol: break beta = xp.dot( grad_new.flatten(), (grad_new.flatten() - grad.flatten()) ) / xp.dot(grad.flatten(), grad.flatten()) beta = max(0, beta) # Polak-Ribiere formula is used to compute the beta velocity = grad_new + beta * velocity image = image - velocity / lipschitz_cst grad = grad_new if callback: callbacks_results.append( callback( image, operator, kspace_data, damp=damp, x0=x0, ) ) if operator.squeeze_dims: image = operator._safe_squeeze(image) if callbacks_results: return image, callbacks_results return image