Note
Go to the end to download the full example code or to run this example in your browser via Binder.
Model-based iterative reconstruction#
This example demonstrates how to reconstruct image from non-Cartesian k-space data with a regularization prior, using deepinv.
Imports#
import numpy as np
import matplotlib.pyplot as plt
from brainweb_dl import get_mri
from deepinv.optim.prior import WaveletPrior
from deepinv.optim.data_fidelity import L2
from deepinv.optim.optimizers import optim_builder
from mrinufft import get_operator
from mrinufft.trajectories import initialize_3D_cones
import torch
import os
BACKEND = os.environ.get("MRINUFFT_BACKEND", "cufinufft")
Get MRI data, 3D FLORET trajectory, and simulate k-space data
samples_loc = initialize_3D_cones(32 * 32, Ns=256, nb_zigzags=16, width=3)
# Load and downsample MRI data for speed
mri = (
torch.Tensor(np.ascontiguousarray(get_mri(0)[::2, ::2, ::2][::-1, ::-1]))
.to(torch.complex64)
.to("cuda")
)
Simulate k-space data
fourier_op = get_operator(BACKEND)(
samples_loc,
shape=mri.shape,
density="pipe",
)
y = fourier_op.op(mri) # Simulate k-space data
noise_level = y.abs().max().item() * 0.0002
y += noise_level * (torch.randn_like(y) + 1j * torch.randn_like(y))
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/venv/lib/python3.10/site-packages/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
warnings.warn(
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/venv/lib/python3.10/site-packages/mrinufft/_utils.py:72: UserWarning: Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)
warnings.warn(
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/venv/lib/python3.10/site-packages/mrinufft/_array_compat.py:248: UserWarning: data is on gpu, it will be moved to CPU.
warnings.warn("data is on gpu, it will be moved to CPU.")
Setup the physics and prior
physics = fourier_op.make_deepinv_phy()
wavelet = WaveletPrior(
wv="sym8",
wvdim=3,
level=3,
is_complex=True,
)
Initial reconstruction with adjoint
x_dagger = physics.A_dagger(y)
0%| | 0/100 [00:00<?, ?it/s]
3%|▎ | 3/100 [00:00<00:03, 29.18it/s]
7%|▋ | 7/100 [00:00<00:02, 32.62it/s]
11%|█ | 11/100 [00:00<00:02, 33.96it/s]
15%|█▌ | 15/100 [00:00<00:02, 34.55it/s]
19%|█▉ | 19/100 [00:00<00:02, 34.91it/s]
23%|██▎ | 23/100 [00:00<00:02, 35.16it/s]
27%|██▋ | 27/100 [00:00<00:02, 35.18it/s]
31%|███ | 31/100 [00:00<00:01, 35.31it/s]
35%|███▌ | 35/100 [00:01<00:01, 35.41it/s]
39%|███▉ | 39/100 [00:01<00:01, 35.44it/s]
43%|████▎ | 43/100 [00:01<00:01, 35.42it/s]
47%|████▋ | 47/100 [00:01<00:01, 35.44it/s]
51%|█████ | 51/100 [00:01<00:01, 33.16it/s]
55%|█████▌ | 55/100 [00:01<00:01, 27.67it/s]
58%|█████▊ | 58/100 [00:01<00:01, 27.93it/s]
61%|██████ | 61/100 [00:01<00:01, 24.79it/s]
64%|██████▍ | 64/100 [00:02<00:01, 22.53it/s]
67%|██████▋ | 67/100 [00:02<00:01, 19.48it/s]
70%|███████ | 70/100 [00:02<00:01, 17.22it/s]
72%|███████▏ | 72/100 [00:02<00:01, 16.53it/s]
74%|███████▍ | 74/100 [00:02<00:01, 15.96it/s]
76%|███████▌ | 76/100 [00:02<00:01, 15.28it/s]
78%|███████▊ | 78/100 [00:03<00:01, 14.93it/s]
80%|████████ | 80/100 [00:03<00:01, 15.91it/s]
82%|████████▏ | 82/100 [00:03<00:01, 16.62it/s]
84%|████████▍ | 84/100 [00:03<00:01, 14.93it/s]
86%|████████▌ | 86/100 [00:03<00:00, 14.12it/s]
88%|████████▊ | 88/100 [00:03<00:00, 13.40it/s]
90%|█████████ | 90/100 [00:03<00:00, 13.27it/s]
92%|█████████▏| 92/100 [00:04<00:00, 13.18it/s]
94%|█████████▍| 94/100 [00:04<00:00, 13.35it/s]
96%|█████████▌| 96/100 [00:04<00:00, 12.98it/s]
98%|█████████▊| 98/100 [00:04<00:00, 12.74it/s]
100%|██████████| 100/100 [00:04<00:00, 12.62it/s]
100%|██████████| 100/100 [00:04<00:00, 20.99it/s]
Setup and run the reconstruction algorithm Data fidelity term
data_fidelity = L2()
# Algorithm parameters
lamb = 1e1
stepsize = 0.8 * float(1 / fourier_op.get_lipschitz_cst().get())
params_algo = {"stepsize": stepsize, "lambda": lamb, "a": 3}
max_iter = 100
early_stop = True
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/venv/lib/python3.10/site-packages/mrinufft/operators/base.py:1061: UserWarning: Lipschitz constant did not converge
warnings.warn("Lipschitz constant did not converge")
Instantiate the algorithm class to solve the problem.
wavelet_recon = optim_builder(
iteration="FISTA",
prior=wavelet,
data_fidelity=data_fidelity,
early_stop=early_stop,
max_iter=max_iter,
params_algo=params_algo,
)
x_wavelet = wavelet_recon(y, physics)
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/venv/lib/python3.10/site-packages/cufinufft/_plan.py:393: UserWarning: Argument `data` does not satisfy the following requirement: C. Copying array (this may reduce performance)
warnings.warn(f"Argument `{name}` does not satisfy the "
Display results
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.imshow(torch.abs(mri[..., mri.shape[2] // 2 - 5]).cpu(), cmap="gray")
plt.title("Ground truth")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(
torch.abs(x_dagger[0, 0, ..., x_dagger.shape[2] // 2 - 5]).cpu(), cmap="gray"
)
plt.title("Adjoint reconstruction")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(
torch.abs(x_wavelet[0, 0, ..., x_wavelet.shape[2] // 2 - 5]).cpu(), cmap="gray"
)
plt.title("Reconstruction with wavelet prior")
plt.axis("off")
plt.show()

Total running time of the script: (4 minutes 54.568 seconds)