Source code for mrinufft.operators.base

"""
Base Fourier Operator interface.

from https://github.com/CEA-COSMIC/pysap-mri

:author: Pierre-Antoine Comby
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from functools import partial
from typing import ClassVar, Callable
import numpy as np
from numpy.typing import NDArray

from mrinufft._array_compat import (
    with_numpy,
    with_numpy_cupy,
    AUTOGRAD_AVAILABLE,
    CUPY_AVAILABLE,
)
from mrinufft._utils import auto_cast, power_method
from mrinufft.density import get_density
from mrinufft.extras import get_smaps
from mrinufft.operators.interfaces.utils import is_cuda_array, is_host_array


if AUTOGRAD_AVAILABLE:
    from mrinufft.operators.autodiff import MRINufftAutoGrad
if CUPY_AVAILABLE:
    import cupy as cp


# Mapping between numpy float and complex types.
DTYPE_R2C = {"float32": "complex64", "float64": "complex128"}


[docs] def check_backend(backend_name: str): """Check if a specific backend is available.""" backend_name = backend_name.lower() try: return FourierOperatorBase.interfaces[backend_name][0] except KeyError as e: raise ValueError(f"unknown backend: '{backend_name}'") from e
[docs] def list_backends(available_only=False): """Return a list of backend. Parameters ---------- available_only: bool, optional If True, only return backends that are available. If False, return all backends, regardless of whether they are available or not. """ return [ name for name, (available, _) in FourierOperatorBase.interfaces.items() if available or not available_only ]
[docs] def get_operator( backend_name: str, wrt_data: bool = False, wrt_traj: bool = False, *args, **kwargs ): """Return an MRI Fourier operator interface using the correct backend. Parameters ---------- backend_name: str Backend name wrt_data: bool, default False if set gradients wrt to data and images will be available. wrt_traj: bool, default False if set gradients wrt to trajectory will be available. *args, **kwargs: Arguments to pass to the operator constructor. Returns ------- FourierOperator class or instance of class if args or kwargs are given. Raises ------ ValueError if the backend is not available. """ available = True backend_name = backend_name.lower() try: available, operator = FourierOperatorBase.interfaces[backend_name] except KeyError as exc: if not backend_name.startswith("stacked-"): raise ValueError(f"backend {backend_name} does not exist") from exc # try to get the backend with stacked # Dedicated registered stacked backend (like stacked-cufinufft) # have be found earlier. backend = backend_name.split("-")[1] operator = get_operator("stacked") operator = partial(operator, backend=backend) if not available: raise ValueError( f"backend {backend_name} found, but dependencies are not met." f" ``pip install mri-nufft[{backend_name}]`` may solve the issue." ) if args or kwargs: operator = operator(*args, **kwargs) # if autograd: if wrt_data or wrt_traj: if isinstance(operator, FourierOperatorBase): operator = operator.make_autograd(wrt_data, wrt_traj) else: # instance will be created later operator = partial(operator.with_autograd, wrt_data, wrt_traj) return operator
[docs] class FourierOperatorBase(ABC): """Base Fourier Operator class. Every (Linear) Fourier operator inherits from this class, to ensure that we have all the functions rightly implemented as required by ModOpt. """ interfaces: dict[str, tuple] = {} autograd_available = False _density_method = None _grad_wrt_data = False _grad_wrt_traj = False backend: ClassVar[str] available: ClassVar[bool] def __init__(self): if not self.available: raise RuntimeError(f"'{self.backend}' backend is not available.") self._smaps = None self._density = None self._n_coils = 1 def __init_subclass__(cls): """Register the class in the list of available operators.""" super().__init_subclass__() available = getattr(cls, "available", True) if callable(available): available = available() if backend := getattr(cls, "backend", None): cls.interfaces[backend] = (available, cls)
[docs] def check_shape(self, *, image=None, ksp=None): """ Validate the shapes of the image or k-space data against operator shapes. Parameters ---------- image : np.ndarray, optional If passed, the shape of image data will be checked. ksp : np.ndarray or object, optional If passed, the shape of the k-space data will be checked. Raises ------ ValueError If the shape of the provided image does not match the expected operator shape, or if the number of k-space samples does not match the expected number of samples. """ if image is not None: image_shape = image.shape[-len(self.shape) :] if image_shape != self.shape: raise ValueError( f"Image shape {image_shape} is not compatible " f"with the operator shape {self.shape}" ) if ksp is not None: kspace_shape = ksp.shape[-1] if kspace_shape != self.n_samples: raise ValueError( f"Kspace samples {kspace_shape} is not compatible " f"with the operator samples {self.n_samples}" ) if image is None and ksp is None: raise ValueError("Nothing to check, provides image or ksp arguments")
[docs] @abstractmethod def op(self, data): """Compute operator transform. Parameters ---------- data: np.ndarray input as array. Returns ------- result: np.ndarray operator transform of the input. """ pass
[docs] @abstractmethod def adj_op(self, coeffs): """Compute adjoint operator transform. Parameters ---------- x: np.ndarray input data array. Returns ------- results: np.ndarray adjoint operator transform. """ pass
[docs] def data_consistency(self, image_data, obs_data): """Compute the gradient data consistency. This is the naive implementation using adj_op(op(x)-y). Specific backend can (and should!) implement a more efficient version. """ return self.adj_op(self.op(image_data) - obs_data)
[docs] def with_off_resonance_correction(self, B, C, indices): """Return a new operator with Off Resonnance Correction.""" from .off_resonance import MRIFourierCorrected return MRIFourierCorrected(self, B, C, indices)
[docs] def compute_smaps(self, method: NDArray | Callable | str | dict | None = None): """Compute the sensitivity maps and set it. Parameters ---------- method: callable or dict or array The method to use to compute the sensitivity maps. If an array, it should be of shape (NCoils,XYZ) and will be used as is. If a dict, it should have a key 'name', to determine which method to use. other items will be used as kwargs. If a callable, it should take the samples and the shape as input. Note that this callable function should also hold the k-space data (use funtools.partial) """ if is_host_array(method) or is_cuda_array(method): self.smaps = method return if not method: self.smaps = None return kwargs = {} if isinstance(method, dict): kwargs = method.copy() method = kwargs.pop("name") if isinstance(method, str): method = get_smaps(method) if not callable(method): raise ValueError(f"Unknown smaps method: {method}") self.smaps, self.SOS = method( self.samples, self.shape, density=self.density, backend=self.backend, **kwargs, )
[docs] def make_autograd(self, wrt_data=True, wrt_traj=False): """Make a new Operator with autodiff support. Parameters ---------- variable: , default data variable on which the gradient is computed with respect to. wrt_data : bool, optional If the gradient with respect to the data is computed, default is true wrt_traj : bool, optional If the gradient with respect to the trajectory is computed, default is false Returns ------- torch.nn.module A NUFFT operator with autodiff capabilities. Raises ------ ValueError If autograd is not available. """ if not AUTOGRAD_AVAILABLE: raise ValueError("Autograd not available, ensure torch is installed.") if not self.autograd_available: raise ValueError("Backend does not support auto-differentiation.") from mrinufft.operators.autodiff import MRINufftAutoGrad return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj)
[docs] def compute_density(self, method=None): """Compute the density compensation weights and set it. Parameters ---------- method: str or callable or array or dict or bool The method to use to compute the density compensation. - If a string, the method should be registered in the density registry. - If a callable, it should take the samples and the shape as input. - If a dict, it should have a key 'name', to determine which method to use. other items will be used as kwargs. - If an array, it should be of shape (Nsamples,) and will be used as is. - If `True`, the method `pipe` is chosen as default estimation method. Notes ----- The "pipe" method is only available for the following backends: `tensorflow`, `finufft`, `cufinufft`, `gpunufft`, `torchkbnufft-cpu` and `torchkbnufft-gpu`. """ if isinstance(method, np.ndarray) or ( CUPY_AVAILABLE and isinstance(method, cp.ndarray) ): self.density = method return None if not method: self._density = None return None if method is True: method = "pipe" kwargs = {} if isinstance(method, dict): kwargs = method.copy() method = kwargs.pop("name") # must be a string ! if method == "pipe" and "backend" not in kwargs: kwargs["backend"] = self.backend if isinstance(method, str): method = get_density(method) if not callable(method): raise ValueError(f"Unknown density method: {method}") if self._density_method is None: self._density_method = lambda samples, shape: method( samples, shape, **kwargs, ) self.density = method(self.samples, self.shape, **kwargs)
[docs] def get_lipschitz_cst(self, max_iter=10, **kwargs): """Return the Lipschitz constant of the operator. Parameters ---------- max_iter: int number of iteration to compute the lipschitz constant. **kwargs: Extra arguments givent Returns ------- float Spectral Radius Notes ----- This uses the Iterative Power Method to compute the largest singular value of a minified version of the nufft operator. No coil or B0 compensation is used, but includes any computed density. """ if self.n_coils > 1: tmp_op = self.__class__( self.samples, self.shape, density=self.density, n_coils=1, **kwargs ) else: tmp_op = self return power_method(max_iter, tmp_op)
[docs] def cg(self, kspace_data, compute_loss=False, **kwargs): """Conjugate Gradient method to solve the inverse problem. Parameters ---------- kspace_data: np.ndarray The k-space data to reconstruct. computer_loss: bool Whether to compute the loss at each iteration. If True, loss is calculated and returned, otherwise, it's skipped. **kwargs: Extra arguments to pass to the conjugate gradient method. Returns ------- np.ndarray Reconstructed image np.ndarray, optional array of loss at each iteration, if compute_loss is True. """ from ..extras.gradient import cg return cg( operator=self, kspace_data=kspace_data, compute_loss=compute_loss, **kwargs )
@property def uses_sense(self): """Return True if the operator uses sensitivity maps.""" return self._smaps is not None @property def uses_density(self): """Return True if the operator uses density compensation.""" return getattr(self, "density", None) is not None @property def ndim(self): """Number of dimensions in image space of the operator.""" return len(self._shape) @property def shape(self): """Shape of the image space of the operator.""" return self._shape @shape.setter def shape(self, shape): self._shape = tuple(shape) @property def n_coils(self): """Number of coils for the operator.""" return self._n_coils @n_coils.setter def n_coils(self, n_coils): if n_coils < 1 or not int(n_coils) == n_coils: raise ValueError(f"n_coils should be a positive integer, {type(n_coils)}") self._n_coils = int(n_coils) @property def smaps(self): """Sensitivity maps of the operator.""" return self._smaps @smaps.setter def smaps(self, new_smaps): self._check_smaps_shape(new_smaps) self._smaps = new_smaps
[docs] def _check_smaps_shape(self, smaps): """Check the shape of the sensitivity maps.""" if smaps is None: self._smaps = None elif smaps.shape != (self.n_coils, *self.shape): raise ValueError( f"smaps shape is {smaps.shape}, it should be" f"(n_coils, *shape): {(self.n_coils, *self.shape)}" )
@property def density(self): """Density compensation of the operator.""" return self._density @density.setter def density(self, new_density): if new_density is None: self._density = None elif len(new_density) != self.n_samples: raise ValueError("Density and samples should have the same length") else: self._density = new_density @property def dtype(self): """Return floating precision of the operator.""" return self._dtype @dtype.setter def dtype(self, new_dtype): self._dtype = np.dtype(new_dtype) @property def cpx_dtype(self): """Return complex floating precision of the operator.""" return np.dtype(DTYPE_R2C[str(self.dtype)]) @property def samples(self): """Return the samples used by the operator.""" return self._samples @samples.setter def samples(self, new_samples): self._samples = new_samples @property def n_samples(self): """Return the number of samples used by the operator.""" return self._samples.shape[0] @property def norm_factor(self): """Normalization factor of the operator.""" return np.sqrt(np.prod(self.shape) * (2 ** len(self.shape))) def __repr__(self): """Return info about the Fourier operator.""" return ( f"{self.__class__.__name__}(\n" f" shape: {self.shape}\n" f" n_coils: {self.n_coils}\n" f" n_samples: {self.n_samples}\n" f" uses_sense: {self.uses_sense}\n" ")" )
[docs] @classmethod def with_autograd(cls, wrt_data=True, wrt_traj=False, *args, **kwargs): """Return a Fourier operator with autograd capabilities.""" return cls(*args, **kwargs).make_autograd(wrt_data, wrt_traj)
[docs] class FourierOperatorCPU(FourierOperatorBase): """Base class for CPU-based NUFFT operator. The NUFFT operation will be done sequentially and looped over coils and batches. Parameters ---------- samples: np.ndarray The samples used by the operator. shape: tuple The shape of the image space (in 2D or 3D) density: bool or np.ndarray If True, the density compensation is estimated from the samples. If False, no density compensation is applied. If np.ndarray, the density compensation is applied from the array. n_coils: int The number of coils. smaps: np.ndarray The sensitivity maps. raw_op: object An object implementing the NUFFT API. Ut should be responsible to compute a single type 1 /type 2 NUFFT. """ def __init__( self, samples, shape, density=False, n_coils=1, n_batchs=1, n_trans=1, smaps=None, raw_op=None, squeeze_dims=True, ): super().__init__() self.shape = shape # we will access the samples by their coordinate first. self._samples = samples.reshape(-1, len(shape)) self.dtype = self.samples.dtype if n_coils < 1: raise ValueError("n_coils should be ≥ 1") self.n_coils = n_coils self.n_batchs = n_batchs self.n_trans = n_trans self.squeeze_dims = squeeze_dims # Density Compensation Setup self.compute_density(density) # Multi Coil Setup self.compute_smaps(smaps) self.raw_op = raw_op
[docs] @with_numpy def op(self, data, ksp=None): r"""Non Cartesian MRI forward operator. Parameters ---------- data: np.ndarray The uniform (2D or 3D) data in image space. Returns ------- Results array on the same device as data. Notes ----- this performs for every coil \ell: ..math:: \mathcal{F}\mathcal{S}_\ell x """ self.check_shape(image=data, ksp=ksp) # sense data = auto_cast(data, self.cpx_dtype) if self.uses_sense: ret = self._op_sense(data, ksp) # calibrationless or monocoil. else: ret = self._op_calibless(data, ksp) ret /= self.norm_factor ret = self._safe_squeeze(ret) return ret
def _op_sense(self, data, ksp=None): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape dataf = data.reshape((B, *XYZ)) if ksp is None: ksp = np.empty((B * C, K), dtype=self.cpx_dtype) for i in range(B * C // T): idx_coils = np.arange(i * T, (i + 1) * T) % C idx_batch = np.arange(i * T, (i + 1) * T) // C coil_img = self.smaps[idx_coils].copy().reshape((T, *XYZ)) coil_img *= dataf[idx_batch] self._op(coil_img, ksp[i * T : (i + 1) * T]) ksp = ksp.reshape((B, C, K)) return ksp def _op_calibless(self, data, ksp=None): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape if ksp is None: ksp = np.empty((B * C, K), dtype=self.cpx_dtype) dataf = np.reshape(data, (B * C, *XYZ)) for i in range((B * C) // T): self._op( dataf[i * T : (i + 1) * T], ksp[i * T : (i + 1) * T], ) ksp = ksp.reshape((B, C, K)) return ksp def _op(self, image, coeffs): self.raw_op.op(coeffs, image)
[docs] @with_numpy def adj_op(self, coeffs, img=None): """Non Cartesian MRI adjoint operator. Parameters ---------- coeffs: np.array or GPUArray Returns ------- Array in the same memory space of coeffs. (ie on cpu or gpu Memory). """ self.check_shape(image=img, ksp=coeffs) coeffs = auto_cast(coeffs, self.cpx_dtype) if self.uses_sense: ret = self._adj_op_sense(coeffs, img) # calibrationless or monocoil. else: ret = self._adj_op_calibless(coeffs, img) ret /= self.norm_factor return self._safe_squeeze(ret)
def _adj_op_sense(self, coeffs, img=None): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape if img is None: img = np.zeros((B, *XYZ), dtype=self.cpx_dtype) coeffs_flat = coeffs.reshape((B * C, K)) img_batched = np.zeros((T, *XYZ), dtype=self.cpx_dtype) for i in range(B * C // T): idx_coils = np.arange(i * T, (i + 1) * T) % C idx_batch = np.arange(i * T, (i + 1) * T) // C self._adj_op(coeffs_flat[i * T : (i + 1) * T], img_batched) img_batched *= self.smaps[idx_coils].conj() for t, b in enumerate(idx_batch): img[b] += img_batched[t] img = img.reshape((B, 1, *XYZ)) return img def _adj_op_calibless(self, coeffs, img=None): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape if img is None: img = np.empty((B * C, *XYZ), dtype=self.cpx_dtype) coeffs_f = np.reshape(coeffs, (B * C, K)) for i in range((B * C) // T): self._adj_op(coeffs_f[i * T : (i + 1) * T], img[i * T : (i + 1) * T]) img = img.reshape((B, C, *XYZ)) return img def _adj_op(self, coeffs, image): if self.density is not None: coeffs2 = coeffs.copy() for i in range(self.n_trans): coeffs2[i * self.n_samples : (i + 1) * self.n_samples] *= self.density else: coeffs2 = coeffs self.raw_op.adj_op(coeffs2, image)
[docs] @with_numpy_cupy def data_consistency(self, image_data, obs_data): """Compute the gradient data consistency. This mixes the op and adj_op method to perform F_adj(F(x-y)) on a per coil basis. By doing the computation coil wise, it uses less memory than the naive call to adj_op(op(x)-y) Parameters ---------- image: array Image on which the gradient operation will be evaluated. N_coil x Image shape is not using sense. obs_data: array Observed data. """ if self.uses_sense: return self._safe_squeeze(self._grad_sense(image_data, obs_data)) return self._safe_squeeze(self._grad_calibless(image_data, obs_data))
def _grad_sense(self, image_data, obs_data): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape dataf = image_data.reshape((B, *XYZ)) obs_dataf = obs_data.reshape((B * C, K)) grad = np.zeros_like(dataf) coil_img = np.empty((T, *XYZ), dtype=self.cpx_dtype) coil_ksp = np.empty((T, K), dtype=self.cpx_dtype) for i in range(B * C // T): idx_coils = np.arange(i * T, (i + 1) * T) % C idx_batch = np.arange(i * T, (i + 1) * T) // C coil_img = self.smaps[idx_coils].copy().reshape((T, *XYZ)) coil_img *= dataf[idx_batch] self._op(coil_img, coil_ksp) coil_ksp /= self.norm_factor coil_ksp -= obs_dataf[i * T : (i + 1) * T] self._adj_op(coil_ksp, coil_img) coil_img *= self.smaps[idx_coils].conj() for t, b in enumerate(idx_batch): grad[b] += coil_img[t] grad /= self.norm_factor return grad def _grad_calibless(self, image_data, obs_data): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape dataf = image_data.reshape((B * C, *XYZ)) obs_dataf = obs_data.reshape((B * C, K)) grad = np.empty_like(dataf) ksp = np.empty((T, K), dtype=self.cpx_dtype) for i in range(B * C // T): self._op(dataf[i * T : (i + 1) * T], ksp) ksp /= self.norm_factor ksp -= obs_dataf[i * T : (i + 1) * T] if self.uses_density: ksp *= self.density self._adj_op(ksp, grad[i * T : (i + 1) * T]) grad /= self.norm_factor return grad.reshape(B, C, *XYZ)
[docs] def _safe_squeeze(self, arr): """Squeeze the first two dimensions of shape of the operator.""" if self.squeeze_dims: try: arr = arr.squeeze(axis=1) except ValueError: pass try: arr = arr.squeeze(axis=0) except ValueError: pass return arr