Learn Sampling pattern#

A small pytorch example to showcase learning k-space sampling patterns. This example showcases the auto-diff capabilities of the NUFFT operator wrt to k-space trajectory in mri-nufft.

In this example, we solve the following optimization problem:

\[\mathbf{\hat{K}} = \mathrm{arg} \min_{\mathbf{K}} || \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} \mathbf{x} - \mathbf{x} ||_2^2\]

where \(\mathcal{F}_\mathbf{K}\) is the forward NUFFT operator and \(D_\mathbf{K}\) is the density compensators for trajectory \(\mathbf{K}\), \(\mathbf{x}\) is the MR image which is also the target image to be reconstructed.

import os

import brainweb_dl as bwdl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import torch

from mrinufft import get_operator
from mrinufft.trajectories import initialize_2D_radial
from mrinufft.trajectories.projection import project_trajectory
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/.venv/lib/python3.10/site-packages/cupyx/jit/_interface.py:247: FutureWarning: cupyx.jit.rawkernel is experimental. The interface can change in the future.
  cupy._util.experimental('cupyx.jit.rawkernel')

Setup a simple class to learn trajectory#

Note

While we are only learning the NUFFT operator, we still need the gradient wrt_data=True to be setup in get_operator to have all the gradients computed correctly. See [Projector] for more details.

BACKEND = os.environ.get("MRINUFFT_BACKEND", "gpunufft")

plt.rcParams["animation.embed_limit"] = 2**30  # 1GiB is very large.


class Model(torch.nn.Module):
    def __init__(self, inital_trajectory):
        super(Model, self).__init__()
        self.trajectory = torch.nn.Parameter(
            data=torch.Tensor(inital_trajectory),
            requires_grad=True,
        )
        self.operator = get_operator(BACKEND, wrt_data=True, wrt_traj=True)(
            self.trajectory.detach().cpu().numpy(),
            shape=(256, 256),
            density=True,
            squeeze_dims=False,
        )

    def forward(self, x):
        # Update the trajectory in the NUFFT operator.
        # Note that the re-computation of density compensation happens internally.
        self.operator.samples = self.trajectory.clone().reshape(-1, 2)

        # A simple acquisition model simulated with a forward NUFFT operator
        kspace = self.operator.op(x)

        # A simple density compensated adjoint operator
        adjoint = self.operator.adj_op(kspace)
        return adjoint / torch.linalg.norm(adjoint)

Setup Data and Model#

num_epochs = 50

mri_2D = torch.Tensor(np.flipud(bwdl.get_mri(4, "T1")[80, ...]).astype(np.complex64))
mri_2D = mri_2D[None, ...] / torch.linalg.norm(mri_2D)

init_traj = initialize_2D_radial(32, 512).astype(np.float32)
init_traj += 0.01 * np.random.randn(*init_traj.shape).astype(
    np.float32
)  # Add some noise to the initial trajectory
init_traj = project_trajectory(
    init_traj, max_iter=100, verbose=0, TE_pos=0
)  # Project the initial trajectory to satisfy hardware constraints
model = Model(init_traj)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
schedulder = torch.optim.lr_scheduler.LinearLR(
    optimizer, start_factor=1, end_factor=1e-4, total_iters=num_epochs
)


model.eval()
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/examples/GPU/example_learn_samples.py:77: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at /pytorch/aten/src/ATen/native/Copy.cpp:308.)
  mri_2D = torch.Tensor(np.flipud(bwdl.get_mri(4, "T1")[80, ...]).astype(np.complex64))
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/_utils.py:72: UserWarning: Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)
  warnings.warn(
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(

Model(
  (operator): MRINufftAutoGrad()
)

Training and plotting#

recon = model(mri_2D)

fig, axs = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle("Training Starting")
axs = axs.flatten()

axs[0].imshow(np.abs(mri_2D[0]), cmap="gray")
axs[0].axis("off")
axs[0].set_title("MR Image")

traj_plot = []
for traj in init_traj:
    traj_plot.append(axs[1].plot(*traj.T, c="b"))
axs[1].set_title("Trajectory")

recon_im = axs[2].imshow(np.abs(recon.squeeze().detach().cpu().numpy()), cmap="gray")
axs[2].axis("off")
axs[2].set_title("Reconstruction")
(loss_curve,) = axs[3].plot([], [])
axs[3].grid()
axs[3].set_xlabel("epochs")
axs[3].set_ylabel("loss")

fig.tight_layout()


def train():
    """Train loop."""
    losses = []
    for i in range(num_epochs):
        out = model(mri_2D)
        loss = torch.norm(out - mri_2D[None])  # Compute loss

        optimizer.zero_grad()  # Zero gradients
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
        with torch.no_grad():
            # clamp the value of trajectory between [-0.5, 0.5]
            for param in model.parameters():
                param = project_trajectory(param, max_iter=100, verbose=0)
                model.trajectory.data = param.clamp_(-0.5, 0.5)
        schedulder.step()
        losses.append(loss.item())
        yield (
            out.detach().cpu().numpy().squeeze(),
            model.trajectory.detach().cpu().numpy(),
            losses,
        )


def plot_epoch(data):
    img, traj, losses = data

    cur_epoch = len(losses)
    recon_im.set_data(abs(img))
    loss_curve.set_xdata(np.arange(cur_epoch))
    loss_curve.set_ydata(losses)
    for plot, t in zip(traj_plot, traj):
        plot[0].set_data(*t.T)

    axs[3].set_xlim(0, cur_epoch)
    axs[3].set_ylim(0, 1.1 * max(losses))
    axs[2].set_title(f"Reconstruction, frame {cur_epoch}/{num_epochs}")
    axs[1].set_title(f"Trajectory, frame {cur_epoch}/{num_epochs}")

    if cur_epoch < num_epochs:
        fig.suptitle("Training in progress " + "." * (1 + cur_epoch % 3))
    else:
        fig.suptitle("Training complete !")


ani = animation.FuncAnimation(
    fig, plot_epoch, train, save_count=num_epochs, repeat=False
)
plt.show()
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/examples/GPU/example_learn_samples.py:107: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
  axs[0].imshow(np.abs(mri_2D[0]), cmap="gray")

References#

[Proj]

N. Chauffert, P. Weiss, J. Kahn and P. Ciuciu, “A Projection Algorithm for Gradient Waveforms Design in Magnetic Resonance Imaging,” in IEEE Transactions on Medical Imaging, vol. 35, no. 9, pp. 2026-2039, Sept. 2016, doi: 10.1109/TMI.2016.2544251.

[Sparks]

Chaithya GR, P. Weiss, G. Daval-Frérot, A. Massire, A. Vignaud and P. Ciuciu, “Optimizing Full 3D SPARKLING Trajectories for High-Resolution Magnetic Resonance Imaging,” in IEEE Transactions on Medical Imaging, vol. 41, no. 8, pp. 2105-2117, Aug. 2022, doi: 10.1109/TMI.2022.3157269.

[Projector]

Chaithya GR, and Philippe Ciuciu. 2023. “Jointly Learning Non-Cartesian k-Space Trajectories and Reconstruction Networks for 2D and 3D MR Imaging through Projection” Bioengineering 10, no. 2: 158. https://doi.org/10.3390/bioengineering10020158

Total running time of the script: (1 minutes 22.455 seconds)

Gallery generated by Sphinx-Gallery