Model-based iterative reconstruction#

This example demonstrates how to reconstruct a 3D MRI image from undersampled non-Cartesian k-space measurements using MRI-NUFFT and DeepInverse.

We first simulate non-Cartesian acquisitions from a reference MRI volume using a 3D cones trajectory. We then compare several reconstruction approaches:

  1. Adjoint reconstruction, providing a fast baseline with sampling artifacts.

  2. Wavelet-regularized reconstruction solved with FISTA.

3. Total Variation reconstruction solved with the DeepInverse PDCP optimizer. . The goal of this example is to illustrate how MRI-NUFFT physics operators can be coupled with DeepInverse optimization tools to solve model-based MRI inverse problems and compare different regularization priors, supporting both complex-valued and real-valued (applied on the real and imaginary part) regularizations.

Imports#

import numpy as np
import matplotlib.pyplot as plt
from brainweb_dl import get_mri
from deepinv.optim.data_fidelity import L2
from deepinv.optim.optimizers import optim_builder
from deepinv.optim.prior import WaveletPrior, TVPrior
from mrinufft import get_operator
from mrinufft.trajectories import initialize_3D_cones
from mrinufft.operators import kspace_as_real
import torch
import os

BACKEND = os.environ.get("MRINUFFT_BACKEND", "cufinufft")

Get MRI data, 3D FLORET trajectory, and simulate k-space data

samples_loc = initialize_3D_cones(32 * 32, Ns=256, nb_zigzags=16, width=3)
# Load and downsample MRI data for speed
mri = (
    torch.Tensor(np.ascontiguousarray(get_mri(0)[::2, ::2, ::2][::-1, ::-1]))
    .to(torch.complex64)
    .to("cuda")
)

Simulate k-space data

