Source code for mrinufft.extras.gradient
"""Conjugate gradient optimization algorithm for image reconstruction."""
from mrinufft._array_compat import with_numpy_cupy
from mrinufft._utils import get_array_module
from tqdm import tqdm
[docs]
@with_numpy_cupy
def cg(
operator,
kspace_data,
x_init=None,
num_iter=10,
tol=1e-4,
compute_loss=False,
progressbar=True,
):
"""
Perform conjugate gradient (CG) optimization for image reconstruction.
The image is updated using the gradient of a data consistency term,
and a velocity vector is used to accelerate convergence.
Parameters
----------
kspace_data : numpy.ndarray
The k-space data to be used for image reconstruction.
x_init : numpy.ndarray, optional
An initial guess for the image. If None, an image of zeros with the same
shape as the expected output is used. Default is None.
num_iter : int, optional
The maximum number of iterations to perform. Default is 10.
tol : float, optional
The tolerance for convergence. If the norm of the gradient falls below
this value or the dot product between the image and k-space data is
non-positive, the iterations stop. Default is 1e-4.
Returns
-------
image : numpy.ndarray
The reconstructed image after the optimization process.
"""
lipschitz_cst = operator.get_lipschitz_cst()
xp = get_array_module(kspace_data)
if operator.uses_sense:
init_shape = (operator.n_batchs, 1, *operator.shape)
else:
init_shape = (operator.n_batchs, operator.n_coils, *operator.shape)
image = (
xp.zeros(init_shape, dtype=kspace_data.dtype)
if x_init is None
else x_init.reshape(init_shape)
)
velocity = xp.zeros_like(image)
grad = operator.data_consistency(image, kspace_data)
velocity = tol * velocity + grad / lipschitz_cst
image = image - velocity
def calculate_loss(image):
residual = operator.op(image) - kspace_data
return xp.linalg.norm(residual) ** 2
loss = [calculate_loss(image)] if compute_loss else None
iterator = range(num_iter)
if progressbar:
iterator = tqdm(iterator)
for _ in iterator:
grad_new = operator.data_consistency(image, kspace_data)
if xp.linalg.norm(grad_new) <= tol:
break
beta = xp.dot(
grad_new.flatten(), (grad_new.flatten() - grad.flatten())
) / xp.dot(grad.flatten(), grad.flatten())
beta = max(0, beta) # Polak-Ribiere formula is used to compute the beta
velocity = grad_new + beta * velocity
image = image - velocity / lipschitz_cst
grad = grad_new
if compute_loss:
loss.append(calculate_loss(image))
if operator.squeeze_dims:
image = operator._safe_squeeze(image)
return image if loss is None else (image, xp.array(loss))