Model-based iterative reconstruction#

This example demonstrates how to reconstruct image from non-Cartesian k-space data with a regularization prior, using deepinv.

Imports#

import numpy as np
import matplotlib.pyplot as plt
from brainweb_dl import get_mri
from deepinv.optim.prior import WaveletPrior
from deepinv.optim.data_fidelity import L2
from deepinv.optim.optimizers import optim_builder

from mrinufft import get_operator
from mrinufft.trajectories import initialize_3D_cones
import torch
import os

BACKEND = os.environ.get("MRINUFFT_BACKEND", "cufinufft")

Get MRI data, 3D FLORET trajectory, and simulate k-space data

samples_loc = initialize_3D_cones(32 * 32, Ns=256, nb_zigzags=16, width=3)
# Load and downsample MRI data for speed
mri = (
    torch.Tensor(np.ascontiguousarray(get_mri(0)[::2, ::2, ::2][::-1, ::-1]))
    .to(torch.complex64)
    .to("cuda")
)

Simulate k-space data

fourier_op = get_operator(BACKEND)(
    samples_loc,
    shape=mri.shape,
    density="pipe",
)
y = fourier_op.op(mri)  # Simulate k-space data
noise_level = y.abs().max().item() * 0.0002
y += noise_level * (torch.randn_like(y) + 1j * torch.randn_like(y))
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/_utils.py:72: UserWarning: Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)
  warnings.warn(
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/_array_compat.py:248: UserWarning: data is on gpu, it will be moved to CPU.
  warnings.warn("data is on gpu, it will be moved to CPU.")

Setup the physics and prior

physics = fourier_op.make_deepinv_phy()
wavelet = WaveletPrior(
    wv="sym8",
    wvdim=3,
    level=3,
    is_complex=True,
)

Initial reconstruction with adjoint

x_dagger = physics.A_dagger(y)
  0%|          | 0/100 [00:00<?, ?it/s]
  2%|▏         | 2/100 [00:00<00:05, 17.31it/s]
  4%|▍         | 4/100 [00:00<00:05, 17.33it/s]
  6%|▌         | 6/100 [00:00<00:05, 17.33it/s]
  8%|▊         | 8/100 [00:00<00:05, 17.33it/s]
 10%|█         | 10/100 [00:00<00:05, 17.30it/s]
 12%|█▏        | 12/100 [00:00<00:05, 17.30it/s]
 14%|█▍        | 14/100 [00:00<00:04, 17.26it/s]
 16%|█▌        | 16/100 [00:00<00:04, 17.29it/s]
 18%|█▊        | 18/100 [00:01<00:04, 17.27it/s]
 20%|██        | 20/100 [00:01<00:04, 17.29it/s]
 22%|██▏       | 22/100 [00:01<00:04, 17.31it/s]
 24%|██▍       | 24/100 [00:01<00:04, 17.32it/s]
 26%|██▌       | 26/100 [00:01<00:04, 17.32it/s]
 28%|██▊       | 28/100 [00:01<00:05, 14.10it/s]
 30%|███       | 30/100 [00:01<00:05, 13.01it/s]
 32%|███▏      | 32/100 [00:02<00:05, 11.84it/s]
 34%|███▍      | 34/100 [00:02<00:06, 11.00it/s]
 36%|███▌      | 36/100 [00:02<00:06,  9.74it/s]
 38%|███▊      | 38/100 [00:02<00:06,  9.14it/s]
 39%|███▉      | 39/100 [00:02<00:06,  8.91it/s]
 40%|████      | 40/100 [00:03<00:07,  8.53it/s]
 41%|████      | 41/100 [00:03<00:07,  8.24it/s]
 42%|████▏     | 42/100 [00:03<00:07,  8.11it/s]
 43%|████▎     | 43/100 [00:03<00:07,  8.07it/s]
 44%|████▍     | 44/100 [00:03<00:06,  8.06it/s]
 45%|████▌     | 45/100 [00:03<00:06,  8.13it/s]
 47%|████▋     | 47/100 [00:03<00:06,  8.57it/s]
 48%|████▊     | 48/100 [00:04<00:06,  8.20it/s]
 49%|████▉     | 49/100 [00:04<00:06,  7.81it/s]
 50%|█████     | 50/100 [00:04<00:06,  7.72it/s]
 51%|█████     | 51/100 [00:04<00:06,  7.57it/s]
 52%|█████▏    | 52/100 [00:04<00:06,  7.69it/s]
 53%|█████▎    | 53/100 [00:04<00:06,  7.44it/s]
 54%|█████▍    | 54/100 [00:04<00:06,  7.30it/s]
 55%|█████▌    | 55/100 [00:05<00:05,  7.73it/s]
 57%|█████▋    | 57/100 [00:05<00:04,  9.20it/s]
 59%|█████▉    | 59/100 [00:05<00:04, 10.19it/s]
 61%|██████    | 61/100 [00:05<00:03, 11.08it/s]
 63%|██████▎   | 63/100 [00:05<00:03, 12.29it/s]
 65%|██████▌   | 65/100 [00:05<00:02, 13.52it/s]
 67%|██████▋   | 67/100 [00:05<00:02, 14.50it/s]
 69%|██████▉   | 69/100 [00:05<00:02, 15.24it/s]
 71%|███████   | 71/100 [00:06<00:01, 15.78it/s]
 73%|███████▎  | 73/100 [00:06<00:01, 16.18it/s]
 75%|███████▌  | 75/100 [00:06<00:01, 16.47it/s]
 77%|███████▋  | 77/100 [00:06<00:01, 16.67it/s]
 79%|███████▉  | 79/100 [00:06<00:01, 16.81it/s]
 81%|████████  | 81/100 [00:06<00:01, 16.92it/s]
 83%|████████▎ | 83/100 [00:06<00:01, 13.35it/s]
 85%|████████▌ | 85/100 [00:07<00:01, 10.44it/s]
 87%|████████▋ | 87/100 [00:07<00:01,  9.23it/s]
 89%|████████▉ | 89/100 [00:07<00:01,  8.55it/s]
 90%|█████████ | 90/100 [00:07<00:01,  8.43it/s]
 91%|█████████ | 91/100 [00:08<00:01,  8.04it/s]
 92%|█████████▏| 92/100 [00:08<00:01,  7.99it/s]
 93%|█████████▎| 93/100 [00:08<00:00,  7.67it/s]
 94%|█████████▍| 94/100 [00:08<00:00,  7.75it/s]
 95%|█████████▌| 95/100 [00:08<00:00,  7.47it/s]
 96%|█████████▌| 96/100 [00:08<00:00,  7.56it/s]
 97%|█████████▋| 97/100 [00:08<00:00,  7.33it/s]
 98%|█████████▊| 98/100 [00:08<00:00,  7.43it/s]
 99%|█████████▉| 99/100 [00:09<00:00,  7.26it/s]
100%|██████████| 100/100 [00:09<00:00,  7.28it/s]
100%|██████████| 100/100 [00:09<00:00, 10.82it/s]

Setup and run the reconstruction algorithm Data fidelity term

data_fidelity = L2()
# Algorithm parameters
lamb = 1e1
stepsize = 0.8 * float(1 / fourier_op.get_lipschitz_cst().get())
params_algo = {"stepsize": stepsize, "lambda": lamb, "a": 3}
max_iter = 100
early_stop = True
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/operators/base.py:1075: UserWarning: Lipschitz constant did not converge
  warnings.warn("Lipschitz constant did not converge")

Instantiate the algorithm class to solve the problem.

wavelet_recon = optim_builder(
    iteration="FISTA",
    prior=wavelet,
    data_fidelity=data_fidelity,
    early_stop=early_stop,
    max_iter=max_iter,
    params_algo=params_algo,
)
x_wavelet = wavelet_recon(y, physics)
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/.venv/lib/python3.10/site-packages/cufinufft/_plan.py:402: UserWarning: Argument `data` does not satisfy the following requirement: C. Copying array (this may reduce performance)
  warnings.warn(f"Argument `{name}` does not satisfy the "

Display results

plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.imshow(torch.abs(mri[..., mri.shape[2] // 2 - 5]).cpu(), cmap="gray")
plt.title("Ground truth")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(
    torch.abs(x_dagger[0, 0, ..., x_dagger.shape[2] // 2 - 5]).cpu(), cmap="gray"
)
plt.title("Adjoint reconstruction")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(
    torch.abs(x_wavelet[0, 0, ..., x_wavelet.shape[2] // 2 - 5]).cpu(), cmap="gray"
)
plt.title("Reconstruction with wavelet prior")
plt.axis("off")
plt.show()
Ground truth, Adjoint reconstruction, Reconstruction with wavelet prior

Total running time of the script: (3 minutes 56.023 seconds)

Gallery generated by Sphinx-Gallery