Least Squares Image Reconstruction#

An example to show how to reconstruct volumes using the least square estimate.

This script demonstrates the use of the Conjugate Gradient (CG), LSQR and LSMR methods, to reconstruct images from non-uniform k-space data.

import os
import time
from tqdm.auto import tqdm
import cupy as cp
import numpy as np
from brainweb_dl import get_mri
from matplotlib import pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr

import mrinufft
from mrinufft.extras.optim import loss_l2_reg, loss_l2_AHreg

BACKEND = os.environ.get("MRINUFFT_BACKEND", "cufinufft")

Setup Inputs

samples_loc = mrinufft.initialize_2D_spiral(Nc=64, Ns=512, nb_revolutions=8)
ground_truth = get_mri(sub_id=4)
ground_truth = ground_truth[90]
# Normalize the ground truth image
ground_truth = ground_truth / np.sqrt(np.mean(abs(ground_truth) ** 2))
image_gpu = cp.array(ground_truth)  # convert to cupy array for GPU processing

print("image size: ", ground_truth.shape)
image size:  (256, 256)

Setup the NUFFT operator

NufftOperator = mrinufft.get_operator(BACKEND)  # get the operator

nufft = NufftOperator(
    samples_loc,
    shape=ground_truth.shape,
    squeeze_dims=True,
)  # create the NUFFT operator
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(

Reconstruct the image using the CG method

kspace_data_gpu = nufft.op(image_gpu)  # get the k-space data
kspace_data = kspace_data_gpu.get()  # convert back to numpy array for display
adjoint = nufft.adj_op(kspace_data_gpu).get()  # adjoint NUFFT
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/.venv/lib/python3.10/site-packages/cufinufft/_plan.py:393: UserWarning: Argument `data` does not satisfy the following requirement: C. Copying array (this may reduce performance)
  warnings.warn(f"Argument `{name}` does not satisfy the "

Pseudo-inverse solver#

The least-square solution to the inverse problem can be obtained by solving the following optimization problem:

\[\min_x \|Ax - b\|_2^2\]

where \(A\) is the NUFFT operator, \(x\) is the image to be reconstructed, and \(b\) is the k-space data. The optimization problem can be solved using different iterative solvers, such as Conjugate Gradient (CG), LSQR and LSMR. The solvers are implemented in the mrinufft.pinv_solver() method, which takes as input the k-space data, the maximum number of iterations, and the optimization method to use.

Callback monitoring#

We can monitor the convergence of the optimization by using a callback function that is called at each iteration of the optimization. The callback function can compute different metrics, such as the residual norm, the PSNR, or the time taken

def mixed_cb(*args, **kwargs):
    """A compound callback function, to track iterations time and convergence."""
    return [
        time.perf_counter(),
        loss_l2_reg(*args, **kwargs),
        loss_l2_AHreg(*args, **kwargs),
        psnr(
            abs(args[0].get().squeeze()),
            abs(ground_truth.squeeze()),
            data_range=ground_truth.max(),
        ),
        time.perf_counter(),
    ]


def process_cb_results(cb_results):
    t0, r, rH, psnrs, t1 = list(zip(*cb_results))
    t1 = (t0[0], *t1[:-1])
    time_it = np.cumsum(np.array(t0) - np.array(t1))
    r = [rr.get() for rr in r]
    rH = [rr.get() for rr in rH]

    return {"time": time_it, "res": r, "AHres": rH, "psnr": psnrs}

Run the least-square minimization for all the solvers#

OPTIM = ["cg", "lsqr", "lsmr"]
METRICS = {
    "res": r"$\|Ax-b\|$",
    "AHres": r"$\|A^H(Ax-b)\|$",
    "psnr": "PSNR",
}

MAX_ITER = 1000

images = dict()
iterations_cb = dict()
pg = tqdm(total=MAX_ITER, position=0, leave=True)
for optim in OPTIM:
    image, iter_cb = nufft.pinv_solver(
        kspace_data=kspace_data_gpu,
        max_iter=MAX_ITER,
        callback=mixed_cb,
        optim=optim,
        progressbar=pg,
    )
    images[optim] = image.get().squeeze()  # retrieve image from GPU.
    iterations_cb[optim] = process_cb_results(iter_cb)
 0%|          | 0/1000 [00:00<?, ?it/s]
 0%|          | 0/1000 [00:00<?, ?it/s]
 2%|▏         | 16/1000 [00:00<00:06, 153.19it/s]
 3%|▎         | 33/1000 [00:00<00:05, 161.31it/s]
 6%|▌         | 55/1000 [00:00<00:05, 185.96it/s]
 7%|▋         | 74/1000 [00:00<00:05, 166.88it/s]
 9%|▉         | 92/1000 [00:00<00:05, 170.48it/s]
11%|█▏        | 114/1000 [00:00<00:04, 181.24it/s]
13%|█▎        | 133/1000 [00:00<00:04, 180.63it/s]
15%|█▌        | 152/1000 [00:00<00:04, 176.86it/s]
17%|█▋        | 174/1000 [00:00<00:04, 188.74it/s]
20%|█▉        | 196/1000 [00:01<00:04, 197.60it/s]
22%|██▏       | 218/1000 [00:01<00:03, 203.71it/s]
24%|██▍       | 239/1000 [00:01<00:04, 188.40it/s]
26%|██▌       | 261/1000 [00:01<00:03, 196.38it/s]
28%|██▊       | 284/1000 [00:01<00:03, 203.81it/s]
30%|███       | 305/1000 [00:01<00:03, 190.35it/s]
33%|███▎      | 328/1000 [00:01<00:03, 199.53it/s]
35%|███▌      | 350/1000 [00:01<00:03, 205.18it/s]
37%|███▋      | 372/1000 [00:01<00:03, 208.12it/s]
39%|███▉      | 393/1000 [00:02<00:03, 191.99it/s]
42%|████▏     | 416/1000 [00:02<00:02, 200.05it/s]
44%|████▍     | 439/1000 [00:02<00:02, 207.21it/s]
46%|████▌     | 460/1000 [00:02<00:02, 192.76it/s]
48%|████▊     | 483/1000 [00:02<00:02, 201.01it/s]
51%|█████     | 506/1000 [00:02<00:02, 207.73it/s]
53%|█████▎    | 528/1000 [00:02<00:02, 198.66it/s]
55%|█████▍    | 549/1000 [00:02<00:02, 192.32it/s]
57%|█████▋    | 570/1000 [00:02<00:02, 196.41it/s]
59%|█████▉    | 593/1000 [00:03<00:01, 203.59it/s]
61%|██████▏   | 614/1000 [00:03<00:02, 188.75it/s]
64%|██████▎   | 636/1000 [00:03<00:01, 196.06it/s]
66%|██████▌   | 659/1000 [00:03<00:01, 203.07it/s]
68%|██████▊   | 680/1000 [00:03<00:01, 193.96it/s]
70%|███████   | 700/1000 [00:03<00:01, 194.56it/s]
72%|███████▏  | 723/1000 [00:03<00:01, 202.96it/s]
75%|███████▍  | 746/1000 [00:03<00:01, 208.47it/s]
77%|███████▋  | 767/1000 [00:03<00:01, 193.47it/s]
79%|███████▉  | 789/1000 [00:04<00:01, 200.11it/s]
81%|████████  | 812/1000 [00:04<00:00, 206.22it/s]
83%|████████▎ | 833/1000 [00:04<00:00, 191.26it/s]
85%|████████▌ | 854/1000 [00:04<00:00, 195.50it/s]
88%|████████▊ | 876/1000 [00:04<00:00, 201.78it/s]
90%|████████▉ | 898/1000 [00:04<00:00, 201.70it/s]
92%|█████████▏| 919/1000 [00:04<00:00, 192.01it/s]
94%|█████████▍| 941/1000 [00:04<00:00, 199.38it/s]
96%|█████████▌| 962/1000 [00:04<00:00, 201.77it/s]
98%|█████████▊| 983/1000 [00:05<00:00, 188.01it/s]
 0%|          | 0/1000 [00:00<?, ?it/s]
 2%|▏         | 20/1000 [00:00<00:06, 141.64it/s]
 4%|▎         | 35/1000 [00:00<00:08, 120.18it/s]
 5%|▍         | 49/1000 [00:00<00:07, 127.11it/s]
 6%|▋         | 63/1000 [00:00<00:07, 130.20it/s]
 8%|▊         | 77/1000 [00:00<00:07, 118.60it/s]
 9%|▉         | 92/1000 [00:00<00:07, 126.29it/s]
11%|█         | 107/1000 [00:00<00:06, 132.91it/s]
12%|█▏        | 121/1000 [00:00<00:07, 122.12it/s]
14%|█▎        | 136/1000 [00:01<00:06, 127.21it/s]
15%|█▌        | 151/1000 [00:01<00:06, 132.67it/s]
16%|█▋        | 165/1000 [00:01<00:06, 123.38it/s]
18%|█▊        | 178/1000 [00:01<00:06, 124.79it/s]
19%|█▉        | 194/1000 [00:01<00:06, 132.85it/s]
21%|██        | 209/1000 [00:01<00:05, 135.67it/s]
22%|██▏       | 223/1000 [00:01<00:06, 126.83it/s]
24%|██▍       | 238/1000 [00:01<00:05, 131.38it/s]
25%|██▌       | 252/1000 [00:01<00:05, 132.02it/s]
27%|██▋       | 266/1000 [00:02<00:06, 120.06it/s]
28%|██▊       | 281/1000 [00:02<00:05, 127.41it/s]
30%|██▉       | 296/1000 [00:02<00:05, 133.10it/s]
31%|███       | 310/1000 [00:02<00:05, 126.45it/s]
33%|███▎      | 326/1000 [00:02<00:05, 133.99it/s]
34%|███▍      | 340/1000 [00:02<00:05, 130.62it/s]
35%|███▌      | 354/1000 [00:02<00:05, 122.34it/s]
37%|███▋      | 369/1000 [00:02<00:04, 127.70it/s]
38%|███▊      | 384/1000 [00:02<00:04, 133.46it/s]
40%|████      | 400/1000 [00:03<00:04, 130.97it/s]
41%|████▏     | 414/1000 [00:03<00:04, 129.20it/s]
43%|████▎     | 428/1000 [00:03<00:04, 130.01it/s]
44%|████▍     | 442/1000 [00:03<00:04, 131.74it/s]
46%|████▌     | 456/1000 [00:03<00:04, 119.82it/s]
47%|████▋     | 471/1000 [00:03<00:04, 125.91it/s]
49%|████▊     | 486/1000 [00:03<00:03, 132.13it/s]
 0%|          | 0/1000 [00:00<?, ?it/s]
 1%|▏         | 14/1000 [00:00<00:11, 83.19it/s]
 2%|▎         | 25/1000 [00:00<00:10, 93.39it/s]
 4%|▎         | 37/1000 [00:00<00:09, 100.42it/s]
 5%|▍         | 48/1000 [00:00<00:10, 91.99it/s]
 6%|▌         | 61/1000 [00:00<00:09, 102.11it/s]
 7%|▋         | 74/1000 [00:00<00:08, 110.51it/s]
 9%|▊         | 86/1000 [00:00<00:09, 97.72it/s]
10%|▉         | 97/1000 [00:00<00:08, 101.01it/s]
11%|█         | 108/1000 [00:01<00:08, 102.12it/s]
12%|█▏        | 119/1000 [00:01<00:09, 94.82it/s]
13%|█▎        | 132/1000 [00:01<00:08, 103.79it/s]
15%|█▍        | 146/1000 [00:01<00:07, 111.47it/s]
16%|█▌        | 158/1000 [00:01<00:08, 99.93it/s]
17%|█▋        | 170/1000 [00:01<00:08, 103.08it/s]
18%|█▊        | 181/1000 [00:01<00:07, 104.24it/s]
19%|█▉        | 193/1000 [00:01<00:07, 102.35it/s]
20%|██        | 204/1000 [00:02<00:07, 103.31it/s]
22%|██▏       | 217/1000 [00:02<00:07, 109.73it/s]
23%|██▎       | 230/1000 [00:02<00:06, 114.00it/s]
24%|██▍       | 242/1000 [00:02<00:07, 99.19it/s]
25%|██▌       | 254/1000 [00:02<00:07, 103.29it/s]
27%|██▋       | 266/1000 [00:02<00:06, 106.37it/s]
28%|██▊       | 277/1000 [00:02<00:07, 96.67it/s]
29%|██▉       | 290/1000 [00:02<00:06, 104.11it/s]
30%|███       | 302/1000 [00:02<00:06, 106.53it/s]
31%|███▏      | 313/1000 [00:03<00:07, 96.28it/s]
33%|███▎      | 326/1000 [00:03<00:06, 103.07it/s]
34%|███▎      | 337/1000 [00:03<00:06, 104.48it/s]
35%|███▍      | 348/1000 [00:03<00:06, 98.49it/s]
36%|███▌      | 361/1000 [00:03<00:06, 106.01it/s]
37%|███▋      | 373/1000 [00:03<00:05, 107.75it/s]
38%|███▊      | 384/1000 [00:03<00:06, 98.48it/s]
40%|███▉      | 396/1000 [00:03<00:05, 102.58it/s]
41%|████      | 408/1000 [00:03<00:05, 107.05it/s]
42%|████▏     | 420/1000 [00:04<00:05, 100.99it/s]
43%|████▎     | 432/1000 [00:04<00:05, 104.23it/s]
44%|████▍     | 443/1000 [00:04<00:05, 104.91it/s]
45%|████▌     | 454/1000 [00:04<00:05, 105.21it/s]
46%|████▋     | 465/1000 [00:04<00:05, 97.90it/s]
48%|████▊     | 477/1000 [00:04<00:05, 103.46it/s]
49%|████▉     | 490/1000 [00:04<00:04, 108.42it/s]
50%|█████     | 501/1000 [00:04<00:05, 98.68it/s]
51%|█████     | 512/1000 [00:05<00:04, 101.02it/s]
52%|█████▏    | 523/1000 [00:05<00:04, 103.41it/s]
53%|█████▎    | 534/1000 [00:05<00:04, 96.77it/s]

Display Convergence#

fig, axs = plt.subplots(len(METRICS), 1, sharex=True, figsize=(8, 12))
for i, metric in enumerate(METRICS):
    for optim in OPTIM:
        if "res" in metric:
            axs[i].set_yscale("log")
        axs[i].plot(
            iterations_cb[optim]["time"],
            iterations_cb[optim][metric],
            marker="o",
            markevery=20,
            label=f"{optim} {np.mean(1/np.diff(iterations_cb[optim]['time'])):.2f}iters/s",
        )
    axs[i].grid()
    axs[i].set_ylabel(METRICS[metric])
axs[0].legend()
axs[-1].set_xlabel("time (s)")
fig.tight_layout()
plt.show()
example pinv

Display images#

fig, axs = plt.subplots(1, len(OPTIM) + 2, figsize=(20, 7))

for i, optim in enumerate(OPTIM):
    axs[i].imshow(abs(images[optim]), cmap="gray", origin="lower")
    axs[i].axis("off")
    axs[i].set_title(
        f"{optim} reconstruction\n PSNR: {iterations_cb[optim]['psnr'][-1]:.2f}dB \n"
        f"{len(iterations_cb[optim]['time'])} iters ({iterations_cb[optim]['time'][-1]:.2f}s)"
    )

axs[-1].imshow(abs(ground_truth), cmap="gray", origin="lower")
axs[-1].axis("off")
axs[-1].set_title("Original image")
axs[-2].imshow(
    abs(adjoint),
    cmap="gray",
    origin="lower",
)
axs[-2].axis("off")
axs[-2].set_title(
    f"Adjoint NUFFT \n PSNR: {psnr(abs(adjoint), abs(ground_truth), data_range=ground_truth.max()):.2f}dB"
)

fig.suptitle("Reconstructed images using different optimizers")
fig.tight_layout()
plt.show()
Reconstructed images using different optimizers, cg reconstruction  PSNR: 26.39dB  1000 iters (2.27s), lsqr reconstruction  PSNR: 26.54dB  489 iters (2.42s), lsmr reconstruction  PSNR: 26.54dB  537 iters (3.61s), Adjoint NUFFT   PSNR: -18.57dB, Original image

Using a damping regularization term#

The least-square problem can be regularized using a damping term to improve the conditioning of the problem. This is done by solving the following optimization problem:

\[\]
#    \min_x \|Ax - b\|_2^2 + \gamma \|x\|_2^2

# where :math:`\gamma` is the regularization parameter.


images = dict()
iterations_cb = dict()

pg = tqdm(total=MAX_ITER, position=0, leave=True)
for optim in OPTIM:
    image, iter_cb = nufft.pinv_solver(
        kspace_data=kspace_data_gpu,
        max_iter=1000,
        callback=mixed_cb,
        damp=0.1,
        optim=optim,
        progressbar=pg,
    )
    images[optim] = image.get().squeeze()  # retrieve image from GPU.
    iterations_cb[optim] = process_cb_results(iter_cb)
 0%|          | 0/1000 [00:00<?, ?it/s]
54%|█████▎    | 536/1000 [00:06<00:06, 76.59it/s]

 0%|          | 0/1000 [00:00<?, ?it/s]
 2%|▏         | 22/1000 [00:00<00:04, 216.56it/s]
 4%|▍         | 44/1000 [00:00<00:05, 179.01it/s]
 7%|▋         | 66/1000 [00:00<00:04, 193.82it/s]
 9%|▊         | 87/1000 [00:00<00:04, 199.30it/s]
11%|█         | 108/1000 [00:00<00:04, 183.50it/s]
13%|█▎        | 128/1000 [00:00<00:04, 187.68it/s]
15%|█▍        | 149/1000 [00:00<00:04, 193.79it/s]
17%|█▋        | 170/1000 [00:00<00:04, 198.05it/s]
19%|█▉        | 190/1000 [00:01<00:04, 184.30it/s]
21%|██▏       | 213/1000 [00:01<00:04, 195.27it/s]
24%|██▎       | 235/1000 [00:01<00:03, 200.90it/s]
26%|██▌       | 256/1000 [00:01<00:03, 187.34it/s]
28%|██▊       | 276/1000 [00:01<00:03, 189.07it/s]
30%|██▉       | 298/1000 [00:01<00:03, 195.89it/s]
32%|███▏      | 319/1000 [00:01<00:03, 199.02it/s]
34%|███▍      | 340/1000 [00:01<00:03, 187.35it/s]
36%|███▌      | 362/1000 [00:01<00:03, 195.41it/s]
38%|███▊      | 382/1000 [00:01<00:03, 195.81it/s]
40%|████      | 402/1000 [00:02<00:03, 185.80it/s]
42%|████▏     | 421/1000 [00:02<00:03, 185.12it/s]
44%|████▍     | 443/1000 [00:02<00:02, 194.53it/s]
46%|████▋     | 465/1000 [00:02<00:02, 200.36it/s]
49%|████▊     | 486/1000 [00:02<00:02, 185.27it/s]
51%|█████     | 507/1000 [00:02<00:02, 191.71it/s]
53%|█████▎    | 528/1000 [00:02<00:02, 195.95it/s]
55%|█████▍    | 548/1000 [00:02<00:02, 186.61it/s]
57%|█████▋    | 567/1000 [00:02<00:02, 186.67it/s]
59%|█████▉    | 588/1000 [00:03<00:02, 191.88it/s]
61%|██████    | 610/1000 [00:03<00:01, 197.86it/s]
63%|██████▎   | 630/1000 [00:03<00:01, 185.15it/s]
65%|██████▌   | 652/1000 [00:03<00:01, 192.08it/s]
67%|██████▋   | 674/1000 [00:03<00:01, 198.07it/s]
69%|██████▉   | 694/1000 [00:03<00:01, 194.36it/s]
71%|███████▏  | 714/1000 [00:03<00:01, 185.65it/s]
74%|███████▎  | 735/1000 [00:03<00:01, 192.38it/s]
76%|███████▌  | 758/1000 [00:03<00:01, 201.71it/s]
78%|███████▊  | 779/1000 [00:04<00:01, 188.02it/s]
80%|████████  | 800/1000 [00:04<00:01, 192.23it/s]
82%|████████▏ | 821/1000 [00:04<00:00, 195.42it/s]
84%|████████▍ | 841/1000 [00:04<00:00, 190.71it/s]
86%|████████▌ | 861/1000 [00:04<00:00, 186.24it/s]
88%|████████▊ | 884/1000 [00:04<00:00, 195.94it/s]
91%|█████████ | 906/1000 [00:04<00:00, 201.69it/s]
93%|█████████▎| 927/1000 [00:04<00:00, 184.72it/s]
95%|█████████▍| 948/1000 [00:04<00:00, 189.55it/s]
97%|█████████▋| 970/1000 [00:05<00:00, 195.37it/s]
99%|█████████▉| 991/1000 [00:05<00:00, 198.42it/s]
 0%|          | 0/1000 [00:00<?, ?it/s]
 2%|▏         | 21/1000 [00:00<00:06, 140.14it/s]
 4%|▎         | 36/1000 [00:00<00:06, 143.08it/s]
 5%|▌         | 51/1000 [00:00<00:06, 137.44it/s]
 6%|▋         | 65/1000 [00:00<00:06, 136.01it/s]
 8%|▊         | 80/1000 [00:00<00:06, 137.59it/s]
 0%|          | 0/1000 [00:00<?, ?it/s]
 1%|▏         | 14/1000 [00:00<00:09, 106.99it/s]
 3%|▎         | 27/1000 [00:00<00:08, 114.94it/s]
 4%|▍         | 39/1000 [00:00<00:08, 111.89it/s]
 5%|▌         | 51/1000 [00:00<00:08, 110.15it/s]
 6%|▋         | 63/1000 [00:00<00:08, 110.86it/s]
 8%|▊         | 75/1000 [00:00<00:08, 111.62it/s]

Display Convergence#

fig, axs = plt.subplots(len(METRICS), 1, sharex=True, figsize=(8, 12))
for i, metric in enumerate(METRICS):
    for optim in OPTIM:
        if "res" in metric:
            axs[i].set_yscale("log")
        axs[i].plot(
            iterations_cb[optim]["time"],
            iterations_cb[optim][metric],
            marker="o",
            markevery=20,
            label=f"{optim} {np.mean(1/np.diff(iterations_cb[optim]['time'])):.2f}iters/s",
        )
    axs[i].grid()
    axs[i].set_ylabel(METRICS[metric])
axs[0].legend()
axs[-1].set_xlabel("time (s)")
fig.tight_layout()
plt.show()
example pinv

Display images#

fig, axs = plt.subplots(1, len(OPTIM) + 2, figsize=(20, 7))

for i, optim in enumerate(OPTIM):
    axs[i].imshow(abs(images[optim]), cmap="gray", origin="lower")
    axs[i].axis("off")
    axs[i].set_title(
        f"{optim} reconstruction\n PSNR: {iterations_cb[optim]['psnr'][-1]:.2f}dB \n"
        f"{len(iterations_cb[optim]['time'])} iters ({iterations_cb[optim]['time'][-1]:.2f}s)"
    )

axs[-1].imshow(abs(ground_truth), cmap="gray", origin="lower")
axs[-1].axis("off")
axs[-1].set_title("Original image")
axs[-2].imshow(
    abs(adjoint),
    cmap="gray",
    origin="lower",
)
axs[-2].axis("off")
axs[-2].set_title(
    f"Adjoint NUFFT \n PSNR: {psnr(abs(adjoint), abs(ground_truth), data_range=ground_truth.max()):.2f}dB"
)

fig.suptitle("Reconstructed images using different optimizers")
fig.tight_layout()
plt.show()
Reconstructed images using different optimizers, cg reconstruction  PSNR: 26.01dB  1000 iters (2.25s), lsqr reconstruction  PSNR: 26.40dB  92 iters (0.40s), lsmr reconstruction  PSNR: 26.40dB  80 iters (0.46s), Adjoint NUFFT   PSNR: -18.57dB, Original image

Total running time of the script: (0 minutes 23.993 seconds)

Gallery generated by Sphinx-Gallery