Source code for mrinufft.operators.autodiff

"""Torch autodifferentiation for MRI-NUFFT."""

import torch
import numpy as np
from .._utils import NP2TORCH


class _NUFFT_OP(torch.autograd.Function):
    """
    Autograd support for op nufft function.

    This class is implemented by an efficient approximation of Jacobian Matrices.

    References
    ----------
    Wang G, Fessler J A. "Efficient approximation of Jacobian matrices involving a
    non-uniform fast Fourier transform (NUFFT)."
    IEEE Transactions on Computational Imaging, 2023, 9: 43-54.
    """

    @staticmethod
    def forward(ctx, x, traj, nufft_op):
        """Forward image -> k-space."""
        ctx.save_for_backward(x)
        ctx.nufft_op = nufft_op
        return nufft_op.op(x)

    @staticmethod
    def backward(ctx, dy):
        """Backward image -> k-space."""
        x = ctx.saved_tensors[0]
        grad_data = None
        grad_traj = None
        if ctx.nufft_op._grad_wrt_data:
            grad_data = ctx.nufft_op.adj_op(dy)
        if ctx.nufft_op._grad_wrt_traj:
            im_size = x.size()[1:]
            factor = 1
            if ctx.nufft_op.backend in ["gpunufft"]:
                factor *= np.pi * 2
            r = [
                torch.linspace(-size / 2, size / 2 - 1, size) * factor
                for size in im_size
            ]
            grid_r = torch.meshgrid(*r, indexing="ij")
            grid_r = torch.stack(grid_r, dim=0).type_as(x)[:, None]
            grid_x = x * grid_r  # Element-wise multiplication: x * r

            nufft_dx_dom = torch.cat(
                [ctx.nufft_op.op(grid_x[i, ...]) for i in range(grid_x.size(0))],
                dim=0,
            )
            grad_traj = -1j * torch.conj(dy) * nufft_dx_dom
            grad_traj = torch.transpose(
                torch.sum(grad_traj, dim=1),
                0,
                1,
            ).to(NP2TORCH[ctx.nufft_op.dtype])
        return grad_data, grad_traj, None


class _NUFFT_ADJOP(torch.autograd.Function):
    """Autograd support for adj_op nufft function."""

    @staticmethod
    def forward(ctx, y, traj, nufft_op):
        """Forward kspace -> image."""
        ctx.save_for_backward(y)
        ctx.nufft_op = nufft_op
        return nufft_op.adj_op(y)

    @staticmethod
    def backward(ctx, dx):
        """Backward kspace -> image."""
        y = ctx.saved_tensors[0]
        grad_data = None
        grad_traj = None
        if ctx.nufft_op._grad_wrt_data:
            grad_data = ctx.nufft_op.op(dx)
        if ctx.nufft_op._grad_wrt_traj:
            ctx.nufft_op.toggle_grad_traj()
            im_size = dx.size()[2:]
            factor = 1
            if ctx.nufft_op.backend in ["gpunufft"]:
                factor *= np.pi * 2
            r = [
                torch.linspace(-size / 2, size / 2 - 1, size) * factor
                for size in im_size
            ]
            grid_r = torch.meshgrid(*r, indexing="ij")
            grid_r = torch.stack(grid_r, dim=0).type_as(dx)[:, None]
            grid_dx = torch.conj(dx) * grid_r
            inufft_dx_dom = torch.cat(
                [ctx.nufft_op.op(grid_dx[i, ...]) for i in range(grid_dx.size(0))],
                dim=0,
            )
            grad_traj = 1j * y * inufft_dx_dom
            grad_traj = torch.transpose(torch.sum(grad_traj, dim=1), 0, 1).to(
                NP2TORCH[ctx.nufft_op.dtype]
            )
            ctx.nufft_op.toggle_grad_traj()
        return grad_data, grad_traj, None


[docs] class MRINufftAutoGrad(torch.nn.Module): """ Wraps the NUFFT operator to support torch autodiff. Parameters ---------- nufft_op: Classic Non differentiable MRI-NUFFT operator. """ def __init__(self, nufft_op, wrt_data=True, wrt_traj=False): super().__init__() if (wrt_data or wrt_traj) and nufft_op.squeeze_dims: raise ValueError("Squeezing dimensions is not supported for autodiff.") self.nufft_op = nufft_op self.nufft_op._grad_wrt_traj = wrt_traj if wrt_traj and self.nufft_op.backend in ["finufft", "cufinufft"]: self.nufft_op._make_plan_grad() self.nufft_op._grad_wrt_data = wrt_data if wrt_traj: # We initialize the samples as a torch tensor purely for autodiff purposes. # It can also be converted later to nn.Parameter, in which case it is # used for update also. self._samples_torch = torch.Tensor(self.nufft_op.samples) self._samples_torch.requires_grad = True
[docs] def op(self, x): r"""Compute the forward image -> k-space.""" return _NUFFT_OP.apply(x, self.samples, self.nufft_op)
[docs] def adj_op(self, kspace): r"""Compute the adjoint k-space -> image.""" return _NUFFT_ADJOP.apply(kspace, self.samples, self.nufft_op)
@property def samples(self): """Get the samples.""" try: return self._samples_torch except AttributeError: return self.nufft_op.samples @samples.setter def samples(self, value): self._samples_torch = value self.nufft_op.samples = value.detach().cpu().numpy() def __getattr__(self, name): """Forward all other attributes to the nufft_op.""" return getattr(self.nufft_op, name)