fourier_op = get_operator(BACKEND)(
    samples_loc,
    shape=mri.shape,
    density="pipe",
    squeeze_dims=False,
)
y = fourier_op.op(mri)  # Simulate k-space data
noise_level = y.abs().max().item() * 0.0002
y += noise_level * (torch.randn_like(y) + 1j * torch.randn_like(y))
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/_utils.py:72: UserWarning: Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)
  warnings.warn(

real-valued physics

physics = fourier_op.make_deepinv_phy(viewed_as_real=True)

y_real = kspace_as_real(y).float()

wavelet = WaveletPrior(
    wv="sym8",
    wvdim=3,
    level=3,
    is_complex=False,
)

Initial reconstruction with adjoint

x_dagger = physics.A_dagger(y_real)
  0%|          | 0/100 [00:00<?, ?it/s]
  2%|▏         | 2/100 [00:00<00:07, 13.02it/s]
  4%|▍         | 4/100 [00:00<00:07, 13.18it/s]
  6%|▌         | 6/100 [00:00<00:07, 13.21it/s]
  8%|▊         | 8/100 [00:00<00:06, 13.24it/s]
 10%|█         | 10/100 [00:00<00:06, 13.25it/s]
 12%|█▏        | 12/100 [00:00<00:06, 13.23it/s]
 14%|█▍        | 14/100 [00:01<00:06, 13.25it/s]
 16%|█▌        | 16/100 [00:01<00:06, 13.26it/s]
 18%|█▊        | 18/100 [00:01<00:06, 13.25it/s]
 20%|██        | 20/100 [00:01<00:06, 13.16it/s]
 22%|██▏       | 22/100 [00:01<00:07, 10.08it/s]
 24%|██▍       | 24/100 [00:02<00:08,  9.41it/s]
 26%|██▌       | 26/100 [00:02<00:08,  8.65it/s]
 27%|██▋       | 27/100 [00:02<00:09,  8.09it/s]
 28%|██▊       | 28/100 [00:02<00:08,  8.05it/s]
 29%|██▉       | 29/100 [00:02<00:09,  7.42it/s]
 30%|███       | 30/100 [00:02<00:10,  6.89it/s]
 31%|███       | 31/100 [00:03<00:10,  6.47it/s]
 32%|███▏      | 32/100 [00:03<00:10,  6.31it/s]
 33%|███▎      | 33/100 [00:03<00:10,  6.22it/s]
 34%|███▍      | 34/100 [00:03<00:10,  6.13it/s]
 35%|███▌      | 35/100 [00:03<00:10,  6.07it/s]
 36%|███▌      | 36/100 [00:04<00:10,  5.94it/s]
 37%|███▋      | 37/100 [00:04<00:10,  5.80it/s]
 38%|███▊      | 38/100 [00:04<00:10,  5.83it/s]
 39%|███▉      | 39/100 [00:04<00:09,  6.40it/s]
 40%|████      | 40/100 [00:04<00:09,  6.26it/s]
 41%|████      | 41/100 [00:04<00:09,  5.94it/s]
 42%|████▏     | 42/100 [00:05<00:10,  5.79it/s]
 43%|████▎     | 43/100 [00:05<00:09,  5.80it/s]
 44%|████▍     | 44/100 [00:05<00:09,  5.64it/s]
 45%|████▌     | 45/100 [00:05<00:09,  5.72it/s]
 46%|████▌     | 46/100 [00:05<00:09,  5.63it/s]
 47%|████▋     | 47/100 [00:05<00:09,  5.54it/s]
 48%|████▊     | 48/100 [00:06<00:09,  5.51it/s]
 49%|████▉     | 49/100 [00:06<00:08,  6.05it/s]
 50%|█████     | 50/100 [00:06<00:07,  6.56it/s]
 51%|█████     | 51/100 [00:06<00:07,  6.98it/s]
 52%|█████▏    | 52/100 [00:06<00:06,  7.57it/s]
 53%|█████▎    | 53/100 [00:06<00:05,  8.06it/s]
 55%|█████▌    | 55/100 [00:06<00:04,  9.56it/s]
 57%|█████▋    | 57/100 [00:06<00:03, 10.75it/s]
 59%|█████▉    | 59/100 [00:07<00:03, 11.54it/s]
 61%|██████    | 61/100 [00:07<00:03, 12.08it/s]
 63%|██████▎   | 63/100 [00:07<00:02, 12.45it/s]
 65%|██████▌   | 65/100 [00:07<00:02, 12.71it/s]
 67%|██████▋   | 67/100 [00:07<00:02, 12.89it/s]
 69%|██████▉   | 69/100 [00:07<00:02, 10.94it/s]
 71%|███████   | 71/100 [00:08<00:03,  8.27it/s]
 72%|███████▏  | 72/100 [00:08<00:03,  7.73it/s]
 73%|███████▎  | 73/100 [00:08<00:03,  7.07it/s]
 74%|███████▍  | 74/100 [00:08<00:03,  6.65it/s]
 75%|███████▌  | 75/100 [00:09<00:03,  6.34it/s]
 76%|███████▌  | 76/100 [00:09<00:03,  6.06it/s]
 77%|███████▋  | 77/100 [00:09<00:03,  5.95it/s]
 78%|███████▊  | 78/100 [00:09<00:03,  5.75it/s]
 79%|███████▉  | 79/100 [00:09<00:03,  5.77it/s]
 80%|████████  | 80/100 [00:10<00:03,  5.62it/s]
 81%|████████  | 81/100 [00:10<00:03,  5.70it/s]
 82%|████████▏ | 82/100 [00:10<00:03,  5.57it/s]
 83%|████████▎ | 83/100 [00:10<00:03,  5.66it/s]
 84%|████████▍ | 84/100 [00:10<00:02,  5.54it/s]
 85%|████████▌ | 85/100 [00:10<00:02,  5.65it/s]
 86%|████████▌ | 86/100 [00:11<00:02,  5.53it/s]
 87%|████████▋ | 87/100 [00:11<00:02,  5.63it/s]
 88%|████████▊ | 88/100 [00:11<00:02,  5.53it/s]
 89%|████████▉ | 89/100 [00:11<00:01,  5.65it/s]
 90%|█████████ | 90/100 [00:11<00:01,  5.54it/s]
 91%|█████████ | 91/100 [00:11<00:01,  5.54it/s]
 93%|█████████▎| 93/100 [00:12<00:00,  7.53it/s]
 95%|█████████▌| 95/100 [00:12<00:00,  9.10it/s]
 97%|█████████▋| 97/100 [00:12<00:00, 10.27it/s]
 99%|█████████▉| 99/100 [00:12<00:00, 11.14it/s]
100%|██████████| 100/100 [00:12<00:00,  7.90it/s]

Wavelet reconstruction with FISTA#

The adjoint reconstruction is fast, but it contains artifacts due to undersampling. We therefore solve a regularized inverse problem using a wavelet sparsity prior.

The reconstruction minimizes a data-fidelity term together with a wavelet regularization term:

\[\min_x \frac{1}{2}\|Ax - y\|_2^2 + \lambda \|Wx\|_1\]

where A is the MRI forward operator, y is the measured k-space data, and W is a wavelet transform. The L2 data-fidelity term enforces consistency with the acquired measurements, while the wavelet prior promotes sparse image representations.

We use FISTA, an accelerated proximal-gradient algorithm, to solve this optimization problem.

Setup and run the reconstruction algorithm Data fidelity term

data_fidelity = L2()
# Algorithm parameters
lamb = 1e1
L = fourier_op.get_lipschitz_cst()
stepsize = 0.8 / float(L)

params_algo = {"stepsize": stepsize, "lambda": lamb, "a": 3}
max_iter = 100
early_stop = True
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/operators/base.py:1080: UserWarning: Lipschitz constant did not converge
  warnings.warn("Lipschitz constant did not converge")

Instantiate the algorithm class to solve the problem.

wavelet_model = optim_builder(
    iteration="FISTA",
    prior=wavelet,
    data_fidelity=data_fidelity,
    early_stop=early_stop,
    max_iter=max_iter,
    params_algo=params_algo,
)
x_wavelet = wavelet_model(y_real, physics)

Total variation reconstruction with PDCP#

We reconstruct the image using a Total Variation (TV) prior solved with the Chambolle-Pock primal-dual algorithm (PDCP). TV promotes piecewise- smooth images while preserving sharp edges.

We solve:

\[\min_x \frac{1}{2}\|Ax - y\|_2^2 + \lambda \operatorname{TV}(x)\]
from deepinv.optim import PDCP
from deepinv.optim.data_fidelity import L2Distance


def pdcp_cost_fn(x, data_fidelity, prior, cur_params, y, physics):
    return data_fidelity(cur_params["K"](x), y) + cur_params["lambda"] * prior(x)


lamb_tv = 50
tv = TVPrior(n_it_max=20)
stepsize_pdcp = 1.0 / float(L)

pdcp_model = PDCP(
    K=physics.A,
    K_adjoint=physics.A_adjoint,
    data_fidelity=L2Distance(),
    prior=tv,
    lambda_reg=lamb_tv,
    stepsize=stepsize_pdcp,
    stepsize_dual=1.0,
    max_iter=20,
    g_first=False,
    cost_fn=pdcp_cost_fn,
)

x_pdcp_real = pdcp_model(y_real, physics)

Quantitative evaluation#

We evaluate the reconstructions with PSNR and SSIM. PSNR measures the reconstruction fidelity with respect to the reference image: higher PSNR means lower pixel-wise error. SSIM measures structural similarity and is often more informative for images because it compares local contrast and structure.

Metrics are computed on magnitude images, since the reconstructions are complex-valued.

from deepinv.loss.metric import PSNR, SSIM

psnr = PSNR(max_pixel=None)
ssim = SSIM()

x_ref = torch.abs(mri).unsqueeze(0).unsqueeze(0)

from mrinufft.operators.autodiff import image_as_cpx


def to_magnitude(x):
    return torch.abs(image_as_cpx(x))


x_adjoint_mag = to_magnitude(x_dagger)
x_wavelet_mag = to_magnitude(x_wavelet)
x_pdcp_mag = to_magnitude(x_pdcp_real)

print(f"Adjoint PSNR: {psnr(x_adjoint_mag, x_ref).item():.2f}")
print(f"Wavelet PSNR: {psnr(x_wavelet_mag, x_ref).item():.2f}")
print(f"TV-PDCP PSNR: {psnr(x_pdcp_mag, x_ref).item():.2f}")

print(f"Adjoint SSIM: {ssim(x_adjoint_mag, x_ref).item():.4f}")
print(f"Wavelet SSIM: {ssim(x_wavelet_mag, x_ref).item():.4f}")
print(f"TV-PDCP SSIM: {ssim(x_pdcp_mag, x_ref).item():.4f}")
Adjoint PSNR: 11.98
Wavelet PSNR: 26.31
TV-PDCP PSNR: 17.74
Adjoint SSIM: 0.4460
Wavelet SSIM: 0.5877
TV-PDCP SSIM: 0.3809

Visualize the reconstructions#

We compare the ground-truth image, the adjoint reconstruction, and the wavelet reconstruction.

slice_idx = mri.shape[-1] // 2 - 5

fig, axes = plt.subplots(1, 4, figsize=(20, 6))

images = [
    (torch.abs(mri[..., slice_idx]).detach().cpu(), "Ground truth"),
    (torch.abs(x_dagger[0, 0, ..., slice_idx]).detach().cpu(), "Adjoint"),
    (torch.abs(x_wavelet[0, 0, ..., slice_idx]).detach().cpu(), "Wavelet"),
    (torch.abs(x_pdcp_mag[0, 0, ..., slice_idx]).detach().cpu(), "TV-PDCP"),
]

for ax, (image, title) in zip(axes, images):
    ax.imshow(image, cmap="gray")
    ax.set_title(title)
    ax.axis("off")

plt.tight_layout()
plt.show()
Ground truth, Adjoint, Wavelet, TV-PDCP

Total running time of the script: (5 minutes 20.243 seconds)

Gallery generated by Sphinx-Gallery