Note
Go to the end to download the full example code. or to run this example in your browser via Binder
Least Squares Image Reconstruction#
An example to show how to reconstruct volumes using the least square estimate.
This script demonstrates the use of the Conjugate Gradient (CG), LSQR and LSMR methods, to reconstruct images from non-uniform k-space data.
import os
import time
import cupy as cp
import numpy as np
from brainweb_dl import get_mri
from matplotlib import pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
import mrinufft
from mrinufft.extras.optim import loss_l2_reg, loss_l2_AHreg
BACKEND = os.environ.get("MRINUFFT_BACKEND", "cufinufft")
Setup Inputs
samples_loc = mrinufft.initialize_2D_spiral(Nc=64, Ns=512, nb_revolutions=8)
ground_truth = get_mri(sub_id=4)
ground_truth = ground_truth[90]
# Normalize the ground truth image
ground_truth = ground_truth / np.sqrt(np.mean(abs(ground_truth) ** 2))
image_gpu = cp.array(ground_truth)  # convert to cupy array for GPU processing
print("image size: ", ground_truth.shape)
image size:  (256, 256)
Setup the NUFFT operator
NufftOperator = mrinufft.get_operator(BACKEND)  # get the operator
nufft = NufftOperator(
    samples_loc,
    shape=ground_truth.shape,
    squeeze_dims=True,
)  # create the NUFFT operator
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/venv/lib/python3.10/site-packages/mrinufft/_utils.py:76: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
Reconstruct the image using the CG method
kspace_data_gpu = nufft.op(image_gpu)  # get the k-space data
kspace_data = kspace_data_gpu.get()  # convert back to numpy array for display
adjoint = nufft.adj_op(kspace_data_gpu).get()  # adjoint NUFFT
def mixed_cb(*args, **kwargs):
    """A compound callback function, to track iterations time and convergence."""
    return [
        time.perf_counter(),
        loss_l2_reg(*args, **kwargs),
        loss_l2_AHreg(*args, **kwargs),
        psnr(
            abs(args[0].get().squeeze()),
            abs(ground_truth.squeeze()),
            data_range=ground_truth.max(),
        ),
        time.perf_counter(),
    ]
def process_cb_results(cb_results):
    t0, r, rH, psnrs, t1 = list(zip(*cb_results))
    t1 = (t0[0], *t1[:-1])
    time_it = np.cumsum(np.array(t0) - np.array(t1))
    r = [rr.get() for rr in r]
    rH = [rr.get() for rr in rH]
    return {"time": time_it, "res": r, "AHres": rH, "psnr": psnrs}
