Least Squares Image Reconstruction#

An example to show how to reconstruct volumes using the least square estimate.

This script demonstrates the use of the Conjugate Gradient (CG), LSQR and LSMR methods, to reconstruct images from non-uniform k-space data.

import os
import time

import cupy as cp
import numpy as np
from brainweb_dl import get_mri
from matplotlib import pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr

import mrinufft
from mrinufft.extras.optim import loss_l2_reg, loss_l2_AHreg


BACKEND = os.environ.get("MRINUFFT_BACKEND", "cufinufft")

Setup Inputs

samples_loc = mrinufft.initialize_2D_spiral(Nc=64, Ns=512, nb_revolutions=8)
ground_truth = get_mri(sub_id=4)
ground_truth = ground_truth[90]
# Normalize the ground truth image
ground_truth = ground_truth / np.sqrt(np.mean(abs(ground_truth) ** 2))
image_gpu = cp.array(ground_truth)  # convert to cupy array for GPU processing

print("image size: ", ground_truth.shape)
image size:  (256, 256)

Setup the NUFFT operator

NufftOperator = mrinufft.get_operator(BACKEND)  # get the operator

nufft = NufftOperator(
    samples_loc,
    shape=ground_truth.shape,
    squeeze_dims=True,
)  # create the NUFFT operator
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/venv/lib/python3.10/site-packages/mrinufft/_utils.py:76: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(

Reconstruct the image using the CG method

kspace_data_gpu = nufft.op(image_gpu)  # get the k-space data
kspace_data = kspace_data_gpu.get()  # convert back to numpy array for display
adjoint = nufft.adj_op(kspace_data_gpu).get()  # adjoint NUFFT


def mixed_cb(*args, **kwargs):
    """A compound callback function, to track iterations time and convergence."""
    return [
        time.perf_counter(),
        loss_l2_reg(*args, **kwargs),
        loss_l2_AHreg(*args, **kwargs),
        psnr(
            abs(args[0].get().squeeze()),
            abs(ground_truth.squeeze()),
            data_range=ground_truth.max(),
        ),
        time.perf_counter(),
    ]


def process_cb_results(cb_results):
    t0, r, rH, psnrs, t1 = list(zip(*cb_results))
    t1 = (t0[0], *t1[:-1])
    time_it = np.cumsum(np.array(t0) - np.array(t1))
    r = [rr.get() for rr in r]
    rH = [rr.get() for rr in rH]

    return {"time": time_it, "res": r, "AHres": rH, "psnr": psnrs}


# Run the least-square minimization for all the solvers:

OPTIM = ["cg", "lsqr", "lsmr"]
METRICS = {
    "res": r"$\|Ax-b\|$",
    "AHres": r"$\|A^H(Ax-b)\|$",
    "psnr": "PSNR",
}


images = dict()
iterations_cb = dict()
for optim in OPTIM:
    image, iter_cb = nufft.pinv_solver(
        kspace_data=kspace_data_gpu,
        max_iter=1000,
        callback=mixed_cb,
        optim=optim,
    )
    images[optim] = image.get().squeeze()  # retrieve image from GPU.
    iterations_cb[optim] = process_cb_results(iter_cb)
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/venv/lib/python3.10/site-packages/cufinufft/_plan.py:393: UserWarning: Argument `data` does not satisfy the following requirement: C. Copying array (this may reduce performance)
  warnings.warn(f"Argument `{name}` does not satisfy the "
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/venv/lib/python3.10/site-packages/mrinufft/operators/base.py:934: UserWarning: Lipschitz constant did not converge
  warnings.warn("Lipschitz constant did not converge")

  0%|          | 0/1000 [00:00<?, ?it/s]
  2%|▏         | 23/1000 [00:00<00:04, 212.66it/s]
  4%|▍         | 45/1000 [00:00<00:04, 207.74it/s]
  7%|▋         | 69/1000 [00:00<00:04, 220.68it/s]
  9%|▉         | 92/1000 [00:00<00:04, 208.39it/s]
 11%|█▏        | 113/1000 [00:00<00:04, 205.91it/s]
 14%|█▍        | 140/1000 [00:00<00:03, 226.45it/s]
 17%|█▋        | 167/1000 [00:00<00:03, 235.62it/s]
 19%|█▉        | 194/1000 [00:00<00:03, 244.18it/s]
 22%|██▏       | 224/1000 [00:00<00:02, 259.99it/s]
 25%|██▌       | 254/1000 [00:01<00:02, 271.64it/s]
 28%|██▊       | 282/1000 [00:01<00:02, 259.99it/s]
 31%|███▏      | 313/1000 [00:01<00:02, 273.16it/s]
 34%|███▍      | 343/1000 [00:01<00:02, 279.36it/s]
 37%|███▋      | 372/1000 [00:01<00:02, 271.90it/s]
 40%|████      | 400/1000 [00:01<00:02, 272.13it/s]
 43%|████▎     | 429/1000 [00:01<00:02, 277.22it/s]
 46%|████▌     | 458/1000 [00:01<00:01, 280.40it/s]
 49%|████▊     | 487/1000 [00:01<00:01, 266.33it/s]
 52%|█████▏    | 518/1000 [00:02<00:01, 276.18it/s]
 55%|█████▍    | 547/1000 [00:02<00:01, 279.46it/s]
 58%|█████▊    | 576/1000 [00:02<00:01, 273.86it/s]
 60%|██████    | 604/1000 [00:02<00:01, 270.87it/s]
 63%|██████▎   | 634/1000 [00:02<00:01, 276.49it/s]
 66%|██████▋   | 664/1000 [00:02<00:01, 282.62it/s]
 69%|██████▉   | 693/1000 [00:02<00:01, 265.51it/s]
 72%|███████▏  | 723/1000 [00:02<00:01, 273.82it/s]
 75%|███████▌  | 753/1000 [00:02<00:00, 280.65it/s]
 78%|███████▊  | 782/1000 [00:02<00:00, 271.94it/s]
 81%|████████  | 810/1000 [00:03<00:00, 269.78it/s]
 84%|████████▍ | 840/1000 [00:03<00:00, 277.50it/s]
 87%|████████▋ | 870/1000 [00:03<00:00, 281.50it/s]
 90%|████████▉ | 899/1000 [00:03<00:00, 269.27it/s]
 93%|█████████▎| 930/1000 [00:03<00:00, 280.19it/s]
 96%|█████████▌| 960/1000 [00:03<00:00, 285.81it/s]
 99%|█████████▉| 989/1000 [00:03<00:00, 277.29it/s]
100%|██████████| 1000/1000 [00:03<00:00, 265.43it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  1%|          | 11/1000 [00:00<00:09, 107.01it/s]
  3%|▎         | 26/1000 [00:00<00:07, 131.47it/s]
  4%|▍         | 40/1000 [00:00<00:07, 124.67it/s]
  6%|▌         | 56/1000 [00:00<00:06, 135.05it/s]
  7%|▋         | 72/1000 [00:00<00:06, 140.57it/s]
  9%|▊         | 87/1000 [00:00<00:06, 131.16it/s]
 10%|█         | 103/1000 [00:00<00:06, 138.97it/s]
 12%|█▏        | 118/1000 [00:00<00:06, 141.68it/s]
 13%|█▎        | 134/1000 [00:00<00:06, 137.16it/s]
 15%|█▍        | 148/1000 [00:01<00:06, 137.74it/s]
 16%|█▋        | 163/1000 [00:01<00:05, 140.50it/s]
 18%|█▊        | 179/1000 [00:01<00:05, 144.72it/s]
 19%|█▉        | 194/1000 [00:01<00:05, 136.28it/s]
 21%|██        | 209/1000 [00:01<00:05, 139.59it/s]
 22%|██▎       | 225/1000 [00:01<00:05, 143.08it/s]
 24%|██▍       | 240/1000 [00:01<00:05, 134.83it/s]
 26%|██▌       | 256/1000 [00:01<00:05, 140.80it/s]
 27%|██▋       | 272/1000 [00:01<00:05, 143.66it/s]
 29%|██▉       | 288/1000 [00:02<00:04, 144.06it/s]
 30%|███       | 303/1000 [00:02<00:04, 145.22it/s]
 32%|███▏      | 319/1000 [00:02<00:04, 147.19it/s]
 34%|███▎      | 335/1000 [00:02<00:04, 148.86it/s]
 35%|███▌      | 350/1000 [00:02<00:04, 139.59it/s]
 37%|███▋      | 366/1000 [00:02<00:04, 143.53it/s]
 38%|███▊      | 382/1000 [00:02<00:04, 146.67it/s]
 40%|███▉      | 397/1000 [00:02<00:04, 138.45it/s]
 41%|████▏     | 413/1000 [00:02<00:04, 142.71it/s]
 43%|████▎     | 429/1000 [00:03<00:03, 145.03it/s]
 44%|████▍     | 444/1000 [00:03<00:04, 137.64it/s]
 46%|████▌     | 460/1000 [00:03<00:03, 142.16it/s]
 48%|████▊     | 475/1000 [00:03<00:03, 143.93it/s]
 49%|████▉     | 491/1000 [00:03<00:03, 147.33it/s]
 50%|█████     | 502/1000 [00:03<00:03, 139.95it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  1%|          | 11/1000 [00:00<00:09, 104.28it/s]
  2%|▏         | 24/1000 [00:00<00:08, 115.34it/s]
  4%|▎         | 36/1000 [00:00<00:09, 102.13it/s]
  5%|▍         | 49/1000 [00:00<00:08, 110.55it/s]
  6%|▌         | 62/1000 [00:00<00:08, 114.44it/s]
  7%|▋         | 74/1000 [00:00<00:08, 109.26it/s]
  9%|▊         | 87/1000 [00:00<00:07, 114.32it/s]
 10%|█         | 101/1000 [00:00<00:07, 121.71it/s]
 11%|█▏        | 114/1000 [00:01<00:07, 114.05it/s]
 13%|█▎        | 127/1000 [00:01<00:07, 115.78it/s]
 14%|█▍        | 140/1000 [00:01<00:07, 118.05it/s]
 15%|█▌        | 152/1000 [00:01<00:07, 110.09it/s]
 16%|█▋        | 164/1000 [00:01<00:07, 110.85it/s]
 18%|█▊        | 177/1000 [00:01<00:07, 114.56it/s]
 19%|█▉        | 190/1000 [00:01<00:06, 115.85it/s]
 20%|██        | 202/1000 [00:01<00:07, 108.21it/s]
 22%|██▏       | 215/1000 [00:01<00:07, 112.02it/s]
 23%|██▎       | 227/1000 [00:02<00:06, 113.70it/s]
 24%|██▍       | 239/1000 [00:02<00:07, 106.39it/s]
 25%|██▌       | 252/1000 [00:02<00:06, 111.26it/s]
 26%|██▋       | 264/1000 [00:02<00:06, 113.20it/s]
 28%|██▊       | 276/1000 [00:02<00:06, 106.52it/s]
 29%|██▉       | 288/1000 [00:02<00:06, 109.95it/s]
 30%|███       | 300/1000 [00:02<00:06, 112.44it/s]
 31%|███       | 312/1000 [00:02<00:06, 104.95it/s]
 32%|███▏      | 324/1000 [00:02<00:06, 108.85it/s]
 34%|███▍      | 338/1000 [00:03<00:05, 115.49it/s]
 35%|███▌      | 351/1000 [00:03<00:05, 119.00it/s]
 36%|███▋      | 364/1000 [00:03<00:05, 114.33it/s]
 38%|███▊      | 378/1000 [00:03<00:05, 119.63it/s]
 39%|███▉      | 391/1000 [00:03<00:05, 119.52it/s]
 40%|████      | 404/1000 [00:03<00:05, 111.13it/s]
 42%|████▏     | 416/1000 [00:03<00:05, 113.01it/s]
 43%|████▎     | 429/1000 [00:03<00:04, 115.50it/s]
 44%|████▍     | 441/1000 [00:03<00:05, 108.02it/s]
 45%|████▌     | 453/1000 [00:04<00:04, 110.83it/s]
 46%|████▋     | 465/1000 [00:04<00:04, 112.90it/s]
 48%|████▊     | 477/1000 [00:04<00:04, 108.87it/s]
 49%|████▉     | 488/1000 [00:04<00:04, 108.66it/s]
 50%|█████     | 500/1000 [00:04<00:04, 110.85it/s]
 51%|█████▏    | 513/1000 [00:04<00:04, 114.82it/s]
 52%|█████▎    | 525/1000 [00:04<00:04, 105.51it/s]
 54%|█████▎    | 537/1000 [00:04<00:04, 109.16it/s]
 55%|█████▌    | 550/1000 [00:04<00:03, 113.70it/s]
 55%|█████▌    | 551/1000 [00:04<00:04, 112.03it/s]

Display Convergence#

fig, axs = plt.subplots(len(METRICS), 1, sharex=True, figsize=(8, 12))
for i, metric in enumerate(METRICS):
    for optim in OPTIM:
        if "res" in metric:
            axs[i].set_yscale("log")
        axs[i].plot(
            iterations_cb[optim]["time"],
            iterations_cb[optim][metric],
            marker="o",
            markevery=20,
            label=f"{optim} {np.mean(1/np.diff(iterations_cb[optim]['time'])):.2f}iters/s",
        )
    axs[i].grid()
    axs[i].set_ylabel(METRICS[metric])
axs[0].legend()
axs[-1].set_xlabel("time (s)")
fig.tight_layout()
plt.show()
example pinv

Display images#

fig, axs = plt.subplots(1, len(OPTIM) + 2, figsize=(20, 7))

for i, optim in enumerate(OPTIM):
    axs[i].imshow(abs(images[optim]), cmap="gray", origin="lower")
    axs[i].axis("off")
    axs[i].set_title(
        f"{optim} reconstruction\n PSNR: {iterations_cb[optim]['psnr'][-1]:.2f}dB \n"
        f"{len(iterations_cb[optim]['time'])} iters ({iterations_cb[optim]['time'][-1]:.2f}s)"
    )

axs[-1].imshow(abs(ground_truth), cmap="gray", origin="lower")
axs[-1].axis("off")
axs[-1].set_title("Original image")
axs[-2].imshow(
    abs(adjoint),
    cmap="gray",
    origin="lower",
)
axs[-2].axis("off")
axs[-2].set_title(
    f"Adjoint NUFFT \n PSNR: {psnr(abs(adjoint), abs(ground_truth), data_range=ground_truth.max()):.2f}dB"
)

fig.suptitle("Reconstructed images using different optimizers")
fig.tight_layout()
plt.show()
Reconstructed images using different optimizers, cg reconstruction  PSNR: 26.39dB  1000 iters (1.69s), lsqr reconstruction  PSNR: 26.54dB  503 iters (2.50s), lsmr reconstruction  PSNR: 26.53dB  552 iters (3.59s), Adjoint NUFFT   PSNR: -18.57dB, Original image

Using a damping regularization term#

The least-square problem can be regularized using a damping term to improve the conditioning of the problem. This is done by solving the following optimization problem:

\[\min_x \|Ax - b\|_2^2 + \gamma \|x\|_2^2 where :math:`\gamma` is the regularization parameter.\]
images = dict()
iterations_cb = dict()
for optim in OPTIM:
    image, iter_cb = nufft.pinv_solver(
        kspace_data=kspace_data_gpu,
        max_iter=1000,
        callback=mixed_cb,
        damp=0.1,
        optim=optim,
    )
    images[optim] = image.get().squeeze()  # retrieve image from GPU.
    iterations_cb[optim] = process_cb_results(iter_cb)
  0%|          | 0/1000 [00:00<?, ?it/s]
  2%|▏         | 24/1000 [00:00<00:04, 234.06it/s]
  5%|▌         | 51/1000 [00:00<00:03, 254.13it/s]
  8%|▊         | 78/1000 [00:00<00:03, 258.74it/s]
 10%|█         | 104/1000 [00:00<00:03, 242.00it/s]
 13%|█▎        | 131/1000 [00:00<00:03, 251.33it/s]
 16%|█▌        | 159/1000 [00:00<00:03, 260.28it/s]
 19%|█▊        | 186/1000 [00:00<00:03, 255.93it/s]
 21%|██        | 212/1000 [00:00<00:03, 257.00it/s]
 24%|██▍       | 241/1000 [00:00<00:02, 266.67it/s]
 27%|██▋       | 270/1000 [00:01<00:02, 273.49it/s]
 30%|██▉       | 298/1000 [00:01<00:02, 263.42it/s]
 33%|███▎      | 327/1000 [00:01<00:02, 270.29it/s]
 36%|███▌      | 355/1000 [00:01<00:02, 271.49it/s]
 38%|███▊      | 383/1000 [00:01<00:02, 255.01it/s]
 41%|████      | 411/1000 [00:01<00:02, 260.19it/s]
 44%|████▍     | 439/1000 [00:01<00:02, 264.02it/s]
 47%|████▋     | 467/1000 [00:01<00:01, 268.55it/s]
 49%|████▉     | 494/1000 [00:01<00:01, 256.18it/s]
 52%|█████▏    | 522/1000 [00:02<00:01, 262.88it/s]
 55%|█████▌    | 551/1000 [00:02<00:01, 268.61it/s]
 58%|█████▊    | 578/1000 [00:02<00:01, 252.75it/s]
 60%|██████    | 605/1000 [00:02<00:01, 257.38it/s]
 63%|██████▎   | 634/1000 [00:02<00:01, 265.00it/s]
 66%|██████▌   | 661/1000 [00:02<00:01, 260.09it/s]
 69%|██████▉   | 688/1000 [00:02<00:01, 255.85it/s]
 72%|███████▏  | 717/1000 [00:02<00:01, 263.55it/s]
 75%|███████▍  | 746/1000 [00:02<00:00, 269.14it/s]
 77%|███████▋  | 773/1000 [00:02<00:00, 254.90it/s]
 80%|████████  | 802/1000 [00:03<00:00, 262.79it/s]
 83%|████████▎ | 831/1000 [00:03<00:00, 268.80it/s]
 86%|████████▌ | 859/1000 [00:03<00:00, 262.56it/s]
 89%|████████▊ | 887/1000 [00:03<00:00, 266.40it/s]
 92%|█████████▏| 917/1000 [00:03<00:00, 275.93it/s]
 95%|█████████▍| 946/1000 [00:03<00:00, 275.32it/s]
 97%|█████████▋| 974/1000 [00:03<00:00, 260.68it/s]
100%|██████████| 1000/1000 [00:03<00:00, 262.35it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  1%|▏         | 14/1000 [00:00<00:07, 139.08it/s]
  3%|▎         | 28/1000 [00:00<00:08, 119.58it/s]
  4%|▍         | 43/1000 [00:00<00:07, 130.40it/s]
  6%|▌         | 58/1000 [00:00<00:06, 136.25it/s]
  7%|▋         | 72/1000 [00:00<00:07, 124.41it/s]
  9%|▊         | 87/1000 [00:00<00:07, 130.33it/s]
 10%|█         | 102/1000 [00:00<00:06, 134.92it/s]
 11%|█         | 106/1000 [00:00<00:06, 130.44it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  1%|          | 8/1000 [00:00<00:12, 79.19it/s]
  2%|▏         | 20/1000 [00:00<00:09, 99.75it/s]
  3%|▎         | 32/1000 [00:00<00:09, 105.25it/s]
  4%|▍         | 43/1000 [00:00<00:09, 96.43it/s]
  6%|▌         | 55/1000 [00:00<00:09, 101.73it/s]
  7%|▋         | 66/1000 [00:00<00:09, 103.16it/s]
  8%|▊         | 77/1000 [00:00<00:09, 97.52it/s]
  9%|▉         | 89/1000 [00:00<00:08, 102.18it/s]
  9%|▉         | 91/1000 [00:00<00:09, 100.10it/s]

Display Convergence#

fig, axs = plt.subplots(len(METRICS), 1, sharex=True, figsize=(8, 12))
for i, metric in enumerate(METRICS):
    for optim in OPTIM:
        if "res" in metric:
            axs[i].set_yscale("log")
        axs[i].plot(
            iterations_cb[optim]["time"],
            iterations_cb[optim][metric],
            marker="o",
            markevery=20,
            label=f"{optim} {np.mean(1/np.diff(iterations_cb[optim]['time'])):.2f}iters/s",
        )
    axs[i].grid()
    axs[i].set_ylabel(METRICS[metric])
axs[0].legend()
axs[-1].set_xlabel("time (s)")
fig.tight_layout()
plt.show()
example pinv

Display images#

fig, axs = plt.subplots(1, len(OPTIM) + 2, figsize=(20, 7))

for i, optim in enumerate(OPTIM):
    axs[i].imshow(abs(images[optim]), cmap="gray", origin="lower")
    axs[i].axis("off")
    axs[i].set_title(
        f"{optim} reconstruction\n PSNR: {iterations_cb[optim]['psnr'][-1]:.2f}dB \n"
        f"{len(iterations_cb[optim]['time'])} iters ({iterations_cb[optim]['time'][-1]:.2f}s)"
    )

axs[-1].imshow(abs(ground_truth), cmap="gray", origin="lower")
axs[-1].axis("off")
axs[-1].set_title("Original image")
axs[-2].imshow(
    abs(adjoint),
    cmap="gray",
    origin="lower",
)
axs[-2].axis("off")
axs[-2].set_title(
    f"Adjoint NUFFT \n PSNR: {psnr(abs(adjoint), abs(ground_truth), data_range=ground_truth.max()):.2f}dB"
)

fig.suptitle("Reconstructed images using different optimizers")
fig.tight_layout()
plt.show()
Reconstructed images using different optimizers, cg reconstruction  PSNR: 26.01dB  1000 iters (1.64s), lsqr reconstruction  PSNR: 26.40dB  107 iters (0.57s), lsmr reconstruction  PSNR: 26.40dB  92 iters (0.66s), Adjoint NUFFT   PSNR: -18.57dB, Original image

Total running time of the script: (0 minutes 21.450 seconds)

Gallery generated by Sphinx-Gallery