Learning sampling pattern with decimation#

An example using PyTorch to showcase learning k-space sampling patterns with decimation.

This example showcases the auto-differentiation capabilities of the NUFFT operator with respect to the k-space trajectory in MRI-nufft.

Hereafter we learn the k-space sample locations \(\mathbf{K}\) using the following cost function:

\[\mathbf{\hat{K}} = arg \min_{\mathbf{K}} || \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} \mathbf{x} - \mathbf{x} ||_2^2\]

where \(\mathcal{F}_\mathbf{K}\) is the forward NUFFT operator, \(D_\mathbf{K}\) is the density compensator for trajectory \(\mathbf{K}\), and \(\mathbf{x}\) is the MR image which is also the target image to be reconstructed.

Additionally, in order to converge faster, we also learn the trajectory in a multi-resolution fashion. This is done by first optimizing x8 times decimated trajectory locations, called control points. After a fixed number of iterations (5 in this example), these control points are upscaled by a factor of 2. Note that the NUFFT operator always holds linearly interpolated version of the control points as k-space sampling trajectory.

Note

This example can run on a binder instance as it is purely CPU based backend (finufft), and is restricted to a 2D single coil toy case.

Warning

This example only showcases the auto-differentiation capabilities, the learned sampling pattern is not scanner compliant as the gradients required to implement it violate the hardware constraints. In practice, a projection \(\Pi_\mathcal{Q}(\mathbf{K})\) onto the scanner constraints set \(\mathcal{Q}\) is recommended (see [Cha+16]). This is implemented in the proprietary SPARKLING package [Cha+22]. Users are encouraged to contact the authors if they want to use it.

import time

import brainweb_dl as bwdl
import joblib
import matplotlib.pyplot as plt
import numpy as np
import tempfile as tmp
import torch
from PIL import Image, ImageSequence
from tqdm import tqdm

from mrinufft import get_operator
from mrinufft.trajectories import initialize_2D_radial

Utils#

Model class#

Note

While we are only learning the NUFFT operator, we still need the gradient wrt_data=True to have all the gradients computed correctly. See [GRC23] for more details.

class Model(torch.nn.Module):
    def __init__(
        self,
        inital_trajectory,
        img_size=(256, 256),
        start_decim=8,
        interpolation_mode="linear",
    ):
        super(Model, self).__init__()
        self.control = torch.nn.Parameter(
            data=torch.Tensor(inital_trajectory[:, ::start_decim]),
            requires_grad=True,
        )
        self.current_decim = start_decim
        self.interpolation_mode = interpolation_mode
        sample_points = inital_trajectory.reshape(-1, inital_trajectory.shape[-1])
        self.operator = get_operator("finufft", wrt_data=True, wrt_traj=True)(
            sample_points,
            shape=img_size,
            squeeze_dims=False,
        )
        self.img_size = img_size

    def _interpolate(self, traj, factor=2):
        """Torch interpolate function to upsample the trajectory"""
        return torch.nn.functional.interpolate(
            traj.moveaxis(1, -1),
            scale_factor=factor,
            mode=self.interpolation_mode,
            align_corners=True,
        ).moveaxis(-1, 1)

    def get_trajectory(self):
        """Function to get trajectory, which is interpolated version of control points."""
        traj = self.control.clone()
        for i in range(np.log2(self.current_decim).astype(int)):
            traj = self._interpolate(traj)

        return traj.reshape(-1, traj.shape[-1])

    def upscale(self, factor=2):
        """Upscaling the model.
        In this step, the number of control points are doubled and interpolated.
        """
        self.control = torch.nn.Parameter(
            data=self._interpolate(self.control),
            requires_grad=True,
        )
        self.current_decim /= factor

    def forward(self, x):
        traj = self.get_trajectory()
        self.operator.samples = traj

        # Simulate the acquisition process
        kspace = self.operator.op(x)

        adjoint = self.operator.adj_op(kspace).abs()
        return adjoint / torch.mean(adjoint)

State plotting#

