Source code for mrinufft.operators.base

"""
Base Fourier Operator interface.

from https://github.com/CEA-COSMIC/pysap-mri

:author: Pierre-Antoine Comby
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from functools import partial

import numpy as np

from mrinufft._array_compat import with_numpy, with_numpy_cupy, AUTOGRAD_AVAILABLE
from mrinufft._utils import auto_cast, power_method
from mrinufft.density import get_density
from mrinufft.extras import get_smaps
from mrinufft.operators.interfaces.utils import is_cuda_array, is_host_array

if AUTOGRAD_AVAILABLE:
    from mrinufft.operators.autodiff import MRINufftAutoGrad


# Mapping between numpy float and complex types.
DTYPE_R2C = {"float32": "complex64", "float64": "complex128"}


[docs] def check_backend(backend_name: str): """Check if a specific backend is available.""" backend_name = backend_name.lower() try: return FourierOperatorBase.interfaces[backend_name][0] except KeyError as e: raise ValueError(f"unknown backend: '{backend_name}'") from e
[docs] def list_backends(available_only=False): """Return a list of backend. Parameters ---------- available_only: bool, optional If True, only return backends that are available. If False, return all backends, regardless of whether they are available or not. """ return [ name for name, (available, _) in FourierOperatorBase.interfaces.items() if available or not available_only ]
[docs] def get_operator( backend_name: str, wrt_data: bool = False, wrt_traj: bool = False, *args, **kwargs ): """Return an MRI Fourier operator interface using the correct backend. Parameters ---------- backend_name: str Backend name wrt_data: bool, default False if set gradients wrt to data and images will be available. wrt_traj: bool, default False if set gradients wrt to trajectory will be available. *args, **kwargs: Arguments to pass to the operator constructor. Returns ------- FourierOperator class or instance of class if args or kwargs are given. Raises ------ ValueError if the backend is not available. """ available = True backend_name = backend_name.lower() try: available, operator = FourierOperatorBase.interfaces[backend_name] except KeyError as exc: if not backend_name.startswith("stacked-"): raise ValueError(f"backend {backend_name} does not exist") from exc # try to get the backend with stacked # Dedicated registered stacked backend (like stacked-cufinufft) # have be found earlier. backend = backend_name.split("-")[1] operator = get_operator("stacked") operator = partial(operator, backend=backend) if not available: raise ValueError(f"backend {backend_name} found, but dependencies are not met.") if args or kwargs: operator = operator(*args, **kwargs) # if autograd: if wrt_data or wrt_traj: if isinstance(operator, FourierOperatorBase): operator = operator.make_autograd(wrt_data, wrt_traj) else: # instance will be created later operator = partial(operator.with_autograd, wrt_data, wrt_traj) return operator
[docs] class FourierOperatorBase(ABC): """Base Fourier Operator class. Every (Linear) Fourier operator inherits from this class, to ensure that we have all the functions rightly implemented as required by ModOpt. """ interfaces: dict[str, tuple] = {} autograd_available = False _density_method = None _grad_wrt_data = False _grad_wrt_traj = False def __init__(self): if not self.available: raise RuntimeError(f"'{self.backend}' backend is not available.") self._smaps = None self._density = None self._n_coils = 1 def __init_subclass__(cls): """Register the class in the list of available operators.""" super().__init_subclass__() available = getattr(cls, "available", True) if callable(available): available = available() if backend := getattr(cls, "backend", None): cls.interfaces[backend] = (available, cls)
[docs] def check_shape(self, *, image=None, ksp=None): """ Validate the shapes of the image or k-space data against operator shapes. Parameters ---------- image : np.ndarray, optional If passed, the shape of image data will be checked. ksp : np.ndarray or object, optional If passed, the shape of the k-space data will be checked. Raises ------ ValueError If the shape of the provided image does not match the expected operator shape, or if the number of k-space samples does not match the expected number of samples. """ if image is not None: image_shape = image.shape[-len(self.shape) :] if image_shape != self.shape: raise ValueError( f"Image shape {image_shape} is not compatible " f"with the operator shape {self.shape}" ) if ksp is not None: kspace_shape = ksp.shape[-1] if kspace_shape != self.n_samples: raise ValueError( f"Kspace samples {kspace_shape} is not compatible " f"with the operator samples {self.n_samples}" ) if image is None and ksp is None: raise ValueError("Nothing to check, provides image or ksp arguments")
[docs] @abstractmethod def op(self, data): """Compute operator transform. Parameters ---------- data: np.ndarray input as array. Returns ------- result: np.ndarray operator transform of the input. """ pass
[docs] @abstractmethod def adj_op(self, coeffs): """Compute adjoint operator transform. Parameters ---------- x: np.ndarray input data array. Returns ------- results: np.ndarray adjoint operator transform. """ pass
[docs] def data_consistency(self, image, obs_data): """Compute the gradient data consistency. This is the naive implementation using adj_op(op(x)-y). Specific backend can (and should!) implement a more efficient version. """ return self.adj_op(self.op(image) - obs_data)
[docs] def with_off_resonance_correction(self, B, C, indices): """Return a new operator with Off Resonnance Correction.""" from ..off_resonance import MRIFourierCorrected return MRIFourierCorrected(self, B, C, indices)
[docs] def compute_smaps(self, method=None): """Compute the sensitivity maps and set it. Parameters ---------- method: callable or dict or array The method to use to compute the sensitivity maps. If an array, it should be of shape (NCoils,XYZ) and will be used as is. If a dict, it should have a key 'name', to determine which method to use. other items will be used as kwargs. If a callable, it should take the samples and the shape as input. Note that this callable function should also hold the k-space data (use funtools.partial) """ if is_host_array(method) or is_cuda_array(method): self.smaps = method return if not method: self.smaps = None return kwargs = {} if isinstance(method, dict): kwargs = method.copy() method = kwargs.pop("name") if isinstance(method, str): method = get_smaps(method) if not callable(method): raise ValueError(f"Unknown smaps method: {method}") self.smaps, self.SOS = method( self.samples, self.shape, density=self.density, backend=self.backend, **kwargs, )
[docs] def make_autograd(self, wrt_data=True, wrt_traj=False): """Make a new Operator with autodiff support. Parameters ---------- variable: , default data variable on which the gradient is computed with respect to. wrt_data : bool, optional If the gradient with respect to the data is computed, default is true wrt_traj : bool, optional If the gradient with respect to the trajectory is computed, default is false Returns ------- torch.nn.module A NUFFT operator with autodiff capabilities. Raises ------ ValueError If autograd is not available. """ if not AUTOGRAD_AVAILABLE: raise ValueError("Autograd not available, ensure torch is installed.") if not self.autograd_available: raise ValueError("Backend does not support auto-differentiation.") return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj)
[docs] def compute_density(self, method=None): """Compute the density compensation weights and set it. Parameters ---------- method: str or callable or array or dict or bool The method to use to compute the density compensation. If a string, the method should be registered in the density registry. If a callable, it should take the samples and the shape as input. If a dict, it should have a key 'name', to determine which method to use. other items will be used as kwargs. If an array, it should be of shape (Nsamples,) and will be used as is. If `True`, the method `pipe` is chosen as default estimation method, if `backend` is `tensorflow`, `gpunufft` or `torchkbnufft-cpu` or `torchkbnufft-gpu`. """ if isinstance(method, np.ndarray): self._density = method return None if not method: self._density = None return None if method is True: method = "pipe" kwargs = {} if isinstance(method, dict): kwargs = method.copy() method = kwargs.pop("name") # must be a string ! if method == "pipe" and "backend" not in kwargs: kwargs["backend"] = self.backend if isinstance(method, str): method = get_density(method) if not callable(method): raise ValueError(f"Unknown density method: {method}") if self._density_method is None: self._density_method = lambda samples, shape: method( samples, shape, **kwargs, ) self._density = method(self.samples, self.shape, **kwargs)
[docs] def get_lipschitz_cst(self, max_iter=10, **kwargs): """Return the Lipschitz constant of the operator. Parameters ---------- max_iter: int number of iteration to compute the lipschitz constant. **kwargs: Extra arguments givent Returns ------- float Spectral Radius Notes ----- This uses the Iterative Power Method to compute the largest singular value of a minified version of the nufft operator. No coil or B0 compensation is used, but includes any computed density. """ if self.n_coils > 1: tmp_op = self.__class__( self.samples, self.shape, density=self.density, n_coils=1, **kwargs ) else: tmp_op = self return power_method(max_iter, tmp_op)
@property def uses_sense(self): """Return True if the operator uses sensitivity maps.""" return self._smaps is not None @property def uses_density(self): """Return True if the operator uses density compensation.""" return getattr(self, "density", None) is not None @property def ndim(self): """Number of dimensions in image space of the operator.""" return len(self._shape) @property def shape(self): """Shape of the image space of the operator.""" return self._shape @shape.setter def shape(self, shape): self._shape = tuple(shape) @property def n_coils(self): """Number of coils for the operator.""" return self._n_coils @n_coils.setter def n_coils(self, n_coils): if n_coils < 1 or not int(n_coils) == n_coils: raise ValueError(f"n_coils should be a positive integer, {type(n_coils)}") self._n_coils = int(n_coils) @property def smaps(self): """Sensitivity maps of the operator.""" return self._smaps @smaps.setter def smaps(self, smaps): self._check_smaps_shape(smaps) self._smaps = smaps
[docs] def _check_smaps_shape(self, smaps): """Check the shape of the sensitivity maps.""" if smaps is None: self._smaps = None elif smaps.shape != (self.n_coils, *self.shape): raise ValueError( f"smaps shape is {smaps.shape}, it should be" f"(n_coils, *shape): {(self.n_coils, *self.shape)}" )
@property def density(self): """Density compensation of the operator.""" return self._density @density.setter def density(self, density): if density is None: self._density = None elif len(density) != self.n_samples: raise ValueError("Density and samples should have the same length") else: self._density = density @property def dtype(self): """Return floating precision of the operator.""" return self._dtype @dtype.setter def dtype(self, dtype): self._dtype = np.dtype(dtype) @property def cpx_dtype(self): """Return complex floating precision of the operator.""" return np.dtype(DTYPE_R2C[str(self.dtype)]) @property def samples(self): """Return the samples used by the operator.""" return self._samples @samples.setter def samples(self, samples): self._samples = samples @property def n_samples(self): """Return the number of samples used by the operator.""" return self._samples.shape[0] @property def norm_factor(self): """Normalization factor of the operator.""" return np.sqrt(np.prod(self.shape) * (2 ** len(self.shape))) def __repr__(self): """Return info about the Fourier operator.""" return ( f"{self.__class__.__name__}(\n" f" shape: {self.shape}\n" f" n_coils: {self.n_coils}\n" f" n_samples: {self.n_samples}\n" f" uses_sense: {self.uses_sense}\n" ")" )
[docs] @classmethod def with_autograd(cls, wrt_data=True, wrt_traj=False, *args, **kwargs): """Return a Fourier operator with autograd capabilities.""" return cls(*args, **kwargs).make_autograd(wrt_data, wrt_traj)
[docs] class FourierOperatorCPU(FourierOperatorBase): """Base class for CPU-based NUFFT operator. The NUFFT operation will be done sequentially and looped over coils and batches. Parameters ---------- samples: np.ndarray The samples used by the operator. shape: tuple The shape of the image space (in 2D or 3D) density: bool or np.ndarray If True, the density compensation is estimated from the samples. If False, no density compensation is applied. If np.ndarray, the density compensation is applied from the array. n_coils: int The number of coils. smaps: np.ndarray The sensitivity maps. raw_op: object An object implementing the NUFFT API. Ut should be responsible to compute a single type 1 /type 2 NUFFT. """ def __init__( self, samples, shape, density=False, n_coils=1, n_batchs=1, n_trans=1, smaps=None, raw_op=None, squeeze_dims=True, ): super().__init__() self.shape = shape # we will access the samples by their coordinate first. self._samples = samples.reshape(-1, len(shape)) self.dtype = self.samples.dtype if n_coils < 1: raise ValueError("n_coils should be ≥ 1") self.n_coils = n_coils self.n_batchs = n_batchs self.n_trans = n_trans self.squeeze_dims = squeeze_dims # Density Compensation Setup self.compute_density(density) # Multi Coil Setup self.compute_smaps(smaps) self.raw_op = raw_op
[docs] @with_numpy def op(self, data, ksp=None): r"""Non Cartesian MRI forward operator. Parameters ---------- data: np.ndarray The uniform (2D or 3D) data in image space. Returns ------- Results array on the same device as data. Notes ----- this performs for every coil \ell: ..math:: \mathcal{F}\mathcal{S}_\ell x """ self.check_shape(image=data, ksp=ksp) # sense data = auto_cast(data, self.cpx_dtype) if self.uses_sense: ret = self._op_sense(data, ksp) # calibrationless or monocoil. else: ret = self._op_calibless(data, ksp) ret /= self.norm_factor ret = self._safe_squeeze(ret) return ret
def _op_sense(self, data, ksp=None): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape dataf = data.reshape((B, *XYZ)) if ksp is None: ksp = np.empty((B * C, K), dtype=self.cpx_dtype) for i in range(B * C // T): idx_coils = np.arange(i * T, (i + 1) * T) % C idx_batch = np.arange(i * T, (i + 1) * T) // C coil_img = self.smaps[idx_coils].copy().reshape((T, *XYZ)) coil_img *= dataf[idx_batch] self._op(coil_img, ksp[i * T : (i + 1) * T]) ksp = ksp.reshape((B, C, K)) return ksp def _op_calibless(self, data, ksp=None): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape if ksp is None: ksp = np.empty((B * C, K), dtype=self.cpx_dtype) dataf = np.reshape(data, (B * C, *XYZ)) for i in range((B * C) // T): self._op( dataf[i * T : (i + 1) * T], ksp[i * T : (i + 1) * T], ) ksp = ksp.reshape((B, C, K)) return ksp def _op(self, image, coeffs): self.raw_op.op(coeffs, image)
[docs] @with_numpy def adj_op(self, coeffs, img=None): """Non Cartesian MRI adjoint operator. Parameters ---------- coeffs: np.array or GPUArray Returns ------- Array in the same memory space of coeffs. (ie on cpu or gpu Memory). """ self.check_shape(image=img, ksp=coeffs) coeffs = auto_cast(coeffs, self.cpx_dtype) if self.uses_sense: ret = self._adj_op_sense(coeffs, img) # calibrationless or monocoil. else: ret = self._adj_op_calibless(coeffs, img) ret /= self.norm_factor return self._safe_squeeze(ret)
def _adj_op_sense(self, coeffs, img=None): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape if img is None: img = np.zeros((B, *XYZ), dtype=self.cpx_dtype) coeffs_flat = coeffs.reshape((B * C, K)) img_batched = np.zeros((T, *XYZ), dtype=self.cpx_dtype) for i in range(B * C // T): idx_coils = np.arange(i * T, (i + 1) * T) % C idx_batch = np.arange(i * T, (i + 1) * T) // C self._adj_op(coeffs_flat[i * T : (i + 1) * T], img_batched) img_batched *= self.smaps[idx_coils].conj() for t, b in enumerate(idx_batch): img[b] += img_batched[t] img = img.reshape((B, 1, *XYZ)) return img def _adj_op_calibless(self, coeffs, img=None): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape if img is None: img = np.empty((B * C, *XYZ), dtype=self.cpx_dtype) coeffs_f = np.reshape(coeffs, (B * C, K)) for i in range((B * C) // T): self._adj_op(coeffs_f[i * T : (i + 1) * T], img[i * T : (i + 1) * T]) img = img.reshape((B, C, *XYZ)) return img def _adj_op(self, coeffs, image): if self.density is not None: coeffs2 = coeffs.copy() for i in range(self.n_trans): coeffs2[i * self.n_samples : (i + 1) * self.n_samples] *= self.density else: coeffs2 = coeffs self.raw_op.adj_op(coeffs2, image)
[docs] @with_numpy_cupy def data_consistency(self, image_data, obs_data): """Compute the gradient data consistency. This mixes the op and adj_op method to perform F_adj(F(x-y)) on a per coil basis. By doing the computation coil wise, it uses less memory than the naive call to adj_op(op(x)-y) Parameters ---------- image: array Image on which the gradient operation will be evaluated. N_coil x Image shape is not using sense. obs_data: array Observed data. """ if self.uses_sense: return self._safe_squeeze(self._grad_sense(image_data, obs_data)) return self._safe_squeeze(self._grad_calibless(image_data, obs_data))
def _grad_sense(self, image_data, obs_data): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape dataf = image_data.reshape((B, *XYZ)) obs_dataf = obs_data.reshape((B * C, K)) grad = np.zeros_like(dataf) coil_img = np.empty((T, *XYZ), dtype=self.cpx_dtype) coil_ksp = np.empty((T, K), dtype=self.cpx_dtype) for i in range(B * C // T): idx_coils = np.arange(i * T, (i + 1) * T) % C idx_batch = np.arange(i * T, (i + 1) * T) // C coil_img = self.smaps[idx_coils].copy().reshape((T, *XYZ)) coil_img *= dataf[idx_batch] self._op(coil_img, coil_ksp) coil_ksp /= self.norm_factor coil_ksp -= obs_dataf[i * T : (i + 1) * T] self._adj_op(coil_ksp, coil_img) coil_img *= self.smaps[idx_coils].conj() for t, b in enumerate(idx_batch): grad[b] += coil_img[t] grad /= self.norm_factor return grad def _grad_calibless(self, image_data, obs_data): T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape dataf = image_data.reshape((B * C, *XYZ)) obs_dataf = obs_data.reshape((B * C, K)) grad = np.empty_like(dataf) ksp = np.empty((T, K), dtype=self.cpx_dtype) for i in range(B * C // T): self._op(dataf[i * T : (i + 1) * T], ksp) ksp /= self.norm_factor ksp -= obs_dataf[i * T : (i + 1) * T] if self.uses_density: ksp *= self.density self._adj_op(ksp, grad[i * T : (i + 1) * T]) grad /= self.norm_factor return grad.reshape(B, C, *XYZ)
[docs] def _safe_squeeze(self, arr): """Squeeze the first two dimensions of shape of the operator.""" if self.squeeze_dims: try: arr = arr.squeeze(axis=1) except ValueError: pass try: arr = arr.squeeze(axis=0) except ValueError: pass return arr