Note
Go to the end to download the full example code or to run this example in your browser via Binder.
Least Squares Image Reconstruction#
An example to show how to reconstruct volumes using the least square estimate.
This script demonstrates the use of the Conjugate Gradient (CG), LSQR and LSMR methods, to reconstruct images from non-uniform k-space data.
import os
import time
from tqdm.auto import tqdm
import cupy as cp
import numpy as np
from brainweb_dl import get_mri
from matplotlib import pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
import mrinufft
from mrinufft.extras.optim import loss_l2_reg, loss_l2_AHreg
BACKEND = os.environ.get("MRINUFFT_BACKEND", "cufinufft")
Setup Inputs
samples_loc = mrinufft.initialize_2D_spiral(Nc=64, Ns=512, nb_revolutions=8)
ground_truth = get_mri(sub_id=4)
ground_truth = ground_truth[90]
# Normalize the ground truth image
ground_truth = ground_truth / np.sqrt(np.mean(abs(ground_truth) ** 2))
image_gpu = cp.array(ground_truth) # convert to cupy array for GPU processing
print("image size: ", ground_truth.shape)
image size: (256, 256)
Setup the NUFFT operator
NufftOperator = mrinufft.get_operator(BACKEND) # get the operator
nufft = NufftOperator(
samples_loc,
shape=ground_truth.shape,
squeeze_dims=True,
) # create the NUFFT operator
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
warnings.warn(
Reconstruct the image using the CG method
kspace_data_gpu = nufft.op(image_gpu) # get the k-space data
kspace_data = kspace_data_gpu.get() # convert back to numpy array for display
adjoint = nufft.adj_op(kspace_data_gpu).get() # adjoint NUFFT
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/.venv/lib/python3.10/site-packages/cufinufft/_plan.py:393: UserWarning: Argument `data` does not satisfy the following requirement: C. Copying array (this may reduce performance)
warnings.warn(f"Argument `{name}` does not satisfy the "
Pseudo-inverse solver#
The least-square solution to the inverse problem can be obtained by solving the following optimization problem:
where \(A\) is the NUFFT operator, \(x\) is the image to be
reconstructed, and \(b\) is the k-space data. The optimization problem can
be solved using different iterative solvers, such as Conjugate Gradient (CG),
LSQR and LSMR. The solvers are implemented in the mrinufft.pinv_solver()
method, which takes as input the k-space data, the maximum number of
iterations, and the optimization method to use.
Callback monitoring#
We can monitor the convergence of the optimization by using a callback function that is called at each iteration of the optimization. The callback function can compute different metrics, such as the residual norm, the PSNR, or the time taken
def mixed_cb(*args, **kwargs):
"""A compound callback function, to track iterations time and convergence."""
return [
time.perf_counter(),
loss_l2_reg(*args, **kwargs),
loss_l2_AHreg(*args, **kwargs),
psnr(
abs(args[0].get().squeeze()),
abs(ground_truth.squeeze()),
data_range=ground_truth.max(),
),
time.perf_counter(),
]
def process_cb_results(cb_results):
t0, r, rH, psnrs, t1 = list(zip(*cb_results))
t1 = (t0[0], *t1[:-1])
time_it = np.cumsum(np.array(t0) - np.array(t1))
r = [rr.get() for rr in r]
rH = [rr.get() for rr in rH]
return {"time": time_it, "res": r, "AHres": rH, "psnr": psnrs}
Run the least-square minimization for all the solvers#
OPTIM = ["cg", "lsqr", "lsmr"]
METRICS = {
"res": r"$\|Ax-b\|$",
"AHres": r"$\|A^H(Ax-b)\|$",
"psnr": "PSNR",
}
MAX_ITER = 1000
images = dict()
iterations_cb = dict()
pg = tqdm(total=MAX_ITER, position=0, leave=True)
for optim in OPTIM:
image, iter_cb = nufft.pinv_solver(
kspace_data=kspace_data_gpu,
max_iter=MAX_ITER,
callback=mixed_cb,
optim=optim,
progressbar=pg,
)
images[optim] = image.get().squeeze() # retrieve image from GPU.
iterations_cb[optim] = process_cb_results(iter_cb)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
2%|▏ | 16/1000 [00:00<00:06, 153.19it/s]
3%|▎ | 33/1000 [00:00<00:05, 161.31it/s]
6%|▌ | 55/1000 [00:00<00:05, 185.96it/s]
7%|▋ | 74/1000 [00:00<00:05, 166.88it/s]
9%|▉ | 92/1000 [00:00<00:05, 170.48it/s]
11%|█▏ | 114/1000 [00:00<00:04, 181.24it/s]
13%|█▎ | 133/1000 [00:00<00:04, 180.63it/s]
15%|█▌ | 152/1000 [00:00<00:04, 176.86it/s]
17%|█▋ | 174/1000 [00:00<00:04, 188.74it/s]
20%|█▉ | 196/1000 [00:01<00:04, 197.60it/s]
22%|██▏ | 218/1000 [00:01<00:03, 203.71it/s]
24%|██▍ | 239/1000 [00:01<00:04, 188.40it/s]
26%|██▌ | 261/1000 [00:01<00:03, 196.38it/s]
28%|██▊ | 284/1000 [00:01<00:03, 203.81it/s]
30%|███ | 305/1000 [00:01<00:03, 190.35it/s]
33%|███▎ | 328/1000 [00:01<00:03, 199.53it/s]
35%|███▌ | 350/1000 [00:01<00:03, 205.18it/s]
37%|███▋ | 372/1000 [00:01<00:03, 208.12it/s]
39%|███▉ | 393/1000 [00:02<00:03, 191.99it/s]
42%|████▏ | 416/1000 [00:02<00:02, 200.05it/s]
44%|████▍ | 439/1000 [00:02<00:02, 207.21it/s]
46%|████▌ | 460/1000 [00:02<00:02, 192.76it/s]
48%|████▊ | 483/1000 [00:02<00:02, 201.01it/s]
51%|█████ | 506/1000 [00:02<00:02, 207.73it/s]
53%|█████▎ | 528/1000 [00:02<00:02, 198.66it/s]
55%|█████▍ | 549/1000 [00:02<00:02, 192.32it/s]
57%|█████▋ | 570/1000 [00:02<00:02, 196.41it/s]
59%|█████▉ | 593/1000 [00:03<00:01, 203.59it/s]
61%|██████▏ | 614/1000 [00:03<00:02, 188.75it/s]
64%|██████▎ | 636/1000 [00:03<00:01, 196.06it/s]
66%|██████▌ | 659/1000 [00:03<00:01, 203.07it/s]
68%|██████▊ | 680/1000 [00:03<00:01, 193.96it/s]
70%|███████ | 700/1000 [00:03<00:01, 194.56it/s]
72%|███████▏ | 723/1000 [00:03<00:01, 202.96it/s]
75%|███████▍ | 746/1000 [00:03<00:01, 208.47it/s]
77%|███████▋ | 767/1000 [00:03<00:01, 193.47it/s]
79%|███████▉ | 789/1000 [00:04<00:01, 200.11it/s]
81%|████████ | 812/1000 [00:04<00:00, 206.22it/s]
83%|████████▎ | 833/1000 [00:04<00:00, 191.26it/s]
85%|████████▌ | 854/1000 [00:04<00:00, 195.50it/s]
88%|████████▊ | 876/1000 [00:04<00:00, 201.78it/s]
90%|████████▉ | 898/1000 [00:04<00:00, 201.70it/s]
92%|█████████▏| 919/1000 [00:04<00:00, 192.01it/s]
94%|█████████▍| 941/1000 [00:04<00:00, 199.38it/s]
96%|█████████▌| 962/1000 [00:04<00:00, 201.77it/s]
98%|█████████▊| 983/1000 [00:05<00:00, 188.01it/s]
0%| | 0/1000 [00:00<?, ?it/s]
2%|▏ | 20/1000 [00:00<00:06, 141.64it/s]
4%|▎ | 35/1000 [00:00<00:08, 120.18it/s]
5%|▍ | 49/1000 [00:00<00:07, 127.11it/s]
6%|▋ | 63/1000 [00:00<00:07, 130.20it/s]
8%|▊ | 77/1000 [00:00<00:07, 118.60it/s]
9%|▉ | 92/1000 [00:00<00:07, 126.29it/s]
11%|█ | 107/1000 [00:00<00:06, 132.91it/s]
12%|█▏ | 121/1000 [00:00<00:07, 122.12it/s]
14%|█▎ | 136/1000 [00:01<00:06, 127.21it/s]
15%|█▌ | 151/1000 [00:01<00:06, 132.67it/s]
16%|█▋ | 165/1000 [00:01<00:06, 123.38it/s]
18%|█▊ | 178/1000 [00:01<00:06, 124.79it/s]
19%|█▉ | 194/1000 [00:01<00:06, 132.85it/s]
21%|██ | 209/1000 [00:01<00:05, 135.67it/s]
22%|██▏ | 223/1000 [00:01<00:06, 126.83it/s]
24%|██▍ | 238/1000 [00:01<00:05, 131.38it/s]
25%|██▌ | 252/1000 [00:01<00:05, 132.02it/s]
27%|██▋ | 266/1000 [00:02<00:06, 120.06it/s]
28%|██▊ | 281/1000 [00:02<00:05, 127.41it/s]
30%|██▉ | 296/1000 [00:02<00:05, 133.10it/s]
31%|███ | 310/1000 [00:02<00:05, 126.45it/s]
33%|███▎ | 326/1000 [00:02<00:05, 133.99it/s]
34%|███▍ | 340/1000 [00:02<00:05, 130.62it/s]
35%|███▌ | 354/1000 [00:02<00:05, 122.34it/s]
37%|███▋ | 369/1000 [00:02<00:04, 127.70it/s]
38%|███▊ | 384/1000 [00:02<00:04, 133.46it/s]
40%|████ | 400/1000 [00:03<00:04, 130.97it/s]
41%|████▏ | 414/1000 [00:03<00:04, 129.20it/s]
43%|████▎ | 428/1000 [00:03<00:04, 130.01it/s]
44%|████▍ | 442/1000 [00:03<00:04, 131.74it/s]
46%|████▌ | 456/1000 [00:03<00:04, 119.82it/s]
47%|████▋ | 471/1000 [00:03<00:04, 125.91it/s]
49%|████▊ | 486/1000 [00:03<00:03, 132.13it/s]
0%| | 0/1000 [00:00<?, ?it/s]
1%|▏ | 14/1000 [00:00<00:11, 83.19it/s]
2%|▎ | 25/1000 [00:00<00:10, 93.39it/s]
4%|▎ | 37/1000 [00:00<00:09, 100.42it/s]
5%|▍ | 48/1000 [00:00<00:10, 91.99it/s]
6%|▌ | 61/1000 [00:00<00:09, 102.11it/s]
7%|▋ | 74/1000 [00:00<00:08, 110.51it/s]
9%|▊ | 86/1000 [00:00<00:09, 97.72it/s]
10%|▉ | 97/1000 [00:00<00:08, 101.01it/s]
11%|█ | 108/1000 [00:01<00:08, 102.12it/s]
12%|█▏ | 119/1000 [00:01<00:09, 94.82it/s]
13%|█▎ | 132/1000 [00:01<00:08, 103.79it/s]
15%|█▍ | 146/1000 [00:01<00:07, 111.47it/s]
16%|█▌ | 158/1000 [00:01<00:08, 99.93it/s]
17%|█▋ | 170/1000 [00:01<00:08, 103.08it/s]
18%|█▊ | 181/1000 [00:01<00:07, 104.24it/s]
19%|█▉ | 193/1000 [00:01<00:07, 102.35it/s]
20%|██ | 204/1000 [00:02<00:07, 103.31it/s]
22%|██▏ | 217/1000 [00:02<00:07, 109.73it/s]
23%|██▎ | 230/1000 [00:02<00:06, 114.00it/s]
24%|██▍ | 242/1000 [00:02<00:07, 99.19it/s]
25%|██▌ | 254/1000 [00:02<00:07, 103.29it/s]
27%|██▋ | 266/1000 [00:02<00:06, 106.37it/s]
28%|██▊ | 277/1000 [00:02<00:07, 96.67it/s]
29%|██▉ | 290/1000 [00:02<00:06, 104.11it/s]
30%|███ | 302/1000 [00:02<00:06, 106.53it/s]
31%|███▏ | 313/1000 [00:03<00:07, 96.28it/s]
33%|███▎ | 326/1000 [00:03<00:06, 103.07it/s]
34%|███▎ | 337/1000 [00:03<00:06, 104.48it/s]
35%|███▍ | 348/1000 [00:03<00:06, 98.49it/s]
36%|███▌ | 361/1000 [00:03<00:06, 106.01it/s]
37%|███▋ | 373/1000 [00:03<00:05, 107.75it/s]
38%|███▊ | 384/1000 [00:03<00:06, 98.48it/s]
40%|███▉ | 396/1000 [00:03<00:05, 102.58it/s]
41%|████ | 408/1000 [00:03<00:05, 107.05it/s]
42%|████▏ | 420/1000 [00:04<00:05, 100.99it/s]
43%|████▎ | 432/1000 [00:04<00:05, 104.23it/s]
44%|████▍ | 443/1000 [00:04<00:05, 104.91it/s]
45%|████▌ | 454/1000 [00:04<00:05, 105.21it/s]
46%|████▋ | 465/1000 [00:04<00:05, 97.90it/s]
48%|████▊ | 477/1000 [00:04<00:05, 103.46it/s]
49%|████▉ | 490/1000 [00:04<00:04, 108.42it/s]
50%|█████ | 501/1000 [00:04<00:05, 98.68it/s]
51%|█████ | 512/1000 [00:05<00:04, 101.02it/s]
52%|█████▏ | 523/1000 [00:05<00:04, 103.41it/s]
53%|█████▎ | 534/1000 [00:05<00:04, 96.77it/s]
Display Convergence#
fig, axs = plt.subplots(len(METRICS), 1, sharex=True, figsize=(8, 12))
for i, metric in enumerate(METRICS):
for optim in OPTIM:
if "res" in metric:
axs[i].set_yscale("log")
axs[i].plot(
iterations_cb[optim]["time"],
iterations_cb[optim][metric],
marker="o",
markevery=20,
label=f"{optim} {np.mean(1/np.diff(iterations_cb[optim]['time'])):.2f}iters/s",
)
axs[i].grid()
axs[i].set_ylabel(METRICS[metric])
axs[0].legend()
axs[-1].set_xlabel("time (s)")
fig.tight_layout()
plt.show()

Display images#
fig, axs = plt.subplots(1, len(OPTIM) + 2, figsize=(20, 7))
for i, optim in enumerate(OPTIM):
axs[i].imshow(abs(images[optim]), cmap="gray", origin="lower")
axs[i].axis("off")
axs[i].set_title(
f"{optim} reconstruction\n PSNR: {iterations_cb[optim]['psnr'][-1]:.2f}dB \n"
f"{len(iterations_cb[optim]['time'])} iters ({iterations_cb[optim]['time'][-1]:.2f}s)"
)
axs[-1].imshow(abs(ground_truth), cmap="gray", origin="lower")
axs[-1].axis("off")
axs[-1].set_title("Original image")
axs[-2].imshow(
abs(adjoint),
cmap="gray",
origin="lower",
)
axs[-2].axis("off")
axs[-2].set_title(
f"Adjoint NUFFT \n PSNR: {psnr(abs(adjoint), abs(ground_truth), data_range=ground_truth.max()):.2f}dB"
)
fig.suptitle("Reconstructed images using different optimizers")
fig.tight_layout()
plt.show()

Using a damping regularization term#
The least-square problem can be regularized using a damping term to improve the conditioning of the problem. This is done by solving the following optimization problem:
# \min_x \|Ax - b\|_2^2 + \gamma \|x\|_2^2
# where :math:`\gamma` is the regularization parameter.
images = dict()
iterations_cb = dict()
pg = tqdm(total=MAX_ITER, position=0, leave=True)
for optim in OPTIM:
image, iter_cb = nufft.pinv_solver(
kspace_data=kspace_data_gpu,
max_iter=1000,
callback=mixed_cb,
damp=0.1,
optim=optim,
progressbar=pg,
)
images[optim] = image.get().squeeze() # retrieve image from GPU.
iterations_cb[optim] = process_cb_results(iter_cb)
0%| | 0/1000 [00:00<?, ?it/s]
54%|█████▎ | 536/1000 [00:06<00:06, 76.59it/s]
0%| | 0/1000 [00:00<?, ?it/s]
2%|▏ | 22/1000 [00:00<00:04, 216.56it/s]
4%|▍ | 44/1000 [00:00<00:05, 179.01it/s]
7%|▋ | 66/1000 [00:00<00:04, 193.82it/s]
9%|▊ | 87/1000 [00:00<00:04, 199.30it/s]
11%|█ | 108/1000 [00:00<00:04, 183.50it/s]
13%|█▎ | 128/1000 [00:00<00:04, 187.68it/s]
15%|█▍ | 149/1000 [00:00<00:04, 193.79it/s]
17%|█▋ | 170/1000 [00:00<00:04, 198.05it/s]
19%|█▉ | 190/1000 [00:01<00:04, 184.30it/s]
21%|██▏ | 213/1000 [00:01<00:04, 195.27it/s]
24%|██▎ | 235/1000 [00:01<00:03, 200.90it/s]
26%|██▌ | 256/1000 [00:01<00:03, 187.34it/s]
28%|██▊ | 276/1000 [00:01<00:03, 189.07it/s]
30%|██▉ | 298/1000 [00:01<00:03, 195.89it/s]
32%|███▏ | 319/1000 [00:01<00:03, 199.02it/s]
34%|███▍ | 340/1000 [00:01<00:03, 187.35it/s]
36%|███▌ | 362/1000 [00:01<00:03, 195.41it/s]
38%|███▊ | 382/1000 [00:01<00:03, 195.81it/s]
40%|████ | 402/1000 [00:02<00:03, 185.80it/s]
42%|████▏ | 421/1000 [00:02<00:03, 185.12it/s]
44%|████▍ | 443/1000 [00:02<00:02, 194.53it/s]
46%|████▋ | 465/1000 [00:02<00:02, 200.36it/s]
49%|████▊ | 486/1000 [00:02<00:02, 185.27it/s]
51%|█████ | 507/1000 [00:02<00:02, 191.71it/s]
53%|█████▎ | 528/1000 [00:02<00:02, 195.95it/s]
55%|█████▍ | 548/1000 [00:02<00:02, 186.61it/s]
57%|█████▋ | 567/1000 [00:02<00:02, 186.67it/s]
59%|█████▉ | 588/1000 [00:03<00:02, 191.88it/s]
61%|██████ | 610/1000 [00:03<00:01, 197.86it/s]
63%|██████▎ | 630/1000 [00:03<00:01, 185.15it/s]
65%|██████▌ | 652/1000 [00:03<00:01, 192.08it/s]
67%|██████▋ | 674/1000 [00:03<00:01, 198.07it/s]
69%|██████▉ | 694/1000 [00:03<00:01, 194.36it/s]
71%|███████▏ | 714/1000 [00:03<00:01, 185.65it/s]
74%|███████▎ | 735/1000 [00:03<00:01, 192.38it/s]
76%|███████▌ | 758/1000 [00:03<00:01, 201.71it/s]
78%|███████▊ | 779/1000 [00:04<00:01, 188.02it/s]
80%|████████ | 800/1000 [00:04<00:01, 192.23it/s]
82%|████████▏ | 821/1000 [00:04<00:00, 195.42it/s]
84%|████████▍ | 841/1000 [00:04<00:00, 190.71it/s]
86%|████████▌ | 861/1000 [00:04<00:00, 186.24it/s]
88%|████████▊ | 884/1000 [00:04<00:00, 195.94it/s]
91%|█████████ | 906/1000 [00:04<00:00, 201.69it/s]
93%|█████████▎| 927/1000 [00:04<00:00, 184.72it/s]
95%|█████████▍| 948/1000 [00:04<00:00, 189.55it/s]
97%|█████████▋| 970/1000 [00:05<00:00, 195.37it/s]
99%|█████████▉| 991/1000 [00:05<00:00, 198.42it/s]
0%| | 0/1000 [00:00<?, ?it/s]
2%|▏ | 21/1000 [00:00<00:06, 140.14it/s]
4%|▎ | 36/1000 [00:00<00:06, 143.08it/s]
5%|▌ | 51/1000 [00:00<00:06, 137.44it/s]
6%|▋ | 65/1000 [00:00<00:06, 136.01it/s]
8%|▊ | 80/1000 [00:00<00:06, 137.59it/s]
0%| | 0/1000 [00:00<?, ?it/s]
1%|▏ | 14/1000 [00:00<00:09, 106.99it/s]
3%|▎ | 27/1000 [00:00<00:08, 114.94it/s]
4%|▍ | 39/1000 [00:00<00:08, 111.89it/s]
5%|▌ | 51/1000 [00:00<00:08, 110.15it/s]
6%|▋ | 63/1000 [00:00<00:08, 110.86it/s]
8%|▊ | 75/1000 [00:00<00:08, 111.62it/s]
Display Convergence#
fig, axs = plt.subplots(len(METRICS), 1, sharex=True, figsize=(8, 12))
for i, metric in enumerate(METRICS):
for optim in OPTIM:
if "res" in metric:
axs[i].set_yscale("log")
axs[i].plot(
iterations_cb[optim]["time"],
iterations_cb[optim][metric],
marker="o",
markevery=20,
label=f"{optim} {np.mean(1/np.diff(iterations_cb[optim]['time'])):.2f}iters/s",
)
axs[i].grid()
axs[i].set_ylabel(METRICS[metric])
axs[0].legend()
axs[-1].set_xlabel("time (s)")
fig.tight_layout()
plt.show()

Display images#
fig, axs = plt.subplots(1, len(OPTIM) + 2, figsize=(20, 7))
for i, optim in enumerate(OPTIM):
axs[i].imshow(abs(images[optim]), cmap="gray", origin="lower")
axs[i].axis("off")
axs[i].set_title(
f"{optim} reconstruction\n PSNR: {iterations_cb[optim]['psnr'][-1]:.2f}dB \n"
f"{len(iterations_cb[optim]['time'])} iters ({iterations_cb[optim]['time'][-1]:.2f}s)"
)
axs[-1].imshow(abs(ground_truth), cmap="gray", origin="lower")
axs[-1].axis("off")
axs[-1].set_title("Original image")
axs[-2].imshow(
abs(adjoint),
cmap="gray",
origin="lower",
)
axs[-2].axis("off")
axs[-2].set_title(
f"Adjoint NUFFT \n PSNR: {psnr(abs(adjoint), abs(ground_truth), data_range=ground_truth.max()):.2f}dB"
)
fig.suptitle("Reconstructed images using different optimizers")
fig.tight_layout()
plt.show()

Total running time of the script: (0 minutes 23.993 seconds)