Learn Straight line readout pattern#

A small pytorch example to showcase learning k-space sampling patterns. In this example we learn the 2D sampling pattern for a 3D MRI image, assuming straight line readouts. This example showcases the auto-diff capabilities of the NUFFT operator The image resolution is kept small to reduce computation time.

Warning

This example only showcases the autodiff capabilities, the learned sampling pattern is not scanner compliant as the scanner gradients required to implement it violate the hardware constraints. In practice, a projection \(\Pi_\mathcal{Q}(\mathbf{K})\) into the scanner constraints set \(\mathcal{Q}\) is recommended (see [Proj]). This is implemented in the proprietary SPARKLING package [Sparks]. Users are encouraged to contact the authors if they want to use it.

Imports#

import time
import joblib

import brainweb_dl as bwdl
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
from PIL import Image, ImageSequence

from mrinufft import get_operator

Setup a simple class to learn trajectory#

Note

While we are only learning the NUFFT operator, we still need the gradient wrt_data=True to have all the gradients computed correctly. See [Projector] for more details.

class Model(torch.nn.Module):
    def __init__(self, num_shots, img_size, factor_cartesian=0.3):
        super(Model, self).__init__()
        self.num_samples_per_shot = 128
        cart_del = 1 / img_size[0]
        num_cart_points = np.round(np.sqrt(factor_cartesian * num_shots)).astype(int)
        edge_center = cart_del * num_cart_points / 2

        self.central_points = torch.nn.Parameter(
            data=torch.stack(
                torch.meshgrid(
                    torch.linspace(-edge_center, edge_center, num_cart_points),
                    torch.linspace(-edge_center, edge_center, num_cart_points),
                    indexing="ij",
                ),
                axis=-1,
            ).reshape(-1, 2),
            requires_grad=False,
        )
        self.non_center_points = torch.nn.Parameter(
            data=torch.Tensor(
                np.random.random((num_shots - self.central_points.shape[0], 2)) - 0.5
            ),
            requires_grad=True,
        )
        self.operator = get_operator("gpunufft", wrt_data=True, wrt_traj=True)(
            np.random.random(
                (self.get_2D_points().shape[0] * self.num_samples_per_shot, 3)
            )
            - 0.5,
            shape=img_size,
            density=True,
            squeeze_dims=False,
        )

    def get_trajectory(self, get_as_shot=False):
        samples = self._get_3D_points(self.get_2D_points())
        if not get_as_shot:
            return samples
        return samples.reshape(-1, self.num_samples_per_shot, 3)

    def get_2D_points(self):
        return torch.vstack([self.central_points, self.non_center_points])

    def _get_3D_points(self, samples2D):
        line = torch.linspace(
            -0.5,
            0.5,
            self.num_samples_per_shot,
            device=samples2D.device,
            dtype=samples2D.dtype,
        )
        return torch.stack(
            [
                line.repeat(samples2D.shape[0], 1),
                samples2D[:, 0].repeat(self.num_samples_per_shot, 1).T,
                samples2D[:, 1].repeat(self.num_samples_per_shot, 1).T,
            ],
            dim=-1,
        ).reshape(-1, 3)

    def forward(self, x):
        self.operator.samples = self.get_trajectory()
        kspace = self.operator.op(x)
        adjoint = self.operator.adj_op(kspace).abs()
        return adjoint / torch.mean(adjoint)

Util function to plot the state of the model#

