Note
Go to the end to download the full example code or to run this example in your browser via Binder.
Learn Sampling pattern#
A small pytorch example to showcase learning k-space sampling patterns. This example showcases the auto-diff capabilities of the NUFFT operator wrt to k-space trajectory in mri-nufft.
In this example, we solve the following optimization problem:
where \(\mathcal{F}_\mathbf{K}\) is the forward NUFFT operator and \(D_\mathbf{K}\) is the density compensators for trajectory \(\mathbf{K}\), \(\mathbf{x}\) is the MR image which is also the target image to be reconstructed.
import os
import brainweb_dl as bwdl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import torch
from mrinufft import get_operator
from mrinufft.trajectories import initialize_2D_radial
from mrinufft.trajectories.projection import project_trajectory
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/.venv/lib/python3.10/site-packages/cupyx/jit/_interface.py:247: FutureWarning: cupyx.jit.rawkernel is experimental. The interface can change in the future.
cupy._util.experimental('cupyx.jit.rawkernel')
Setup a simple class to learn trajectory#
Note
While we are only learning the NUFFT operator, we still need the gradient wrt_data=True to be setup in get_operator to have all the gradients computed correctly.
See [Projector] for more details.
BACKEND = os.environ.get("MRINUFFT_BACKEND", "gpunufft")
plt.rcParams["animation.embed_limit"] = 2**30 # 1GiB is very large.
class Model(torch.nn.Module):
def __init__(self, inital_trajectory):
super(Model, self).__init__()
self.trajectory = torch.nn.Parameter(
data=torch.Tensor(inital_trajectory),
requires_grad=True,
)
self.operator = get_operator(BACKEND, wrt_data=True, wrt_traj=True)(
self.trajectory.detach().cpu().numpy(),
shape=(256, 256),
density=True,
squeeze_dims=False,
)
def forward(self, x):
# Update the trajectory in the NUFFT operator.
# Note that the re-computation of density compensation happens internally.
self.operator.samples = self.trajectory.clone().reshape(-1, 2)
# A simple acquisition model simulated with a forward NUFFT operator
kspace = self.operator.op(x)
# A simple density compensated adjoint operator
adjoint = self.operator.adj_op(kspace)
return adjoint / torch.linalg.norm(adjoint)
Setup Data and Model#
num_epochs = 50
mri_2D = torch.Tensor(np.flipud(bwdl.get_mri(4, "T1")[80, ...]).astype(np.complex64))
mri_2D = mri_2D[None, ...] / torch.linalg.norm(mri_2D)
init_traj = initialize_2D_radial(32, 512).astype(np.float32)
init_traj += 0.01 * np.random.randn(*init_traj.shape).astype(
np.float32
) # Add some noise to the initial trajectory
init_traj = project_trajectory(
init_traj, max_iter=100, verbose=0, TE_pos=0
) # Project the initial trajectory to satisfy hardware constraints
model = Model(init_traj)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
schedulder = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=1, end_factor=1e-4, total_iters=num_epochs
)
model.eval()
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/examples/GPU/example_learn_samples.py:77: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at /pytorch/aten/src/ATen/native/Copy.cpp:308.)
mri_2D = torch.Tensor(np.flipud(bwdl.get_mri(4, "T1")[80, ...]).astype(np.complex64))
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/_utils.py:72: UserWarning: Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)
warnings.warn(
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/src/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
warnings.warn(
Model(
(operator): MRINufftAutoGrad()
)
Training and plotting#
recon = model(mri_2D)
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle("Training Starting")
axs = axs.flatten()
axs[0].imshow(np.abs(mri_2D[0]), cmap="gray")
axs[0].axis("off")
axs[0].set_title("MR Image")
traj_plot = []
for traj in init_traj:
traj_plot.append(axs[1].plot(*traj.T, c="b"))
axs[1].set_title("Trajectory")
recon_im = axs[2].imshow(np.abs(recon.squeeze().detach().cpu().numpy()), cmap="gray")
axs[2].axis("off")
axs[2].set_title("Reconstruction")
(loss_curve,) = axs[3].plot([], [])
axs[3].grid()
axs[3].set_xlabel("epochs")
axs[3].set_ylabel("loss")
fig.tight_layout()
def train():
"""Train loop."""
losses = []
for i in range(num_epochs):
out = model(mri_2D)
loss = torch.norm(out - mri_2D[None]) # Compute loss
optimizer.zero_grad() # Zero gradients
loss.backward() # Backward pass
optimizer.step() # Update weights
with torch.no_grad():
# clamp the value of trajectory between [-0.5, 0.5]
for param in model.parameters():
param = project_trajectory(param, max_iter=100, verbose=0)
model.trajectory.data = param.clamp_(-0.5, 0.5)
schedulder.step()
losses.append(loss.item())
yield (
out.detach().cpu().numpy().squeeze(),
model.trajectory.detach().cpu().numpy(),
losses,
)
def plot_epoch(data):
img, traj, losses = data
cur_epoch = len(losses)
recon_im.set_data(abs(img))
loss_curve.set_xdata(np.arange(cur_epoch))
loss_curve.set_ydata(losses)
for plot, t in zip(traj_plot, traj):
plot[0].set_data(*t.T)
axs[3].set_xlim(0, cur_epoch)
axs[3].set_ylim(0, 1.1 * max(losses))
axs[2].set_title(f"Reconstruction, frame {cur_epoch}/{num_epochs}")
axs[1].set_title(f"Trajectory, frame {cur_epoch}/{num_epochs}")
if cur_epoch < num_epochs:
fig.suptitle("Training in progress " + "." * (1 + cur_epoch % 3))
else:
fig.suptitle("Training complete !")
ani = animation.FuncAnimation(
fig, plot_epoch, train, save_count=num_epochs, repeat=False
)
plt.show()
/volatile/github-ci-mind-inria/gpu_mind_runner/_work/mri-nufft/mri-nufft/examples/GPU/example_learn_samples.py:107: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
axs[0].imshow(np.abs(mri_2D[0]), cmap="gray")
References#
N. Chauffert, P. Weiss, J. Kahn and P. Ciuciu, “A Projection Algorithm for Gradient Waveforms Design in Magnetic Resonance Imaging,” in IEEE Transactions on Medical Imaging, vol. 35, no. 9, pp. 2026-2039, Sept. 2016, doi: 10.1109/TMI.2016.2544251.
Chaithya GR, P. Weiss, G. Daval-Frérot, A. Massire, A. Vignaud and P. Ciuciu, “Optimizing Full 3D SPARKLING Trajectories for High-Resolution Magnetic Resonance Imaging,” in IEEE Transactions on Medical Imaging, vol. 41, no. 8, pp. 2105-2117, Aug. 2022, doi: 10.1109/TMI.2022.3157269.
Chaithya GR, and Philippe Ciuciu. 2023. “Jointly Learning Non-Cartesian k-Space Trajectories and Reconstruction Networks for 2D and 3D MR Imaging through Projection” Bioengineering 10, no. 2: 158. https://doi.org/10.3390/bioengineering10020158
Total running time of the script: (1 minutes 22.455 seconds)