# Run the least-square minimization for all the solvers:
OPTIM = ["cg", "lsqr", "lsmr"]
METRICS = {
    "res": r"$\|Ax-b\|$",
    "AHres": r"$\|A^H(Ax-b)\|$",
    "psnr": "PSNR",
}
images = dict()
iterations_cb = dict()
for optim in OPTIM:
    image, iter_cb = nufft.pinv_solver(
        kspace_data=kspace_data_gpu,
        max_iter=1000,
        callback=mixed_cb,
        optim=optim,
    )
    images[optim] = image.get().squeeze()  # retrieve image from GPU.
    iterations_cb[optim] = process_cb_results(iter_cb)
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/venv/lib/python3.10/site-packages/cufinufft/_plan.py:393: UserWarning: Argument `data` does not satisfy the following requirement: C. Copying array (this may reduce performance)
  warnings.warn(f"Argument `{name}` does not satisfy the "
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/venv/lib/python3.10/site-packages/mrinufft/operators/base.py:934: UserWarning: Lipschitz constant did not converge
  warnings.warn("Lipschitz constant did not converge")
  0%|          | 0/1000 [00:00<?, ?it/s]
  2%|▏         | 23/1000 [00:00<00:04, 212.66it/s]
  4%|▍         | 45/1000 [00:00<00:04, 207.74it/s]
  7%|▋         | 69/1000 [00:00<00:04, 220.68it/s]
  9%|▉         | 92/1000 [00:00<00:04, 208.39it/s]
 11%|█▏        | 113/1000 [00:00<00:04, 205.91it/s]
 14%|█▍        | 140/1000 [00:00<00:03, 226.45it/s]
 17%|█▋        | 167/1000 [00:00<00:03, 235.62it/s]
 19%|█▉        | 194/1000 [00:00<00:03, 244.18it/s]
 22%|██▏       | 224/1000 [00:00<00:02, 259.99it/s]
 25%|██▌       | 254/1000 [00:01<00:02, 271.64it/s]
 28%|██▊       | 282/1000 [00:01<00:02, 259.99it/s]
 31%|███▏      | 313/1000 [00:01<00:02, 273.16it/s]
 34%|███▍      | 343/1000 [00:01<00:02, 279.36it/s]
 37%|███▋      | 372/1000 [00:01<00:02, 271.90it/s]
 40%|████      | 400/1000 [00:01<00:02, 272.13it/s]
 43%|████▎     | 429/1000 [00:01<00:02, 277.22it/s]
 46%|████▌     | 458/1000 [00:01<00:01, 280.40it/s]
 49%|████▊     | 487/1000 [00:01<00:01, 266.33it/s]
 52%|█████▏    | 518/1000 [00:02<00:01, 276.18it/s]
 55%|█████▍    | 547/1000 [00:02<00:01, 279.46it/s]
 58%|█████▊    | 576/1000 [00:02<00:01, 273.86it/s]
 60%|██████    | 604/1000 [00:02<00:01, 270.87it/s]
 63%|██████▎   | 634/1000 [00:02<00:01, 276.49it/s]
 66%|██████▋   | 664/1000 [00:02<00:01, 282.62it/s]
 69%|██████▉   | 693/1000 [00:02<00:01, 265.51it/s]
 72%|███████▏  | 723/1000 [00:02<00:01, 273.82it/s]
 75%|███████▌  | 753/1000 [00:02<00:00, 280.65it/s]
 78%|███████▊  | 782/1000 [00:02<00:00, 271.94it/s]
 81%|████████  | 810/1000 [00:03<00:00, 269.78it/s]
 84%|████████▍ | 840/1000 [00:03<00:00, 277.50it/s]
 87%|████████▋ | 870/1000 [00:03<00:00, 281.50it/s]
 90%|████████▉ | 899/1000 [00:03<00:00, 269.27it/s]
 93%|█████████▎| 930/1000 [00:03<00:00, 280.19it/s]
 96%|█████████▌| 960/1000 [00:03<00:00, 285.81it/s]
 99%|█████████▉| 989/1000 [00:03<00:00, 277.29it/s]
100%|██████████| 1000/1000 [00:03<00:00, 265.43it/s]
  0%|          | 0/1000 [00:00<?, ?it/s]
  1%|          | 11/1000 [00:00<00:09, 107.01it/s]
  3%|▎         | 26/1000 [00:00<00:07, 131.47it/s]
  4%|▍         | 40/1000 [00:00<00:07, 124.67it/s]
  6%|▌         | 56/1000 [00:00<00:06, 135.05it/s]
  7%|▋         | 72/1000 [00:00<00:06, 140.57it/s]
  9%|▊         | 87/1000 [00:00<00:06, 131.16it/s]
 10%|█         | 103/1000 [00:00<00:06, 138.97it/s]
 12%|█▏        | 118/1000 [00:00<00:06, 141.68it/s]
 13%|█▎        | 134/1000 [00:00<00:06, 137.16it/s]
 15%|█▍        | 148/1000 [00:01<00:06, 137.74it/s]
 16%|█▋        | 163/1000 [00:01<00:05, 140.50it/s]
 18%|█▊        | 179/1000 [00:01<00:05, 144.72it/s]
 19%|█▉        | 194/1000 [00:01<00:05, 136.28it/s]
 21%|██        | 209/1000 [00:01<00:05, 139.59it/s]
 22%|██▎       | 225/1000 [00:01<00:05, 143.08it/s]
 24%|██▍       | 240/1000 [00:01<00:05, 134.83it/s]
 26%|██▌       | 256/1000 [00:01<00:05, 140.80it/s]
 27%|██▋       | 272/1000 [00:01<00:05, 143.66it/s]
 29%|██▉       | 288/1000 [00:02<00:04, 144.06it/s]
 30%|███       | 303/1000 [00:02<00:04, 145.22it/s]
 32%|███▏      | 319/1000 [00:02<00:04, 147.19it/s]
 34%|███▎      | 335/1000 [00:02<00:04, 148.86it/s]
 35%|███▌      | 350/1000 [00:02<00:04, 139.59it/s]
 37%|███▋      | 366/1000 [00:02<00:04, 143.53it/s]
 38%|███▊      | 382/1000 [00:02<00:04, 146.67it/s]
 40%|███▉      | 397/1000 [00:02<00:04, 138.45it/s]
 41%|████▏     | 413/1000 [00:02<00:04, 142.71it/s]
 43%|████▎     | 429/1000 [00:03<00:03, 145.03it/s]
 44%|████▍     | 444/1000 [00:03<00:04, 137.64it/s]
 46%|████▌     | 460/1000 [00:03<00:03, 142.16it/s]
 48%|████▊     | 475/1000 [00:03<00:03, 143.93it/s]
 49%|████▉     | 491/1000 [00:03<00:03, 147.33it/s]
 50%|█████     | 502/1000 [00:03<00:03, 139.95it/s]
  0%|          | 0/1000 [00:00<?, ?it/s]
  1%|          | 11/1000 [00:00<00:09, 104.28it/s]
  2%|▏         | 24/1000 [00:00<00:08, 115.34it/s]
  4%|▎         | 36/1000 [00:00<00:09, 102.13it/s]
  5%|▍         | 49/1000 [00:00<00:08, 110.55it/s]
  6%|▌         | 62/1000 [00:00<00:08, 114.44it/s]
  7%|▋         | 74/1000 [00:00<00:08, 109.26it/s]
  9%|▊         | 87/1000 [00:00<00:07, 114.32it/s]
 10%|█         | 101/1000 [00:00<00:07, 121.71it/s]
 11%|█▏        | 114/1000 [00:01<00:07, 114.05it/s]
 13%|█▎        | 127/1000 [00:01<00:07, 115.78it/s]
 14%|█▍        | 140/1000 [00:01<00:07, 118.05it/s]
 15%|█▌        | 152/1000 [00:01<00:07, 110.09it/s]
 16%|█▋        | 164/1000 [00:01<00:07, 110.85it/s]
 18%|█▊        | 177/1000 [00:01<00:07, 114.56it/s]
 19%|█▉        | 190/1000 [00:01<00:06, 115.85it/s]
 20%|██        | 202/1000 [00:01<00:07, 108.21it/s]
 22%|██▏       | 215/1000 [00:01<00:07, 112.02it/s]
 23%|██▎       | 227/1000 [00:02<00:06, 113.70it/s]
 24%|██▍       | 239/1000 [00:02<00:07, 106.39it/s]
 25%|██▌       | 252/1000 [00:02<00:06, 111.26it/s]
 26%|██▋       | 264/1000 [00:02<00:06, 113.20it/s]
 28%|██▊       | 276/1000 [00:02<00:06, 106.52it/s]
 29%|██▉       | 288/1000 [00:02<00:06, 109.95it/s]
 30%|███       | 300/1000 [00:02<00:06, 112.44it/s]
 31%|███       | 312/1000 [00:02<00:06, 104.95it/s]
 32%|███▏      | 324/1000 [00:02<00:06, 108.85it/s]
 34%|███▍      | 338/1000 [00:03<00:05, 115.49it/s]
 35%|███▌      | 351/1000 [00:03<00:05, 119.00it/s]
 36%|███▋      | 364/1000 [00:03<00:05, 114.33it/s]
 38%|███▊      | 378/1000 [00:03<00:05, 119.63it/s]
 39%|███▉      | 391/1000 [00:03<00:05, 119.52it/s]
 40%|████      | 404/1000 [00:03<00:05, 111.13it/s]
 42%|████▏     | 416/1000 [00:03<00:05, 113.01it/s]
 43%|████▎     | 429/1000 [00:03<00:04, 115.50it/s]
 44%|████▍     | 441/1000 [00:03<00:05, 108.02it/s]
 45%|████▌     | 453/1000 [00:04<00:04, 110.83it/s]
 46%|████▋     | 465/1000 [00:04<00:04, 112.90it/s]
 48%|████▊     | 477/1000 [00:04<00:04, 108.87it/s]
 49%|████▉     | 488/1000 [00:04<00:04, 108.66it/s]
 50%|█████     | 500/1000 [00:04<00:04, 110.85it/s]
 51%|█████▏    | 513/1000 [00:04<00:04, 114.82it/s]
 52%|█████▎    | 525/1000 [00:04<00:04, 105.51it/s]
 54%|█████▎    | 537/1000 [00:04<00:04, 109.16it/s]
 55%|█████▌    | 550/1000 [00:04<00:03, 113.70it/s]
 55%|█████▌    | 551/1000 [00:04<00:04, 112.03it/s]
Display Convergence#
fig, axs = plt.subplots(len(METRICS), 1, sharex=True, figsize=(8, 12))
for i, metric in enumerate(METRICS):
    for optim in OPTIM:
        if "res" in metric:
            axs[i].set_yscale("log")
        axs[i].plot(
            iterations_cb[optim]["time"],
            iterations_cb[optim][metric],
            marker="o",
            markevery=20,
            label=f"{optim} {np.mean(1/np.diff(iterations_cb[optim]['time'])):.2f}iters/s",
        )
    axs[i].grid()
    axs[i].set_ylabel(METRICS[metric])
axs[0].legend()
axs[-1].set_xlabel("time (s)")
fig.tight_layout()
plt.show()

Display images#
fig, axs = plt.subplots(1, len(OPTIM) + 2, figsize=(20, 7))
for i, optim in enumerate(OPTIM):
    axs[i].imshow(abs(images[optim]), cmap="gray", origin="lower")
    axs[i].axis("off")
    axs[i].set_title(
        f"{optim} reconstruction\n PSNR: {iterations_cb[optim]['psnr'][-1]:.2f}dB \n"
        f"{len(iterations_cb[optim]['time'])} iters ({iterations_cb[optim]['time'][-1]:.2f}s)"
    )
axs[-1].imshow(abs(ground_truth), cmap="gray", origin="lower")
axs[-1].axis("off")
axs[-1].set_title("Original image")
axs[-2].imshow(
    abs(adjoint),
    cmap="gray",
    origin="lower",
)
axs[-2].axis("off")
axs[-2].set_title(
    f"Adjoint NUFFT \n PSNR: {psnr(abs(adjoint), abs(ground_truth), data_range=ground_truth.max()):.2f}dB"
)
fig.suptitle("Reconstructed images using different optimizers")
fig.tight_layout()
plt.show()

Using a damping regularization term#
The least-square problem can be regularized using a damping term to improve the conditioning of the problem. This is done by solving the following optimization problem:
\[\min_x \|Ax - b\|_2^2 + \gamma \|x\|_2^2
where :math:`\gamma` is the regularization parameter.\]
images = dict()
iterations_cb = dict()
for optim in OPTIM:
    image, iter_cb = nufft.pinv_solver(
        kspace_data=kspace_data_gpu,
        max_iter=1000,
        callback=mixed_cb,
        damp=0.1,
        optim=optim,
    )
    images[optim] = image.get().squeeze()  # retrieve image from GPU.
    iterations_cb[optim] = process_cb_results(iter_cb)
  0%|          | 0/1000 [00:00<?, ?it/s]
  2%|▏         | 24/1000 [00:00<00:04, 234.06it/s]
  5%|▌         | 51/1000 [00:00<00:03, 254.13it/s]
  8%|▊         | 78/1000 [00:00<00:03, 258.74it/s]
 10%|█         | 104/1000 [00:00<00:03, 242.00it/s]
 13%|█▎        | 131/1000 [00:00<00:03, 251.33it/s]
 16%|█▌        | 159/1000 [00:00<00:03, 260.28it/s]
 19%|█▊        | 186/1000 [00:00<00:03, 255.93it/s]
 21%|██        | 212/1000 [00:00<00:03, 257.00it/s]
 24%|██▍       | 241/1000 [00:00<00:02, 266.67it/s]
 27%|██▋       | 270/1000 [00:01<00:02, 273.49it/s]
 30%|██▉       | 298/1000 [00:01<00:02, 263.42it/s]
 33%|███▎      | 327/1000 [00:01<00:02, 270.29it/s]
 36%|███▌      | 355/1000 [00:01<00:02, 271.49it/s]
 38%|███▊      | 383/1000 [00:01<00:02, 255.01it/s]
 41%|████      | 411/1000 [00:01<00:02, 260.19it/s]
 44%|████▍     | 439/1000 [00:01<00:02, 264.02it/s]
 47%|████▋     | 467/1000 [00:01<00:01, 268.55it/s]
 49%|████▉     | 494/1000 [00:01<00:01, 256.18it/s]
 52%|█████▏    | 522/1000 [00:02<00:01, 262.88it/s]
 55%|█████▌    | 551/1000 [00:02<00:01, 268.61it/s]
 58%|█████▊    | 578/1000 [00:02<00:01, 252.75it/s]
 60%|██████    | 605/1000 [00:02<00:01, 257.38it/s]
 63%|██████▎   | 634/1000 [00:02<00:01, 265.00it/s]
 66%|██████▌   | 661/1000 [00:02<00:01, 260.09it/s]
 69%|██████▉   | 688/1000 [00:02<00:01, 255.85it/s]
 72%|███████▏  | 717/1000 [00:02<00:01, 263.55it/s]
 75%|███████▍  | 746/1000 [00:02<00:00, 269.14it/s]
 77%|███████▋  | 773/1000 [00:02<00:00, 254.90it/s]
 80%|████████  | 802/1000 [00:03<00:00, 262.79it/s]
 83%|████████▎ | 831/1000 [00:03<00:00, 268.80it/s]
 86%|████████▌ | 859/1000 [00:03<00:00, 262.56it/s]
 89%|████████▊ | 887/1000 [00:03<00:00, 266.40it/s]
 92%|█████████▏| 917/1000 [00:03<00:00, 275.93it/s]
 95%|█████████▍| 946/1000 [00:03<00:00, 275.32it/s]
 97%|█████████▋| 974/1000 [00:03<00:00, 260.68it/s]
100%|██████████| 1000/1000 [00:03<00:00, 262.35it/s]
  0%|          | 0/1000 [00:00<?, ?it/s]
  1%|▏         | 14/1000 [00:00<00:07, 139.08it/s]
  3%|▎         | 28/1000 [00:00<00:08, 119.58it/s]
  4%|▍         | 43/1000 [00:00<00:07, 130.40it/s]
  6%|▌         | 58/1000 [00:00<00:06, 136.25it/s]
  7%|▋         | 72/1000 [00:00<00:07, 124.41it/s]
  9%|▊         | 87/1000 [00:00<00:07, 130.33it/s]
 10%|█         | 102/1000 [00:00<00:06, 134.92it/s]
 11%|█         | 106/1000 [00:00<00:06, 130.44it/s]
  0%|          | 0/1000 [00:00<?, ?it/s]
  1%|          | 8/1000 [00:00<00:12, 79.19it/s]
  2%|▏         | 20/1000 [00:00<00:09, 99.75it/s]
  3%|▎         | 32/1000 [00:00<00:09, 105.25it/s]
  4%|▍         | 43/1000 [00:00<00:09, 96.43it/s]
  6%|▌         | 55/1000 [00:00<00:09, 101.73it/s]
  7%|▋         | 66/1000 [00:00<00:09, 103.16it/s]
  8%|▊         | 77/1000 [00:00<00:09, 97.52it/s]
  9%|▉         | 89/1000 [00:00<00:08, 102.18it/s]
  9%|▉         | 91/1000 [00:00<00:09, 100.10it/s]
Display Convergence#
fig, axs = plt.subplots(len(METRICS), 1, sharex=True, figsize=(8, 12))
for i, metric in enumerate(METRICS):
    for optim in OPTIM:
        if "res" in metric:
            axs[i].set_yscale("log")
        axs[i].plot(
            iterations_cb[optim]["time"],
            iterations_cb[optim][metric],
            marker="o",
            markevery=20,
            label=f"{optim} {np.mean(1/np.diff(iterations_cb[optim]['time'])):.2f}iters/s",
        )
    axs[i].grid()
    axs[i].set_ylabel(METRICS[metric])
axs[0].legend()
axs[-1].set_xlabel("time (s)")
fig.tight_layout()
plt.show()

Display images#
fig, axs = plt.subplots(1, len(OPTIM) + 2, figsize=(20, 7))
for i, optim in enumerate(OPTIM):
    axs[i].imshow(abs(images[optim]), cmap="gray", origin="lower")
    axs[i].axis("off")
    axs[i].set_title(
        f"{optim} reconstruction\n PSNR: {iterations_cb[optim]['psnr'][-1]:.2f}dB \n"
        f"{len(iterations_cb[optim]['time'])} iters ({iterations_cb[optim]['time'][-1]:.2f}s)"
    )
axs[-1].imshow(abs(ground_truth), cmap="gray", origin="lower")
axs[-1].axis("off")
axs[-1].set_title("Original image")
axs[-2].imshow(
    abs(adjoint),
    cmap="gray",
    origin="lower",
)
axs[-2].axis("off")
axs[-2].set_title(
    f"Adjoint NUFFT \n PSNR: {psnr(abs(adjoint), abs(ground_truth), data_range=ground_truth.max()):.2f}dB"
)
fig.suptitle("Reconstructed images using different optimizers")
fig.tight_layout()
plt.show()

Total running time of the script: (0 minutes 21.450 seconds)