def plot_state(mri_2D, traj, recon, loss=None, save_name=None, i=None):
    fig_grid = (2, 2)
    if loss is None:
        fig_grid = (1, 3)
    fig, axs = plt.subplots(*fig_grid, figsize=tuple(i * 5 for i in fig_grid[::-1]))
    axs = axs.flatten()
    axs[0].imshow(np.abs(mri_2D[0][..., 11]), cmap="gray")
    axs[0].axis("off")
    axs[0].set_title("MR Image")
    if traj.shape[-1] == 3:
        if i is not None and i > 50:
            axs[1].scatter(*traj.T[1:3, 0], s=10, color="blue")
        else:
            fig_kwargs = {}
            plt_kwargs = {"s": 1, "alpha": 0.2}
            if i is not None:
                fig_kwargs["azim"], fig_kwargs["elev"] = (
                    i / 50 * 60 - 60,
                    30 - i / 50 * 30,
                )
                plt_kwargs["alpha"] = 0.2 + 0.8 * i / 50
                plt_kwargs["s"] = 1 + 9 * i / 50
            axs[1].remove()
            axs[1] = fig.add_subplot(*fig_grid, 2, projection="3d", **fig_kwargs)
            for shot in traj:
                axs[1].scatter(*shot.T, color="blue", **plt_kwargs)
    else:
        axs[1].scatter(*traj.T, s=10)
    axs[1].set_title("Trajectory")
    axs[2].imshow(np.abs(recon[0][0][..., 11].detach().cpu().numpy()), cmap="gray")
    axs[2].axis("off")
    axs[2].set_title("Reconstruction")
    if loss is not None:
        axs[3].plot(loss)
        axs[3].grid("on")
        axs[3].set_title("Loss")
    if save_name is not None:
        plt.savefig(save_name, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

Setup model and optimizer#

cart_data = np.flipud(bwdl.get_mri(4, "T1")).T[::8, ::8, ::8].astype(np.complex64)
model = Model(253, cart_data.shape)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

Setup data#

mri_3D = torch.Tensor(cart_data)[None]
mri_3D = mri_3D / torch.mean(mri_3D)
model.eval()
recon = model(mri_3D)
plot_state(mri_3D, model.get_trajectory(True).detach().cpu().numpy(), recon)
MR Image, Reconstruction, Trajectory
/volatile/github-ci-mind-inria/action-runner/_work/mri-nufft/mri-nufft/examples/GPU/example_learn_straight_line_readouts.py:174: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:305.)
  mri_3D = torch.Tensor(cart_data)[None]

Start training loop#

