Source code for mrinufft.operators.subspace

"""Subspace NUFFT Operator wrapper."""

from mrinufft._array_compat import _get_device, _to_interface
from mrinufft._utils import ARRAY_LIBS, get_array_module


from .base import FourierOperatorBase


[docs] class MRISubspace(FourierOperatorBase): """Fourier Operator with subspace projection. This is a wrapper around the Fourier Operator to project data onto a low-rank subspace (e.g., dynamic and multi-contrast MRI). Parameters ---------- fourier_op: object of class FourierBase the fourier operator to wrap subspace_basis : np.ndarray Low rank subspace basis of shape ``(K, T)``, where K is the rank of the subspace and T is the number of time frames or contrasts in the original image series. Also supports Cupy arrays and Torch tensors. use_gpu : bool, optional Whether to use the GPU. Default is False. Ignored if the Fourier operator internally use only GPU (e.g., Cupy) or CPU (e.g., Numpy) Notes ----- This extension adds a new axis for both image and k-space data: * Image: ``(B, C, XYZ)`` -> ``(B, S, C, XYZ)`` * K-Space: ``(B, C, K)`` -> ``(B, T, C, K)`` with ``S`` representing the subspace index and ``T`` representing time domain or contrast space (for dynamic and multi-contrast MR, respectively). Similarly, k-space trajectory is expected to have the following shape: ``(<N_frames or N_contrasts>, N_shots, N_samples, dim)``. The flatten version is also accepted: ``(<N_frames or N_contrasts>, N_shots * N_samples, dim)`` """ def __init__(self, fourier_op, subspace_basis, use_gpu=False): self._fourier_op = fourier_op self.backend = fourier_op.backend self.n_batchs = fourier_op.n_batchs self.n_coils = fourier_op.n_coils self.shape = fourier_op.shape self.smaps = fourier_op.smaps self.autograd_available = fourier_op.autograd_available self.subspace_basis = subspace_basis self.n_coeffs, self.n_frames = self.subspace_basis.shape
[docs] def op(self, data, *args): """ Compute Forward Operation time/contrast-domain backprojection. Parameters ---------- data: numpy.ndarray N-D subspace-projected image. Also supports Cupy arrays and Torch tensors. Returns ------- numpy.ndarray Time/contrast-domain k-space data. """ xp = get_array_module(data) device = _get_device(data) subspace_basis = _to_interface(self.subspace_basis, xp, device) # if required, move subspace index axis to leftmost position if self.n_batchs != 1 or data.shape[0] == 1: # non-squeezed data data_d = data.swapaxes(0, 1) else: data_d = data # enforce data contiguity if xp.__name__ == "torch": data_d = data_d.contiguous() else: data_d = xp.ascontiguousarray(data_d) # perform computation y = 0.0 for idx in range(self.n_coeffs): # select basis element basis_element = subspace_basis[idx] # actual transform _y = self._fourier_op.op(data_d[idx], *args) _y = _y.reshape(*_y.shape[:-1], self.n_frames, -1) # back-project on time domain y += basis_element.conj() * _y.swapaxes(-2, -1) y = y[None, ...].swapaxes(0, -1)[..., 0] # bring back time/contrast axis to original position (B, T, ...) if self.n_batchs != 1 or data.shape[0] == 1: # non-squeezed data y = y.swapaxes(0, 1) return y
[docs] def adj_op(self, coeffs, *args): """ Compute Adjoint Operation with subspace projection. Parameters ---------- coeffs: numpy.ndarray Time/contrast-domain k-space data. Also supports Cupy arrays and Torch tensors. Returns ------- numpy.ndarray Inverse Fourier transform of the subspace-projected k-space. """ xp = get_array_module(coeffs) device = _get_device(coeffs) subspace_basis = _to_interface(self.subspace_basis, xp, device) # if required, move time/contrast axis to leftmost position if self.n_batchs != 1 or coeffs.shape[0] == 1: # non-squeezed data coeffs_d = coeffs.swapaxes(0, 1) else: coeffs_d = coeffs coeffs_d = coeffs_d[..., None].swapaxes(0, -1)[0, ...] # perform computation y = [] for idx in range(self.n_coeffs): # select basis element basis_element = subspace_basis[idx] # project on current subspace basis element _coeffs_d = basis_element * coeffs_d _coeffs_d = _coeffs_d.swapaxes(-2, -1).reshape(*coeffs_d.shape[:-2], -1) # enforce data contiguity if xp.__name__ == "torch": _coeffs_d = _coeffs_d.contiguous() else: _coeffs_d = xp.ascontiguousarray(_coeffs_d) # actual transform y.append(self._fourier_op.adj_op(_coeffs_d, *args)) # stack coefficients y = xp.stack(y, axis=0) # bring back subspace index to original position (B, S, ...) if self.n_batchs != 1 or coeffs.shape[0] == 1: y = y.swapaxes(0, 1) return y
@property def n_samples(self): """Return the number of samples used by the operator.""" return self._fourier_op.n_samples
def _get_arraylib_from_operator( fourier_op, use_gpu ): # maybe that is usefull for MRIFourierCorrected constructor? LUT = { "MRIBartNUFFT": ("numpy", "numpy"), "MRICufiNUFFT": ("cupy", "cupy"), "MRIfinufft": ("numpy", "numpy"), "MRIGpuNUFFT": ("cupy", "cupy"), "MRInfft": ("numpy", "numpy"), "MRIPynufft": ("numpy", "numpy"), "MRISigpyNUFFT": ("numpy", "cupy"), "MRITensorflowNUFFT": ("tensorflow", "tensorflow"), "MRITorchKbNufft": ("torch", "torch"), "TorchKbNUFFTcpu": ("torch", "torch"), "TorchKbNUFFTgpu": ("torch", "torch"), } return ARRAY_LIBS[LUT[fourier_op.__class__.__name__][use_gpu]][0]