def plot_state(axs, image, traj, recon, control_points=None, loss=None, save_name=None):
    axs = axs.flatten()
    # Upper left reference image
    axs[0].imshow(np.abs(image[0]), cmap="gray")
    axs[0].axis("off")
    axs[0].set_title("MR Image")

    # Upper right trajectory
    axs[1].scatter(*traj.T, s=0.5)
    if control_points is not None:
        axs[1].scatter(*control_points.T, s=1, color="r")
        axs[1].legend(
            ["Trajectory", "Control points"], loc="right", bbox_to_anchor=(2, 0.6)
        )
    axs[1].grid(True)
    axs[1].set_title("Trajectory")
    axs[1].set_xlim(-0.5, 0.5)
    axs[1].set_ylim(-0.5, 0.5)
    axs[1].set_aspect("equal")

    # Down left reconstructed image
    axs[2].imshow(np.abs(recon[0][0].detach().cpu().numpy()), cmap="gray")
    axs[2].axis("off")
    axs[2].set_title("Reconstruction")

    # Down right loss evolution
    if loss is not None:
        axs[3].plot(loss)
        axs[3].set_ylim(0, None)
        axs[3].grid("on")
        axs[3].set_title("Loss")
        plt.subplots_adjust(hspace=0.3)

    # Save & close
    if save_name is not None:
        plt.savefig(save_name, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

Optimizer upscaling#

def upsample_optimizer(optimizer, new_optimizer, factor=2):
    """Upsample the optimizer."""
    for old_group, new_group in zip(optimizer.param_groups, new_optimizer.param_groups):
        for old_param, new_param in zip(old_group["params"], new_group["params"]):
            # Interpolate optimizer states
            if old_param in optimizer.state:
                for key in optimizer.state[old_param].keys():
                    if isinstance(optimizer.state[old_param][key], torch.Tensor):
                        old_state = optimizer.state[old_param][key]
                        if old_state.ndim == 0:
                            new_state = old_state
                        else:
                            new_state = torch.nn.functional.interpolate(
                                old_state.moveaxis(1, -1),
                                scale_factor=factor,
                                mode="linear",
                            ).moveaxis(-1, 1)
                        new_optimizer.state[new_param][key] = new_state
                    else:
                        new_optimizer.state[new_param][key] = optimizer.state[
                            old_param
                        ][key]
    return new_optimizer

Data preparation#

A single image to train the model over. Note that in practice we would use a whole dataset instead (e.g. fastMRI).

volume = np.flip(bwdl.get_mri(4, "T1"), axis=(0, 1, 2))
image = torch.from_numpy(volume[-80, ...].astype(np.float32))[None]
image = image / torch.mean(image)

A basic radial trajectory with an acceleration factor of 8.

AF = 8
initial_traj = initialize_2D_radial(image.shape[1] // AF, image.shape[2]).astype(
    np.float32
)

Trajectory learning#

Initialisation#

model = Model(initial_traj, img_size=image.shape[1:])
model = model.eval()
/volatile/github-ci-mind-inria/gpu_runner/_work/_tool/Python/3.10.15/x64/lib/python3.10/site-packages/mrinufft/_utils.py:94: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(

The image obtained before learning the sampling pattern is highly degraded because of the acceleration factor and simplicity of the trajectory.

initial_recons = model(image)

fig, axs = plt.subplots(1, 3, figsize=(9, 3))
plot_state(axs, image, initial_traj, initial_recons)
MR Image, Trajectory, Reconstruction
/volatile/github-ci-mind-inria/gpu_runner/_work/_tool/Python/3.10.15/x64/lib/python3.10/site-packages/mrinufft/_utils.py:94: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
/volatile/github-ci-mind-inria/gpu_runner/_work/_tool/Python/3.10.15/x64/lib/python3.10/site-packages/finufft/_interfaces.py:329: UserWarning: Argument `data` does not satisfy the following requirement: C. Copying array (this may reduce performance)
  warnings.warn(f"Argument `{name}` does not satisfy the following requirement: {prop}. Copying array (this may reduce performance)")

Training loop#

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train()

losses = []
image_files = []
while model.current_decim >= 1:
    with tqdm(range(30), unit="steps") as tqdms:
        for i in tqdms:
            out = model(image)
            loss = torch.nn.functional.mse_loss(out, image[None, None])
            numpy_loss = (loss.detach().cpu().numpy(),)

            tqdms.set_postfix({"loss": numpy_loss})
            losses.append(numpy_loss)
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            with torch.no_grad():
                # Clamp the value of trajectory between [-0.5, 0.5]
                for param in model.parameters():
                    param.clamp_(-0.5, 0.5)
            # Generate images for gif
            filename = f"{tmp.NamedTemporaryFile().name}.png"
            plt.clf()
            fig, axs = plt.subplots(2, 2, figsize=(10, 10), num=1)
            plot_state(
                axs,
                image,
                model.get_trajectory().detach().cpu().numpy(),
                out,
                model.control.detach().cpu().numpy(),
                losses,
                save_name=filename,
            )
            image_files.append(filename)
        if model.current_decim == 1:
            break
        else:
            model.upscale()
            optimizer = upsample_optimizer(
                optimizer, torch.optim.Adam(model.parameters(), lr=1e-3)
            )
  0%|          | 0/30 [00:00<?, ?steps/s]/volatile/github-ci-mind-inria/gpu_runner/_work/_tool/Python/3.10.15/x64/lib/python3.10/site-packages/mrinufft/_utils.py:94: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
/volatile/github-ci-mind-inria/gpu_runner/_work/_tool/Python/3.10.15/x64/lib/python3.10/site-packages/finufft/_interfaces.py:329: UserWarning: Argument `data` does not satisfy the following requirement: C. Copying array (this may reduce performance)
  warnings.warn(f"Argument `{name}` does not satisfy the following requirement: {prop}. Copying array (this may reduce performance)")
/volatile/github-ci-mind-inria/gpu_runner/_work/mri-nufft/mri-nufft/examples/example_learn_samples_multires.py:259: UserWarning: Using a target size (torch.Size([1, 1, 1, 256, 256])) that is different to the input size (torch.Size([1, 1, 256, 256])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  loss = torch.nn.functional.mse_loss(out, image[None, None])

  0%|          | 0/30 [00:00<?, ?steps/s, loss=(array(0.86704654, dtype=float32),)]/volatile/github-ci-mind-inria/gpu_runner/_work/_tool/Python/3.10.15/x64/lib/python3.10/site-packages/mrinufft/operators/autodiff.py:98: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:308.)
  grad_traj = torch.transpose(torch.sum(grad_traj, dim=1), 0, 1).to(

  3%|▎         | 1/30 [00:00<00:20,  1.42steps/s, loss=(array(0.86704654, dtype=float32),)]
  3%|▎         | 1/30 [00:00<00:20,  1.42steps/s, loss=(array(0.5286128, dtype=float32),)]
  7%|▋         | 2/30 [00:01<00:17,  1.57steps/s, loss=(array(0.5286128, dtype=float32),)]
  7%|▋         | 2/30 [00:01<00:17,  1.57steps/s, loss=(array(0.36503774, dtype=float32),)]
 10%|█         | 3/30 [00:01<00:18,  1.50steps/s, loss=(array(0.36503774, dtype=float32),)]
 10%|█         | 3/30 [00:02<00:18,  1.50steps/s, loss=(array(0.4146319, dtype=float32),)]
 13%|█▎        | 4/30 [00:02<00:16,  1.55steps/s, loss=(array(0.4146319, dtype=float32),)]
 13%|█▎        | 4/30 [00:02<00:16,  1.55steps/s, loss=(array(0.30785707, dtype=float32),)]
 17%|█▋        | 5/30 [00:03<00:16,  1.55steps/s, loss=(array(0.30785707, dtype=float32),)]
 17%|█▋        | 5/30 [00:03<00:16,  1.55steps/s, loss=(array(0.32219106, dtype=float32),)]
 20%|██        | 6/30 [00:03<00:15,  1.57steps/s, loss=(array(0.32219106, dtype=float32),)]
 20%|██        | 6/30 [00:03<00:15,  1.57steps/s, loss=(array(0.34580207, dtype=float32),)]
 23%|██▎       | 7/30 [00:04<00:13,  1.65steps/s, loss=(array(0.34580207, dtype=float32),)]
 23%|██▎       | 7/30 [00:04<00:13,  1.65steps/s, loss=(array(0.3096822, dtype=float32),)]
 27%|██▋       | 8/30 [00:05<00:13,  1.60steps/s, loss=(array(0.3096822, dtype=float32),)]
 27%|██▋       | 8/30 [00:05<00:13,  1.60steps/s, loss=(array(0.29107636, dtype=float32),)]
 30%|███       | 9/30 [00:05<00:14,  1.44steps/s, loss=(array(0.29107636, dtype=float32),)]
 30%|███       | 9/30 [00:05<00:14,  1.44steps/s, loss=(array(0.28929195, dtype=float32),)]
 33%|███▎      | 10/30 [00:06<00:13,  1.50steps/s, loss=(array(0.28929195, dtype=float32),)]
 33%|███▎      | 10/30 [00:06<00:13,  1.50steps/s, loss=(array(0.27598783, dtype=float32),)]
 37%|███▋      | 11/30 [00:07<00:12,  1.53steps/s, loss=(array(0.27598783, dtype=float32),)]
 37%|███▋      | 11/30 [00:07<00:12,  1.53steps/s, loss=(array(0.27360773, dtype=float32),)]
 40%|████      | 12/30 [00:07<00:11,  1.57steps/s, loss=(array(0.27360773, dtype=float32),)]
 40%|████      | 12/30 [00:07<00:11,  1.57steps/s, loss=(array(0.2685275, dtype=float32),)]
 43%|████▎     | 13/30 [00:08<00:10,  1.56steps/s, loss=(array(0.2685275, dtype=float32),)]
 43%|████▎     | 13/30 [00:08<00:10,  1.56steps/s, loss=(array(0.25053963, dtype=float32),)]
 47%|████▋     | 14/30 [00:09<00:10,  1.59steps/s, loss=(array(0.25053963, dtype=float32),)]
 47%|████▋     | 14/30 [00:09<00:10,  1.59steps/s, loss=(array(0.23135254, dtype=float32),)]
 50%|█████     | 15/30 [00:09<00:09,  1.64steps/s, loss=(array(0.23135254, dtype=float32),)]
 50%|█████     | 15/30 [00:09<00:09,  1.64steps/s, loss=(array(0.24112129, dtype=float32),)]
 53%|█████▎    | 16/30 [00:10<00:08,  1.66steps/s, loss=(array(0.24112129, dtype=float32),)]
 53%|█████▎    | 16/30 [00:10<00:08,  1.66steps/s, loss=(array(0.23986332, dtype=float32),)]
 57%|█████▋    | 17/30 [00:10<00:07,  1.78steps/s, loss=(array(0.23986332, dtype=float32),)]
 57%|█████▋    | 17/30 [00:10<00:07,  1.78steps/s, loss=(array(0.22417977, dtype=float32),)]
 60%|██████    | 18/30 [00:11<00:06,  1.74steps/s, loss=(array(0.22417977, dtype=float32),)]
 60%|██████    | 18/30 [00:11<00:06,  1.74steps/s, loss=(array(0.21919632, dtype=float32),)]
 63%|██████▎   | 19/30 [00:12<00:07,  1.53steps/s, loss=(array(0.21919632, dtype=float32),)]
 63%|██████▎   | 19/30 [00:12<00:07,  1.53steps/s, loss=(array(0.20742278, dtype=float32),)]
 67%|██████▋   | 20/30 [00:12<00:06,  1.57steps/s, loss=(array(0.20742278, dtype=float32),)]
 67%|██████▋   | 20/30 [00:12<00:06,  1.57steps/s, loss=(array(0.20717141, dtype=float32),)]
 70%|███████   | 21/30 [00:13<00:05,  1.59steps/s, loss=(array(0.20717141, dtype=float32),)]
 70%|███████   | 21/30 [00:13<00:05,  1.59steps/s, loss=(array(0.19848356, dtype=float32),)]
 73%|███████▎  | 22/30 [00:13<00:05,  1.59steps/s, loss=(array(0.19848356, dtype=float32),)]
 73%|███████▎  | 22/30 [00:13<00:05,  1.59steps/s, loss=(array(0.18597358, dtype=float32),)]
 77%|███████▋  | 23/30 [00:14<00:04,  1.61steps/s, loss=(array(0.18597358, dtype=float32),)]
 77%|███████▋  | 23/30 [00:14<00:04,  1.61steps/s, loss=(array(0.18525052, dtype=float32),)]
 80%|████████  | 24/30 [00:15<00:03,  1.59steps/s, loss=(array(0.18525052, dtype=float32),)]
 80%|████████  | 24/30 [00:15<00:03,  1.59steps/s, loss=(array(0.17937596, dtype=float32),)]
 83%|████████▎ | 25/30 [00:15<00:03,  1.62steps/s, loss=(array(0.17937596, dtype=float32),)]
 83%|████████▎ | 25/30 [00:15<00:03,  1.62steps/s, loss=(array(0.17472246, dtype=float32),)]
 87%|████████▋ | 26/30 [00:16<00:02,  1.46steps/s, loss=(array(0.17472246, dtype=float32),)]
 87%|████████▋ | 26/30 [00:16<00:02,  1.46steps/s, loss=(array(0.16732308, dtype=float32),)]
 90%|█████████ | 27/30 [00:17<00:01,  1.50steps/s, loss=(array(0.16732308, dtype=float32),)]
 90%|█████████ | 27/30 [00:17<00:01,  1.50steps/s, loss=(array(0.1665864, dtype=float32),)]
 93%|█████████▎| 28/30 [00:17<00:01,  1.55steps/s, loss=(array(0.1665864, dtype=float32),)]
 93%|█████████▎| 28/30 [00:17<00:01,  1.55steps/s, loss=(array(0.15969816, dtype=float32),)]
 97%|█████████▋| 29/30 [00:18<00:00,  1.38steps/s, loss=(array(0.15969816, dtype=float32),)]
 97%|█████████▋| 29/30 [00:18<00:00,  1.38steps/s, loss=(array(0.15472594, dtype=float32),)]
100%|██████████| 30/30 [00:19<00:00,  1.45steps/s, loss=(array(0.15472594, dtype=float32),)]
100%|██████████| 30/30 [00:19<00:00,  1.55steps/s, loss=(array(0.15472594, dtype=float32),)]

  0%|          | 0/30 [00:00<?, ?steps/s]
  0%|          | 0/30 [00:00<?, ?steps/s, loss=(array(0.15260476, dtype=float32),)]
  3%|▎         | 1/30 [00:00<00:17,  1.64steps/s, loss=(array(0.15260476, dtype=float32),)]
  3%|▎         | 1/30 [00:00<00:17,  1.64steps/s, loss=(array(0.15010233, dtype=float32),)]
  7%|▋         | 2/30 [00:01<00:17,  1.58steps/s, loss=(array(0.15010233, dtype=float32),)]
  7%|▋         | 2/30 [00:01<00:17,  1.58steps/s, loss=(array(0.14741988, dtype=float32),)]
 10%|█         | 3/30 [00:01<00:16,  1.61steps/s, loss=(array(0.14741988, dtype=float32),)]
 10%|█         | 3/30 [00:01<00:16,  1.61steps/s, loss=(array(0.1448765, dtype=float32),)]
 13%|█▎        | 4/30 [00:02<00:16,  1.61steps/s, loss=(array(0.1448765, dtype=float32),)]
 13%|█▎        | 4/30 [00:02<00:16,  1.61steps/s, loss=(array(0.14455202, dtype=float32),)]
 17%|█▋        | 5/30 [00:03<00:15,  1.61steps/s, loss=(array(0.14455202, dtype=float32),)]
 17%|█▋        | 5/30 [00:03<00:15,  1.61steps/s, loss=(array(0.13923562, dtype=float32),)]
 20%|██        | 6/30 [00:03<00:14,  1.64steps/s, loss=(array(0.13923562, dtype=float32),)]
 20%|██        | 6/30 [00:03<00:14,  1.64steps/s, loss=(array(0.13722605, dtype=float32),)]
 23%|██▎       | 7/30 [00:04<00:13,  1.67steps/s, loss=(array(0.13722605, dtype=float32),)]
 23%|██▎       | 7/30 [00:04<00:13,  1.67steps/s, loss=(array(0.1342723, dtype=float32),)]
 27%|██▋       | 8/30 [00:04<00:12,  1.70steps/s, loss=(array(0.1342723, dtype=float32),)]
 27%|██▋       | 8/30 [00:04<00:12,  1.70steps/s, loss=(array(0.13274895, dtype=float32),)]
 30%|███       | 9/30 [00:05<00:15,  1.37steps/s, loss=(array(0.13274895, dtype=float32),)]
 30%|███       | 9/30 [00:05<00:15,  1.37steps/s, loss=(array(0.1312033, dtype=float32),)]
 33%|███▎      | 10/30 [00:06<00:13,  1.44steps/s, loss=(array(0.1312033, dtype=float32),)]
 33%|███▎      | 10/30 [00:06<00:13,  1.44steps/s, loss=(array(0.13059834, dtype=float32),)]
 37%|███▋      | 11/30 [00:07<00:12,  1.50steps/s, loss=(array(0.13059834, dtype=float32),)]
 37%|███▋      | 11/30 [00:07<00:12,  1.50steps/s, loss=(array(0.12819725, dtype=float32),)]
 40%|████      | 12/30 [00:07<00:11,  1.55steps/s, loss=(array(0.12819725, dtype=float32),)]
 40%|████      | 12/30 [00:07<00:11,  1.55steps/s, loss=(array(0.12721008, dtype=float32),)]
 43%|████▎     | 13/30 [00:08<00:10,  1.55steps/s, loss=(array(0.12721008, dtype=float32),)]
 43%|████▎     | 13/30 [00:08<00:10,  1.55steps/s, loss=(array(0.12574998, dtype=float32),)]
 47%|████▋     | 14/30 [00:09<00:10,  1.50steps/s, loss=(array(0.12574998, dtype=float32),)]
 47%|████▋     | 14/30 [00:09<00:10,  1.50steps/s, loss=(array(0.1241952, dtype=float32),)]
 50%|█████     | 15/30 [00:09<00:09,  1.61steps/s, loss=(array(0.1241952, dtype=float32),)]
 50%|█████     | 15/30 [00:09<00:09,  1.61steps/s, loss=(array(0.12223588, dtype=float32),)]
 53%|█████▎    | 16/30 [00:10<00:08,  1.70steps/s, loss=(array(0.12223588, dtype=float32),)]
 53%|█████▎    | 16/30 [00:10<00:08,  1.70steps/s, loss=(array(0.12121239, dtype=float32),)]
 57%|█████▋    | 17/30 [00:10<00:07,  1.67steps/s, loss=(array(0.12121239, dtype=float32),)]
 57%|█████▋    | 17/30 [00:10<00:07,  1.67steps/s, loss=(array(0.11902349, dtype=float32),)]
 60%|██████    | 18/30 [00:11<00:06,  1.72steps/s, loss=(array(0.11902349, dtype=float32),)]
 60%|██████    | 18/30 [00:11<00:06,  1.72steps/s, loss=(array(0.11780152, dtype=float32),)]
 63%|██████▎   | 19/30 [00:12<00:07,  1.51steps/s, loss=(array(0.11780152, dtype=float32),)]
 63%|██████▎   | 19/30 [00:12<00:07,  1.51steps/s, loss=(array(0.11655162, dtype=float32),)]
 67%|██████▋   | 20/30 [00:12<00:06,  1.53steps/s, loss=(array(0.11655162, dtype=float32),)]
 67%|██████▋   | 20/30 [00:12<00:06,  1.53steps/s, loss=(array(0.11569391, dtype=float32),)]
 70%|███████   | 21/30 [00:13<00:05,  1.59steps/s, loss=(array(0.11569391, dtype=float32),)]
 70%|███████   | 21/30 [00:13<00:05,  1.59steps/s, loss=(array(0.11424895, dtype=float32),)]
 73%|███████▎  | 22/30 [00:13<00:04,  1.62steps/s, loss=(array(0.11424895, dtype=float32),)]
 73%|███████▎  | 22/30 [00:13<00:04,  1.62steps/s, loss=(array(0.1132096, dtype=float32),)]
 77%|███████▋  | 23/30 [00:14<00:04,  1.67steps/s, loss=(array(0.1132096, dtype=float32),)]
 77%|███████▋  | 23/30 [00:14<00:04,  1.67steps/s, loss=(array(0.11162986, dtype=float32),)]
 80%|████████  | 24/30 [00:15<00:03,  1.61steps/s, loss=(array(0.11162986, dtype=float32),)]
 80%|████████  | 24/30 [00:15<00:03,  1.61steps/s, loss=(array(0.11079919, dtype=float32),)]
 83%|████████▎ | 25/30 [00:15<00:03,  1.63steps/s, loss=(array(0.11079919, dtype=float32),)]
 83%|████████▎ | 25/30 [00:15<00:03,  1.63steps/s, loss=(array(0.10959952, dtype=float32),)]
 87%|████████▋ | 26/30 [00:16<00:02,  1.67steps/s, loss=(array(0.10959952, dtype=float32),)]
 87%|████████▋ | 26/30 [00:16<00:02,  1.67steps/s, loss=(array(0.10853799, dtype=float32),)]
 90%|█████████ | 27/30 [00:16<00:01,  1.59steps/s, loss=(array(0.10853799, dtype=float32),)]
 90%|█████████ | 27/30 [00:17<00:01,  1.59steps/s, loss=(array(0.10743055, dtype=float32),)]
 93%|█████████▎| 28/30 [00:17<00:01,  1.60steps/s, loss=(array(0.10743055, dtype=float32),)]
 93%|█████████▎| 28/30 [00:17<00:01,  1.60steps/s, loss=(array(0.10652827, dtype=float32),)]
 97%|█████████▋| 29/30 [00:18<00:00,  1.39steps/s, loss=(array(0.10652827, dtype=float32),)]
 97%|█████████▋| 29/30 [00:18<00:00,  1.39steps/s, loss=(array(0.10561012, dtype=float32),)]
100%|██████████| 30/30 [00:19<00:00,  1.47steps/s, loss=(array(0.10561012, dtype=float32),)]
100%|██████████| 30/30 [00:19<00:00,  1.57steps/s, loss=(array(0.10561012, dtype=float32),)]

  0%|          | 0/30 [00:00<?, ?steps/s]
  0%|          | 0/30 [00:00<?, ?steps/s, loss=(array(0.10487418, dtype=float32),)]
  3%|▎         | 1/30 [00:00<00:18,  1.54steps/s, loss=(array(0.10487418, dtype=float32),)]
  3%|▎         | 1/30 [00:00<00:18,  1.54steps/s, loss=(array(0.10341767, dtype=float32),)]
  7%|▋         | 2/30 [00:01<00:18,  1.54steps/s, loss=(array(0.10341767, dtype=float32),)]
  7%|▋         | 2/30 [00:01<00:18,  1.54steps/s, loss=(array(0.10168627, dtype=float32),)]
 10%|█         | 3/30 [00:01<00:17,  1.58steps/s, loss=(array(0.10168627, dtype=float32),)]
 10%|█         | 3/30 [00:01<00:17,  1.58steps/s, loss=(array(0.0999157, dtype=float32),)]
 13%|█▎        | 4/30 [00:02<00:16,  1.56steps/s, loss=(array(0.0999157, dtype=float32),)]
 13%|█▎        | 4/30 [00:02<00:16,  1.56steps/s, loss=(array(0.0984189, dtype=float32),)]
 17%|█▋        | 5/30 [00:03<00:15,  1.57steps/s, loss=(array(0.0984189, dtype=float32),)]
 17%|█▋        | 5/30 [00:03<00:15,  1.57steps/s, loss=(array(0.0972532, dtype=float32),)]
 20%|██        | 6/30 [00:03<00:15,  1.57steps/s, loss=(array(0.0972532, dtype=float32),)]
 20%|██        | 6/30 [00:03<00:15,  1.57steps/s, loss=(array(0.09613352, dtype=float32),)]
 23%|██▎       | 7/30 [00:04<00:14,  1.61steps/s, loss=(array(0.09613352, dtype=float32),)]
 23%|██▎       | 7/30 [00:04<00:14,  1.61steps/s, loss=(array(0.09505253, dtype=float32),)]
 27%|██▋       | 8/30 [00:05<00:13,  1.62steps/s, loss=(array(0.09505253, dtype=float32),)]
 27%|██▋       | 8/30 [00:05<00:13,  1.62steps/s, loss=(array(0.09403138, dtype=float32),)]
 30%|███       | 9/30 [00:05<00:14,  1.41steps/s, loss=(array(0.09403138, dtype=float32),)]
 30%|███       | 9/30 [00:05<00:14,  1.41steps/s, loss=(array(0.09303587, dtype=float32),)]
 33%|███▎      | 10/30 [00:06<00:13,  1.46steps/s, loss=(array(0.09303587, dtype=float32),)]
 33%|███▎      | 10/30 [00:06<00:13,  1.46steps/s, loss=(array(0.09200649, dtype=float32),)]
 37%|███▋      | 11/30 [00:07<00:12,  1.51steps/s, loss=(array(0.09200649, dtype=float32),)]
 37%|███▋      | 11/30 [00:07<00:12,  1.51steps/s, loss=(array(0.09091684, dtype=float32),)]
 40%|████      | 12/30 [00:07<00:11,  1.56steps/s, loss=(array(0.09091684, dtype=float32),)]
 40%|████      | 12/30 [00:07<00:11,  1.56steps/s, loss=(array(0.08976055, dtype=float32),)]
 43%|████▎     | 13/30 [00:08<00:10,  1.60steps/s, loss=(array(0.08976055, dtype=float32),)]
 43%|████▎     | 13/30 [00:08<00:10,  1.60steps/s, loss=(array(0.08874091, dtype=float32),)]
 47%|████▋     | 14/30 [00:08<00:09,  1.61steps/s, loss=(array(0.08874091, dtype=float32),)]
 47%|████▋     | 14/30 [00:09<00:09,  1.61steps/s, loss=(array(0.08782169, dtype=float32),)]
 50%|█████     | 15/30 [00:10<00:11,  1.33steps/s, loss=(array(0.08782169, dtype=float32),)]
 50%|█████     | 15/30 [00:10<00:11,  1.33steps/s, loss=(array(0.08700931, dtype=float32),)]
 53%|█████▎    | 16/30 [00:10<00:09,  1.42steps/s, loss=(array(0.08700931, dtype=float32),)]
 53%|█████▎    | 16/30 [00:10<00:09,  1.42steps/s, loss=(array(0.08608636, dtype=float32),)]
 57%|█████▋    | 17/30 [00:11<00:08,  1.47steps/s, loss=(array(0.08608636, dtype=float32),)]
 57%|█████▋    | 17/30 [00:11<00:08,  1.47steps/s, loss=(array(0.08510467, dtype=float32),)]
 60%|██████    | 18/30 [00:11<00:07,  1.52steps/s, loss=(array(0.08510467, dtype=float32),)]
 60%|██████    | 18/30 [00:11<00:07,  1.52steps/s, loss=(array(0.08416025, dtype=float32),)]
 63%|██████▎   | 19/30 [00:12<00:08,  1.35steps/s, loss=(array(0.08416025, dtype=float32),)]
 63%|██████▎   | 19/30 [00:12<00:08,  1.35steps/s, loss=(array(0.08324701, dtype=float32),)]
 67%|██████▋   | 20/30 [00:13<00:07,  1.41steps/s, loss=(array(0.08324701, dtype=float32),)]
 67%|██████▋   | 20/30 [00:13<00:07,  1.41steps/s, loss=(array(0.08234854, dtype=float32),)]
 70%|███████   | 21/30 [00:14<00:06,  1.48steps/s, loss=(array(0.08234854, dtype=float32),)]
 70%|███████   | 21/30 [00:14<00:06,  1.48steps/s, loss=(array(0.08139667, dtype=float32),)]
 73%|███████▎  | 22/30 [00:14<00:05,  1.53steps/s, loss=(array(0.08139667, dtype=float32),)]
 73%|███████▎  | 22/30 [00:14<00:05,  1.53steps/s, loss=(array(0.0804868, dtype=float32),)]
 77%|███████▋  | 23/30 [00:15<00:04,  1.56steps/s, loss=(array(0.0804868, dtype=float32),)]
 77%|███████▋  | 23/30 [00:15<00:04,  1.56steps/s, loss=(array(0.07960066, dtype=float32),)]
 80%|████████  | 24/30 [00:16<00:04,  1.46steps/s, loss=(array(0.07960066, dtype=float32),)]
 80%|████████  | 24/30 [00:16<00:04,  1.46steps/s, loss=(array(0.07872592, dtype=float32),)]
 83%|████████▎ | 25/30 [00:16<00:03,  1.30steps/s, loss=(array(0.07872592, dtype=float32),)]
 83%|████████▎ | 25/30 [00:17<00:03,  1.30steps/s, loss=(array(0.07789394, dtype=float32),)]
 87%|████████▋ | 26/30 [00:17<00:02,  1.40steps/s, loss=(array(0.07789394, dtype=float32),)]
 87%|████████▋ | 26/30 [00:17<00:02,  1.40steps/s, loss=(array(0.07707645, dtype=float32),)]
 90%|█████████ | 27/30 [00:18<00:02,  1.47steps/s, loss=(array(0.07707645, dtype=float32),)]
 90%|█████████ | 27/30 [00:18<00:02,  1.47steps/s, loss=(array(0.07629781, dtype=float32),)]
 93%|█████████▎| 28/30 [00:18<00:01,  1.53steps/s, loss=(array(0.07629781, dtype=float32),)]
 93%|█████████▎| 28/30 [00:18<00:01,  1.53steps/s, loss=(array(0.07554185, dtype=float32),)]
 97%|█████████▋| 29/30 [00:19<00:00,  1.38steps/s, loss=(array(0.07554185, dtype=float32),)]
 97%|█████████▋| 29/30 [00:19<00:00,  1.38steps/s, loss=(array(0.07479402, dtype=float32),)]
100%|██████████| 30/30 [00:20<00:00,  1.46steps/s, loss=(array(0.07479402, dtype=float32),)]
100%|██████████| 30/30 [00:20<00:00,  1.48steps/s, loss=(array(0.07479402, dtype=float32),)]

  0%|          | 0/30 [00:00<?, ?steps/s]
  0%|          | 0/30 [00:00<?, ?steps/s, loss=(array(0.07407662, dtype=float32),)]
  3%|▎         | 1/30 [00:00<00:18,  1.61steps/s, loss=(array(0.07407662, dtype=float32),)]
  3%|▎         | 1/30 [00:00<00:18,  1.61steps/s, loss=(array(0.0726244, dtype=float32),)]
  7%|▋         | 2/30 [00:01<00:17,  1.56steps/s, loss=(array(0.0726244, dtype=float32),)]
  7%|▋         | 2/30 [00:01<00:17,  1.56steps/s, loss=(array(0.07098428, dtype=float32),)]
 10%|█         | 3/30 [00:01<00:17,  1.51steps/s, loss=(array(0.07098428, dtype=float32),)]
 10%|█         | 3/30 [00:01<00:17,  1.51steps/s, loss=(array(0.06965234, dtype=float32),)]
 13%|█▎        | 4/30 [00:02<00:16,  1.53steps/s, loss=(array(0.06965234, dtype=float32),)]
 13%|█▎        | 4/30 [00:02<00:16,  1.53steps/s, loss=(array(0.0685983, dtype=float32),)]
 17%|█▋        | 5/30 [00:03<00:16,  1.53steps/s, loss=(array(0.0685983, dtype=float32),)]
 17%|█▋        | 5/30 [00:03<00:16,  1.53steps/s, loss=(array(0.0676876, dtype=float32),)]
 20%|██        | 6/30 [00:03<00:16,  1.47steps/s, loss=(array(0.0676876, dtype=float32),)]
 20%|██        | 6/30 [00:03<00:16,  1.47steps/s, loss=(array(0.06679374, dtype=float32),)]
 23%|██▎       | 7/30 [00:04<00:15,  1.51steps/s, loss=(array(0.06679374, dtype=float32),)]
 23%|██▎       | 7/30 [00:04<00:15,  1.51steps/s, loss=(array(0.06589988, dtype=float32),)]
 27%|██▋       | 8/30 [00:05<00:14,  1.54steps/s, loss=(array(0.06589988, dtype=float32),)]
 27%|██▋       | 8/30 [00:05<00:14,  1.54steps/s, loss=(array(0.06502866, dtype=float32),)]
 30%|███       | 9/30 [00:06<00:15,  1.39steps/s, loss=(array(0.06502866, dtype=float32),)]
 30%|███       | 9/30 [00:06<00:15,  1.39steps/s, loss=(array(0.06413081, dtype=float32),)]
 33%|███▎      | 10/30 [00:06<00:13,  1.49steps/s, loss=(array(0.06413081, dtype=float32),)]
 33%|███▎      | 10/30 [00:06<00:13,  1.49steps/s, loss=(array(0.06324845, dtype=float32),)]
 37%|███▋      | 11/30 [00:07<00:11,  1.61steps/s, loss=(array(0.06324845, dtype=float32),)]
 37%|███▋      | 11/30 [00:07<00:11,  1.61steps/s, loss=(array(0.06246167, dtype=float32),)]
 40%|████      | 12/30 [00:07<00:10,  1.72steps/s, loss=(array(0.06246167, dtype=float32),)]
 40%|████      | 12/30 [00:07<00:10,  1.72steps/s, loss=(array(0.06174359, dtype=float32),)]
 43%|████▎     | 13/30 [00:08<00:09,  1.79steps/s, loss=(array(0.06174359, dtype=float32),)]
 43%|████▎     | 13/30 [00:08<00:09,  1.79steps/s, loss=(array(0.06104461, dtype=float32),)]
 47%|████▋     | 14/30 [00:08<00:09,  1.62steps/s, loss=(array(0.06104461, dtype=float32),)]
 47%|████▋     | 14/30 [00:08<00:09,  1.62steps/s, loss=(array(0.06036811, dtype=float32),)]
 50%|█████     | 15/30 [00:09<00:09,  1.59steps/s, loss=(array(0.06036811, dtype=float32),)]
 50%|█████     | 15/30 [00:09<00:09,  1.59steps/s, loss=(array(0.05971479, dtype=float32),)]
 53%|█████▎    | 16/30 [00:10<00:09,  1.50steps/s, loss=(array(0.05971479, dtype=float32),)]
 53%|█████▎    | 16/30 [00:10<00:09,  1.50steps/s, loss=(array(0.05908708, dtype=float32),)]
 57%|█████▋    | 17/30 [00:11<00:08,  1.50steps/s, loss=(array(0.05908708, dtype=float32),)]
 57%|█████▋    | 17/30 [00:11<00:08,  1.50steps/s, loss=(array(0.05846841, dtype=float32),)]
 60%|██████    | 18/30 [00:11<00:08,  1.36steps/s, loss=(array(0.05846841, dtype=float32),)]
 60%|██████    | 18/30 [00:11<00:08,  1.36steps/s, loss=(array(0.0578428, dtype=float32),)]
 63%|██████▎   | 19/30 [00:13<00:09,  1.17steps/s, loss=(array(0.0578428, dtype=float32),)]
 63%|██████▎   | 19/30 [00:13<00:09,  1.17steps/s, loss=(array(0.0572281, dtype=float32),)]
 67%|██████▋   | 20/30 [00:13<00:08,  1.25steps/s, loss=(array(0.0572281, dtype=float32),)]
 67%|██████▋   | 20/30 [00:13<00:08,  1.25steps/s, loss=(array(0.05666602, dtype=float32),)]
 70%|███████   | 21/30 [00:14<00:06,  1.36steps/s, loss=(array(0.05666602, dtype=float32),)]
 70%|███████   | 21/30 [00:14<00:06,  1.36steps/s, loss=(array(0.05613874, dtype=float32),)]
 73%|███████▎  | 22/30 [00:14<00:05,  1.44steps/s, loss=(array(0.05613874, dtype=float32),)]
 73%|███████▎  | 22/30 [00:14<00:05,  1.44steps/s, loss=(array(0.05561662, dtype=float32),)]
 77%|███████▋  | 23/30 [00:15<00:04,  1.47steps/s, loss=(array(0.05561662, dtype=float32),)]
 77%|███████▋  | 23/30 [00:15<00:04,  1.47steps/s, loss=(array(0.05512222, dtype=float32),)]
 80%|████████  | 24/30 [00:16<00:04,  1.35steps/s, loss=(array(0.05512222, dtype=float32),)]
 80%|████████  | 24/30 [00:16<00:04,  1.35steps/s, loss=(array(0.05470764, dtype=float32),)]
 83%|████████▎ | 25/30 [00:17<00:03,  1.33steps/s, loss=(array(0.05470764, dtype=float32),)]
 83%|████████▎ | 25/30 [00:17<00:03,  1.33steps/s, loss=(array(0.05457507, dtype=float32),)]
 87%|████████▋ | 26/30 [00:18<00:03,  1.28steps/s, loss=(array(0.05457507, dtype=float32),)]
 87%|████████▋ | 26/30 [00:18<00:03,  1.28steps/s, loss=(array(0.0563061, dtype=float32),)]
 90%|█████████ | 27/30 [00:18<00:02,  1.33steps/s, loss=(array(0.0563061, dtype=float32),)]
 90%|█████████ | 27/30 [00:18<00:02,  1.33steps/s, loss=(array(0.09944884, dtype=float32),)]
 93%|█████████▎| 28/30 [00:19<00:01,  1.30steps/s, loss=(array(0.09944884, dtype=float32),)]
 93%|█████████▎| 28/30 [00:19<00:01,  1.30steps/s, loss=(array(0.43647137, dtype=float32),)]
 97%|█████████▋| 29/30 [00:20<00:00,  1.33steps/s, loss=(array(0.43647137, dtype=float32),)]
 97%|█████████▋| 29/30 [00:20<00:00,  1.33steps/s, loss=(array(0.07823093, dtype=float32),)]
100%|██████████| 30/30 [00:21<00:00,  1.26steps/s, loss=(array(0.07823093, dtype=float32),)]
100%|██████████| 30/30 [00:21<00:00,  1.42steps/s, loss=(array(0.07823093, dtype=float32),)]
# Make a GIF of all images.
imgs = [Image.open(img) for img in image_files]
imgs[0].save(
    "mrinufft_learn_traj_multires.gif",
    save_all=True,
    append_images=imgs[1:],
    optimize=False,
    duration=2,
    loop=0,
)
example learn_samples

Results#

model.eval()
final_recons = model(image)
final_traj = model.get_trajectory().detach().cpu().numpy()
fig, axs = plt.subplots(1, 3, figsize=(9, 3))
plot_state(axs, image, final_traj, final_recons)
plt.show()
MR Image, Trajectory, Reconstruction

The learned trajectory above improves the reconstruction quality as compared to the initial trajectory shown above. Note of course that the reconstructed image is far from perfect because of the documentation rendering constraints. In order to improve the results one can start by training it for more than just 5 iterations per decimation level. Also density compensation should be used, even though it was avoided here for CPU compliance. Check out Learn Sampling pattern to know more.

References#

[Cha+16]

N. Chauffert, P. Weiss, J. Kahn and P. Ciuciu, “A Projection Algorithm for Gradient Waveforms Design in Magnetic Resonance Imaging,” in IEEE Transactions on Medical Imaging, vol. 35, no. 9, pp. 2026-2039, Sept. 2016, doi: 10.1109/TMI.2016.2544251.

[Cha+22]

G. R. Chaithya, P. Weiss, G. Daval-Frérot, A. Massire, A. Vignaud and P. Ciuciu, “Optimizing Full 3D SPARKLING Trajectories for High-Resolution Magnetic Resonance Imaging,” in IEEE Transactions on Medical Imaging, vol. 41, no. 8, pp. 2105-2117, Aug. 2022, doi: 10.1109/TMI.2022.3157269.

[GRC23]

Chaithya GR, and Philippe Ciuciu. 2023. “Jointly Learning Non-Cartesian k-Space Trajectories and Reconstruction Networks for 2D and 3D MR Imaging through Projection” Bioengineering 10, no. 2: 158. https://doi.org/10.3390/bioengineering10020158

Total running time of the script: (1 minutes 28.349 seconds)

Gallery generated by Sphinx-Gallery