losses = []
image_files = []
model.train()
with tqdm(range(100), unit="steps") as tqdms:
    for i in tqdms:
        out = model(mri_3D)
        loss = torch.nn.functional.mse_loss(out, mri_3D[None])
        numpy_loss = loss.detach().cpu().numpy()
        tqdms.set_postfix({"loss": numpy_loss})
        losses.append(numpy_loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            # Clamp the value of trajectory between [-0.5, 0.5]
            for param in model.parameters():
                param.clamp_(-0.5, 0.5)
        # Generate images for gif
        hashed = joblib.hash((i, "learn_line", time.time()))
        filename = "/tmp/" + f"{hashed}.png"
        plot_state(
            mri_3D,
            model.get_trajectory(True).detach().cpu().numpy(),
            out,
            losses,
            save_name=filename,
            i=i,
        )
        image_files.append(filename)

# Make a GIF of all images.
imgs = [Image.open(img) for img in image_files]
imgs[0].save(
    "mrinufft_learn_2d_sampling_pattern.gif",
    save_all=True,
    append_images=imgs[1:],
    optimize=False,
    duration=2,
    loop=0,
)

# sphinx_gallery_thumbnail_path = 'generated/autoexamples/GPU/images/mrinufft_learn_2d_sampling_pattern.gif'
  0%|          | 0/100 [00:00<?, ?steps/s]
  0%|          | 0/100 [00:00<?, ?steps/s, loss=0.31269866]
  1%|          | 1/100 [00:03<05:32,  3.36s/steps, loss=0.31269866]
  1%|          | 1/100 [00:03<05:32,  3.36s/steps, loss=0.28621608]
  2%|▏         | 2/100 [00:06<05:11,  3.18s/steps, loss=0.28621608]
  2%|▏         | 2/100 [00:06<05:11,  3.18s/steps, loss=0.28055453]
  3%|▎         | 3/100 [00:09<05:00,  3.10s/steps, loss=0.28055453]
  3%|▎         | 3/100 [00:09<05:00,  3.10s/steps, loss=0.26690418]
  4%|▍         | 4/100 [00:12<05:05,  3.18s/steps, loss=0.26690418]
  4%|▍         | 4/100 [00:12<05:05,  3.18s/steps, loss=0.2570128]
  5%|▌         | 5/100 [00:15<04:56,  3.12s/steps, loss=0.2570128]
  5%|▌         | 5/100 [00:15<04:56,  3.12s/steps, loss=0.25114173]
  6%|▌         | 6/100 [00:18<04:49,  3.08s/steps, loss=0.25114173]
  6%|▌         | 6/100 [00:18<04:49,  3.08s/steps, loss=0.24944021]
  7%|▋         | 7/100 [00:22<04:53,  3.16s/steps, loss=0.24944021]
  7%|▋         | 7/100 [00:22<04:53,  3.16s/steps, loss=0.24832918]
  8%|▊         | 8/100 [00:25<04:46,  3.11s/steps, loss=0.24832918]
  8%|▊         | 8/100 [00:25<04:46,  3.11s/steps, loss=0.24599953]
  9%|▉         | 9/100 [00:28<04:40,  3.08s/steps, loss=0.24599953]
  9%|▉         | 9/100 [00:28<04:40,  3.08s/steps, loss=0.24378936]
 10%|█         | 10/100 [00:31<04:42,  3.14s/steps, loss=0.24378936]
 10%|█         | 10/100 [00:31<04:42,  3.14s/steps, loss=0.2422308]
 11%|█         | 11/100 [00:34<04:36,  3.10s/steps, loss=0.2422308]
 11%|█         | 11/100 [00:34<04:36,  3.10s/steps, loss=0.24135594]
 12%|█▏        | 12/100 [00:37<04:36,  3.14s/steps, loss=0.24135594]
 12%|█▏        | 12/100 [00:37<04:36,  3.14s/steps, loss=0.23864205]
 13%|█▎        | 13/100 [00:40<04:30,  3.10s/steps, loss=0.23864205]
 13%|█▎        | 13/100 [00:40<04:30,  3.10s/steps, loss=0.23628193]
 14%|█▍        | 14/100 [00:43<04:24,  3.08s/steps, loss=0.23628193]
 14%|█▍        | 14/100 [00:43<04:24,  3.08s/steps, loss=0.23433003]
 15%|█▌        | 15/100 [00:46<04:26,  3.14s/steps, loss=0.23433003]
 15%|█▌        | 15/100 [00:46<04:26,  3.14s/steps, loss=0.23285046]
 16%|█▌        | 16/100 [00:49<04:20,  3.10s/steps, loss=0.23285046]
 16%|█▌        | 16/100 [00:50<04:20,  3.10s/steps, loss=0.23205149]
 17%|█▋        | 17/100 [00:52<04:15,  3.08s/steps, loss=0.23205149]
 17%|█▋        | 17/100 [00:53<04:15,  3.08s/steps, loss=0.23136263]
 18%|█▊        | 18/100 [00:56<04:17,  3.14s/steps, loss=0.23136263]
 18%|█▊        | 18/100 [00:56<04:17,  3.14s/steps, loss=0.23030923]
 19%|█▉        | 19/100 [00:59<04:11,  3.10s/steps, loss=0.23030923]
 19%|█▉        | 19/100 [00:59<04:11,  3.10s/steps, loss=0.22902341]
 20%|██        | 20/100 [01:02<04:09,  3.12s/steps, loss=0.22902341]
 20%|██        | 20/100 [01:02<04:09,  3.12s/steps, loss=0.22816998]
 21%|██        | 21/100 [01:05<04:01,  3.05s/steps, loss=0.22816998]
 21%|██        | 21/100 [01:05<04:01,  3.05s/steps, loss=0.22701749]
 22%|██▏       | 22/100 [01:08<03:54,  3.01s/steps, loss=0.22701749]
 22%|██▏       | 22/100 [01:08<03:54,  3.01s/steps, loss=0.225749]
 23%|██▎       | 23/100 [01:11<03:54,  3.05s/steps, loss=0.225749]
 23%|██▎       | 23/100 [01:11<03:54,  3.05s/steps, loss=0.22442329]
 24%|██▍       | 24/100 [01:14<03:48,  3.01s/steps, loss=0.22442329]
 24%|██▍       | 24/100 [01:14<03:48,  3.01s/steps, loss=0.22332573]
 25%|██▌       | 25/100 [01:17<03:44,  2.99s/steps, loss=0.22332573]
 25%|██▌       | 25/100 [01:17<03:44,  2.99s/steps, loss=0.22279465]
 26%|██▌       | 26/100 [01:20<03:45,  3.04s/steps, loss=0.22279465]
 26%|██▌       | 26/100 [01:20<03:45,  3.04s/steps, loss=0.22274435]
 27%|██▋       | 27/100 [01:23<03:39,  3.01s/steps, loss=0.22274435]
 27%|██▋       | 27/100 [01:23<03:39,  3.01s/steps, loss=0.22271453]
 28%|██▊       | 28/100 [01:26<03:39,  3.05s/steps, loss=0.22271453]
 28%|██▊       | 28/100 [01:26<03:39,  3.05s/steps, loss=0.22240941]
 29%|██▉       | 29/100 [01:29<03:34,  3.02s/steps, loss=0.22240941]
 29%|██▉       | 29/100 [01:29<03:34,  3.02s/steps, loss=0.22207557]
 30%|███       | 30/100 [01:32<03:29,  3.00s/steps, loss=0.22207557]
 30%|███       | 30/100 [01:32<03:29,  3.00s/steps, loss=0.22172217]
 31%|███       | 31/100 [01:35<03:30,  3.06s/steps, loss=0.22172217]
 31%|███       | 31/100 [01:35<03:30,  3.06s/steps, loss=0.22087069]
 32%|███▏      | 32/100 [01:38<03:25,  3.03s/steps, loss=0.22087069]
 32%|███▏      | 32/100 [01:38<03:25,  3.03s/steps, loss=0.21991867]
 33%|███▎      | 33/100 [01:41<03:21,  3.00s/steps, loss=0.21991867]
 33%|███▎      | 33/100 [01:41<03:21,  3.00s/steps, loss=0.21887887]
 34%|███▍      | 34/100 [01:44<03:21,  3.05s/steps, loss=0.21887887]
 34%|███▍      | 34/100 [01:44<03:21,  3.05s/steps, loss=0.21813084]
 35%|███▌      | 35/100 [01:47<03:15,  3.01s/steps, loss=0.21813084]
 35%|███▌      | 35/100 [01:47<03:15,  3.01s/steps, loss=0.21815285]
 36%|███▌      | 36/100 [01:50<03:10,  2.98s/steps, loss=0.21815285]
 36%|███▌      | 36/100 [01:50<03:10,  2.98s/steps, loss=0.21809953]
 37%|███▋      | 37/100 [01:53<03:11,  3.03s/steps, loss=0.21809953]
 37%|███▋      | 37/100 [01:53<03:11,  3.03s/steps, loss=0.21817195]
 38%|███▊      | 38/100 [01:56<03:05,  2.99s/steps, loss=0.21817195]
 38%|███▊      | 38/100 [01:56<03:05,  2.99s/steps, loss=0.21790914]
 39%|███▉      | 39/100 [01:59<03:04,  3.03s/steps, loss=0.21790914]
 39%|███▉      | 39/100 [01:59<03:04,  3.03s/steps, loss=0.21778516]
 40%|████      | 40/100 [02:02<02:59,  2.99s/steps, loss=0.21778516]
 40%|████      | 40/100 [02:02<02:59,  2.99s/steps, loss=0.21772687]
 41%|████      | 41/100 [02:05<02:55,  2.97s/steps, loss=0.21772687]
 41%|████      | 41/100 [02:05<02:55,  2.97s/steps, loss=0.21764515]
 42%|████▏     | 42/100 [02:08<02:55,  3.02s/steps, loss=0.21764515]
 42%|████▏     | 42/100 [02:08<02:55,  3.02s/steps, loss=0.21743998]
 43%|████▎     | 43/100 [02:11<02:50,  2.98s/steps, loss=0.21743998]
 43%|████▎     | 43/100 [02:11<02:50,  2.98s/steps, loss=0.21715567]
 44%|████▍     | 44/100 [02:14<02:45,  2.96s/steps, loss=0.21715567]
 44%|████▍     | 44/100 [02:14<02:45,  2.96s/steps, loss=0.21692038]
 45%|████▌     | 45/100 [02:17<02:45,  3.01s/steps, loss=0.21692038]
 45%|████▌     | 45/100 [02:17<02:45,  3.01s/steps, loss=0.21661372]
 46%|████▌     | 46/100 [02:20<02:41,  2.98s/steps, loss=0.21661372]
 46%|████▌     | 46/100 [02:20<02:41,  2.98s/steps, loss=0.21635765]
 47%|████▋     | 47/100 [02:23<02:40,  3.02s/steps, loss=0.21635765]
 47%|████▋     | 47/100 [02:23<02:40,  3.02s/steps, loss=0.21628879]
 48%|████▊     | 48/100 [02:26<02:35,  2.99s/steps, loss=0.21628879]
 48%|████▊     | 48/100 [02:26<02:35,  2.99s/steps, loss=0.21594298]
 49%|████▉     | 49/100 [02:29<02:31,  2.96s/steps, loss=0.21594298]
 49%|████▉     | 49/100 [02:29<02:31,  2.96s/steps, loss=0.21561931]
 50%|█████     | 50/100 [02:32<02:30,  3.01s/steps, loss=0.21561931]
 50%|█████     | 50/100 [02:32<02:30,  3.01s/steps, loss=0.21537787]
 51%|█████     | 51/100 [02:35<02:26,  2.98s/steps, loss=0.21537787]
 51%|█████     | 51/100 [02:35<02:26,  2.98s/steps, loss=0.21502337]
 52%|█████▏    | 52/100 [02:35<01:47,  2.23s/steps, loss=0.21502337]
 52%|█████▏    | 52/100 [02:35<01:47,  2.23s/steps, loss=0.21496297]
 53%|█████▎    | 53/100 [02:36<01:20,  1.71s/steps, loss=0.21496297]
 53%|█████▎    | 53/100 [02:36<01:20,  1.71s/steps, loss=0.21477088]
 54%|█████▍    | 54/100 [02:36<01:01,  1.34s/steps, loss=0.21477088]
 54%|█████▍    | 54/100 [02:36<01:01,  1.34s/steps, loss=0.21469645]
 55%|█████▌    | 55/100 [02:37<00:48,  1.08s/steps, loss=0.21469645]
 55%|█████▌    | 55/100 [02:37<00:48,  1.08s/steps, loss=0.21477848]
 56%|█████▌    | 56/100 [02:37<00:39,  1.11steps/s, loss=0.21477848]
 56%|█████▌    | 56/100 [02:37<00:39,  1.11steps/s, loss=0.21473406]
 57%|█████▋    | 57/100 [02:38<00:36,  1.18steps/s, loss=0.21473406]
 57%|█████▋    | 57/100 [02:38<00:36,  1.18steps/s, loss=0.2143364]
 58%|█████▊    | 58/100 [02:38<00:30,  1.36steps/s, loss=0.2143364]
 58%|█████▊    | 58/100 [02:39<00:30,  1.36steps/s, loss=0.2138579]
 59%|█████▉    | 59/100 [02:39<00:27,  1.51steps/s, loss=0.2138579]
 59%|█████▉    | 59/100 [02:39<00:27,  1.51steps/s, loss=0.21346204]
 60%|██████    | 60/100 [02:39<00:24,  1.64steps/s, loss=0.21346204]
 60%|██████    | 60/100 [02:40<00:24,  1.64steps/s, loss=0.21322244]
 61%|██████    | 61/100 [02:40<00:22,  1.75steps/s, loss=0.21322244]
 61%|██████    | 61/100 [02:40<00:22,  1.75steps/s, loss=0.21317327]
 62%|██████▏   | 62/100 [02:40<00:20,  1.83steps/s, loss=0.21317327]
 62%|██████▏   | 62/100 [02:41<00:20,  1.83steps/s, loss=0.21306983]
 63%|██████▎   | 63/100 [02:41<00:19,  1.89steps/s, loss=0.21306983]
 63%|██████▎   | 63/100 [02:41<00:19,  1.89steps/s, loss=0.21305107]
 64%|██████▍   | 64/100 [02:41<00:18,  1.93steps/s, loss=0.21305107]
 64%|██████▍   | 64/100 [02:41<00:18,  1.93steps/s, loss=0.21300307]
 65%|██████▌   | 65/100 [02:42<00:20,  1.74steps/s, loss=0.21300307]
 65%|██████▌   | 65/100 [02:42<00:20,  1.74steps/s, loss=0.21302323]
 66%|██████▌   | 66/100 [02:43<00:18,  1.83steps/s, loss=0.21302323]
 66%|██████▌   | 66/100 [02:43<00:18,  1.83steps/s, loss=0.21279402]
 67%|██████▋   | 67/100 [02:43<00:17,  1.89steps/s, loss=0.21279402]
 67%|██████▋   | 67/100 [02:43<00:17,  1.89steps/s, loss=0.21258767]
 68%|██████▊   | 68/100 [02:44<00:16,  1.93steps/s, loss=0.21258767]
 68%|██████▊   | 68/100 [02:44<00:16,  1.93steps/s, loss=0.21236463]
 69%|██████▉   | 69/100 [02:44<00:15,  1.95steps/s, loss=0.21236463]
 69%|██████▉   | 69/100 [02:44<00:15,  1.95steps/s, loss=0.21205866]
 70%|███████   | 70/100 [02:45<00:15,  1.98steps/s, loss=0.21205866]
 70%|███████   | 70/100 [02:45<00:15,  1.98steps/s, loss=0.21181977]
 71%|███████   | 71/100 [02:45<00:14,  1.99steps/s, loss=0.21181977]
 71%|███████   | 71/100 [02:45<00:14,  1.99steps/s, loss=0.21181841]
 72%|███████▏  | 72/100 [02:46<00:13,  2.00steps/s, loss=0.21181841]
 72%|███████▏  | 72/100 [02:46<00:13,  2.00steps/s, loss=0.21166171]
 73%|███████▎  | 73/100 [02:46<00:15,  1.78steps/s, loss=0.21166171]
 73%|███████▎  | 73/100 [02:46<00:15,  1.78steps/s, loss=0.21122287]
 74%|███████▍  | 74/100 [02:47<00:13,  1.87steps/s, loss=0.21122287]
 74%|███████▍  | 74/100 [02:47<00:13,  1.87steps/s, loss=0.21124454]
 75%|███████▌  | 75/100 [02:47<00:12,  1.93steps/s, loss=0.21124454]
 75%|███████▌  | 75/100 [02:47<00:12,  1.93steps/s, loss=0.21133342]
 76%|███████▌  | 76/100 [02:48<00:12,  1.99steps/s, loss=0.21133342]
 76%|███████▌  | 76/100 [02:48<00:12,  1.99steps/s, loss=0.2110769]
 77%|███████▋  | 77/100 [02:48<00:11,  2.03steps/s, loss=0.2110769]
 77%|███████▋  | 77/100 [02:48<00:11,  2.03steps/s, loss=0.21085785]
 78%|███████▊  | 78/100 [02:49<00:10,  2.04steps/s, loss=0.21085785]
 78%|███████▊  | 78/100 [02:49<00:10,  2.04steps/s, loss=0.21065724]
 79%|███████▉  | 79/100 [02:49<00:10,  2.06steps/s, loss=0.21065724]
 79%|███████▉  | 79/100 [02:49<00:10,  2.06steps/s, loss=0.21041058]
 80%|████████  | 80/100 [02:50<00:09,  2.07steps/s, loss=0.21041058]
 80%|████████  | 80/100 [02:50<00:09,  2.07steps/s, loss=0.21011266]
 81%|████████  | 81/100 [02:50<00:09,  2.08steps/s, loss=0.21011266]
 81%|████████  | 81/100 [02:50<00:09,  2.08steps/s, loss=0.20978431]
 82%|████████▏ | 82/100 [02:51<00:09,  1.84steps/s, loss=0.20978431]
 82%|████████▏ | 82/100 [02:51<00:09,  1.84steps/s, loss=0.20988365]
 83%|████████▎ | 83/100 [02:51<00:08,  1.91steps/s, loss=0.20988365]
 83%|████████▎ | 83/100 [02:51<00:08,  1.91steps/s, loss=0.2097596]
 84%|████████▍ | 84/100 [02:52<00:08,  1.96steps/s, loss=0.2097596]
 84%|████████▍ | 84/100 [02:52<00:08,  1.96steps/s, loss=0.2096154]
 85%|████████▌ | 85/100 [02:52<00:07,  2.00steps/s, loss=0.2096154]
 85%|████████▌ | 85/100 [02:52<00:07,  2.00steps/s, loss=0.20961627]
 86%|████████▌ | 86/100 [02:53<00:06,  2.03steps/s, loss=0.20961627]
 86%|████████▌ | 86/100 [02:53<00:06,  2.03steps/s, loss=0.20949975]
 87%|████████▋ | 87/100 [02:53<00:06,  2.04steps/s, loss=0.20949975]
 87%|████████▋ | 87/100 [02:53<00:06,  2.04steps/s, loss=0.20952103]
 88%|████████▊ | 88/100 [02:54<00:05,  2.06steps/s, loss=0.20952103]
 88%|████████▊ | 88/100 [02:54<00:05,  2.06steps/s, loss=0.20932287]
 89%|████████▉ | 89/100 [02:54<00:05,  2.07steps/s, loss=0.20932287]
 89%|████████▉ | 89/100 [02:54<00:05,  2.07steps/s, loss=0.2092399]
 90%|█████████ | 90/100 [02:55<00:04,  2.07steps/s, loss=0.2092399]
 90%|█████████ | 90/100 [02:55<00:04,  2.07steps/s, loss=0.20924813]
 91%|█████████ | 91/100 [02:55<00:04,  1.84steps/s, loss=0.20924813]
 91%|█████████ | 91/100 [02:55<00:04,  1.84steps/s, loss=0.20932367]
 92%|█████████▏| 92/100 [02:56<00:04,  1.91steps/s, loss=0.20932367]
 92%|█████████▏| 92/100 [02:56<00:04,  1.91steps/s, loss=0.20941375]
 93%|█████████▎| 93/100 [02:56<00:03,  1.96steps/s, loss=0.20941375]
 93%|█████████▎| 93/100 [02:56<00:03,  1.96steps/s, loss=0.20942785]
 94%|█████████▍| 94/100 [02:57<00:03,  2.00steps/s, loss=0.20942785]
 94%|█████████▍| 94/100 [02:57<00:03,  2.00steps/s, loss=0.20929511]
 95%|█████████▌| 95/100 [02:57<00:02,  2.02steps/s, loss=0.20929511]
 95%|█████████▌| 95/100 [02:57<00:02,  2.02steps/s, loss=0.2091367]
 96%|█████████▌| 96/100 [02:58<00:01,  2.04steps/s, loss=0.2091367]
 96%|█████████▌| 96/100 [02:58<00:01,  2.04steps/s, loss=0.20907779]
 97%|█████████▋| 97/100 [02:58<00:01,  2.05steps/s, loss=0.20907779]
 97%|█████████▋| 97/100 [02:58<00:01,  2.05steps/s, loss=0.2089521]
 98%|█████████▊| 98/100 [02:59<00:00,  2.05steps/s, loss=0.2089521]
 98%|█████████▊| 98/100 [02:59<00:00,  2.05steps/s, loss=0.2088416]
 99%|█████████▉| 99/100 [02:59<00:00,  2.06steps/s, loss=0.2088416]
 99%|█████████▉| 99/100 [02:59<00:00,  2.06steps/s, loss=0.20872626]
100%|██████████| 100/100 [03:00<00:00,  1.82steps/s, loss=0.20872626]
100%|██████████| 100/100 [03:00<00:00,  1.80s/steps, loss=0.20872626]
example learn_samples

Trained trajectory#

model.eval()
recon = model(mri_3D)
plot_state(mri_3D, model.get_trajectory(True).detach().cpu().numpy(), recon, losses)
plt.show()
MR Image, Reconstruction, Loss, Trajectory

References#

[Proj]

N. Chauffert, P. Weiss, J. Kahn and P. Ciuciu, “A Projection Algorithm for Gradient Waveforms Design in Magnetic Resonance Imaging,” in IEEE Transactions on Medical Imaging, vol. 35, no. 9, pp. 2026-2039, Sept. 2016, doi: 10.1109/TMI.2016.2544251.

[Sparks]

G. R. Chaithya, P. Weiss, G. Daval-Frérot, A. Massire, A. Vignaud and P. Ciuciu, “Optimizing Full 3D SPARKLING Trajectories for High-Resolution Magnetic Resonance Imaging,” in IEEE Transactions on Medical Imaging, vol. 41, no. 8, pp. 2105-2117, Aug. 2022, doi: 10.1109/TMI.2022.3157269.

[Projector]

Chaithya GR, and Philippe Ciuciu. 2023. “Jointly Learning Non-Cartesian k-Space Trajectories and Reconstruction Networks for 2D and 3D MR Imaging through Projection” Bioengineering 10, no. 2: 158. https://doi.org/10.3390/bioengineering10020158

Total running time of the script: (3 minutes 8.269 seconds)

Gallery generated by Sphinx-Gallery