Source code for mrinufft.operators.interfaces.torchkbnufft

"""Pytorch MRI Nufft Operators."""

from mrinufft.operators.base import FourierOperatorBase
from mrinufft._array_compat import with_torch
from mrinufft._utils import proper_trajectory
from mrinufft.operators.interfaces.utils import (
    is_cuda_tensor,
)
import numpy as np


TORCH_AVAILABLE = True
try:
    import torchkbnufft as tkbn
    import torch
except ImportError:
    TORCH_AVAILABLE = False

CUPY_AVAILABLE = True
try:
    import cupy as cp
except ImportError:
    CUPY_AVAILABLE = False


[docs] class MRITorchKbNufft(FourierOperatorBase): """ MRI Transform Operator using Torch NUFFT. This class provides a Non-Uniform Fast Fourier Transform (NUFFT) operator for MRI data, utilizing the torchkbnufft library for performing the computations. It supports both CPU and GPU computations. Parameters ---------- samples : Tensor The sample locations of shape ``Nsamples x N_dimensions``. It should be C-contiguous. shape : tuple Shape of the image space. density : bool or Tensor, optional Density compensation support. Default is False. - If a Tensor, it will be used for density. - If True, the density compensation will be automatically estimated using the fixed point method. - If False, density compensation will not be used. n_coils : int, optional Number of coils. Default is 1. n_batchs : int, optional Number of batches. Default is 1. smaps : Tensor, optional Sensitivity maps. Default is None. eps : float, optional A small epsilon value for numerical stability. Default is 1e-6. squeeze_dims : bool, optional If True, tries to remove singleton dimensions for batch and coils. Default is True. use_gpu : bool, optional Whether to use the GPU. Default is False. osf : int, optional Oversampling factor. Default is 2. **kwargs : dict Additional keyword arguments. """ available = TORCH_AVAILABLE autograd_available = False def __init__( self, samples, shape, density=False, n_coils=1, n_batchs=1, smaps=None, eps=1e-6, squeeze_dims=True, use_gpu=False, osf=2, **kwargs, ): super().__init__() if use_gpu: self.device = "cuda" else: self.device = "cpu" if isinstance(samples, torch.Tensor): if is_cuda_tensor(samples): samples = samples.cpu() samples = samples.numpy() samples = proper_trajectory( samples.astype(np.float32, copy=False), normalize="pi" ) self.samples = torch.tensor(samples).to(self.device) self.dtype = None # self.dtype = self.samples.dtype self.shape = shape self.n_coils = n_coils self.n_batchs = n_batchs self.squeeze_dims = squeeze_dims # self.eps = eps self.compute_density(density) if isinstance(smaps, torch.Tensor): self.smaps = smaps else: self.compute_smaps(smaps) if self.smaps is not None: self.smaps = torch.tensor(self.smaps).to(self.device) self._tkb_op = tkbn.KbNufft(im_size=self.shape).to(self.device) self._tkb_adj_op = tkbn.KbNufftAdjoint(im_size=self.shape).to(self.device)
[docs] @with_torch def op(self, data, out=None): """Forward operation. Parameters ---------- data: Tensor Returns ------- Tensor: Non-uniform Fourier transform of the input image. """ self.check_shape(image=data, ksp=out) B, C, XYZ = self.n_batchs, self.n_coils, self.shape data = data.reshape((B, 1 if self.uses_sense else C, *XYZ)) data = data.to(self.device, copy=False) if self.smaps is not None: self.smaps = self.smaps.to(data.dtype, copy=False) kdata = self._tkb_op.forward( image=data, omega=self.samples.t(), smaps=self.smaps ) kdata /= self.norm_factor return self._safe_squeeze(kdata)
[docs] @with_torch def adj_op(self, coeffs, out=None): """Backward Operation. Parameters ---------- coeffs: Tensor Returns ------- Tensor """ self.check_shape(image=out, ksp=coeffs) B, C, K, XYZ = self.n_batchs, self.n_coils, self.n_samples, self.shape coeffs = coeffs.reshape((B, C, K)) coeffs = coeffs.to(self.device, copy=False) if self.smaps is not None: self.smaps = self.smaps.to(coeffs.dtype, copy=False) if self.density: coeffs = coeffs * self.density img = self._tkb_adj_op.forward( data=coeffs, omega=self.samples.t(), smaps=self.smaps ) img = img.reshape((B, 1 if self.uses_sense else C, *XYZ)) img /= self.norm_factor return self._safe_squeeze(img)
[docs] def _safe_squeeze(self, arr): """Squeeze the first two dimensions of shape of the operator.""" if self.squeeze_dims: try: arr = arr.squeeze(axis=1) except ValueError: pass try: arr = arr.squeeze(axis=0) except ValueError: pass return arr
[docs] @with_torch def data_consistency(self, data, obs_data): """Compute the data consistency. Parameters ---------- data: Tensor Image data obs_data: Tensor Observed data Returns ------- Tensor The data consistency error in image space. """ obs_data = obs_data.to(self.device, copy=False) ret = self.adj_op(self.op(data) - obs_data) return ret
[docs] @classmethod @with_torch def pipe( cls, kspace_loc, volume_shape, num_iterations=10, osf=2, normalize=True, use_gpu=False, **kwargs, ): """Compute the density compensation weights for a given set of kspace locations. Parameters ---------- kspace_loc: Tensor the kspace locations volume_shape: tuple the volume shape num_iterations: int default 10 the number of iterations for density estimation osf: float or int The oversampling factor the volume shape normalize: bool Whether to normalize the density compensation. We normalize such that the energy of PSF = 1 use_gpu: bool, default False Whether to use the GPU """ volume_shape = (np.array(volume_shape) * osf).astype(int) grid_op = MRITorchKbNufft( samples=kspace_loc, shape=volume_shape, osf=1, use_gpu=use_gpu, **kwargs, ) density_comp = tkbn.calc_density_compensation_function( ktraj=kspace_loc, im_size=volume_shape, num_iterations=num_iterations ) if normalize: spike = torch.zeros(volume_shape, dtype=torch.float32).to(grid_op.device) mid_loc = tuple(v // 2 for v in volume_shape) spike[mid_loc] = 1 psf = grid_op.adj_op(grid_op.op(spike)) density_comp /= torch.norm(psf) return density_comp.squeeze()
[docs] class TorchKbNUFFTcpu(MRITorchKbNufft): """ MRI Transform Operator using Torch NUFFT for CPU. This class provides a Non-Uniform Fast Fourier Transform (NUFFT) operator specifically optimized for CPU using the torchkbnufft library. It inherits from the MRITorchKbNufft class and sets the use_gpu parameter to False. """ backend = "torchkbnufft-cpu" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, use_gpu=False)
[docs] class TorchKbNUFFTgpu(MRITorchKbNufft): """ MRI Transform Operator using Torch NUFFT for GPU. This class provides a Non-Uniform Fast Fourier Transform (NUFFT) operator specifically optimized for GPU using the torchkbnufft library. It inherits from the MRITorchKbNufft class and sets the use_gpu parameter to True. """ backend = "torchkbnufft-gpu" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, use_gpu=True)