"""Off Resonance correction Operator wrapper.
Based on the implementation of Guillaume Daval-Frérot in pysap-mri:
https://github.com/CEA-COSMIC/pysap-mri/blob/master/mri/operators/fourier/orc_wrapper.py
"""
import math
import numpy as np
from .._array_compat import CUPY_AVAILABLE, AUTOGRAD_AVAILABLE, with_numpy_cupy
from .._utils import get_array_module
from .base import FourierOperatorBase
from .interfaces.utils import is_cuda_array
if CUPY_AVAILABLE:
import cupy as cp
if AUTOGRAD_AVAILABLE:
import torch
[docs]
@with_numpy_cupy
def get_interpolators_from_fieldmap(
b0_map, readout_time, n_time_segments=6, n_bins=(40, 10), mask=None, r2star_map=None
):
r"""Approximate ``exp(-2j*pi*fieldmap*readout_time) ≈ Σ B_n(t)C_n(r)``.
Here, B_n(t) are n_time_segments temporal coefficients and C_n(r)
are n_time_segments temporal spatial coefficients.
The matrix B has shape ``(n_time_segments, len(readout_time))``
and C has shape ``(n_time_segments, *b0_map.shape)``.
From Sigpy: https://github.com/mikgroup/sigpy
and MIRT (mri_exp_approx.m): https://web.eecs.umich.edu/~fessler/code/
Parameters
----------
b0_map : np.ndarray
Static field inhomogeneities map.
``b0_map`` and ``readout_time`` should have reciprocal units.
Also supports Cupy arrays and Torch tensors.
readout_time : np.ndarray
Readout time in ``[s]`` of shape ``(n_shots, n_pts)`` or ``(n_shots * n_pts,)``.
Also supports Cupy arrays and Torch tensors.
n_time_segments : int, optional
Number of time segments. The default is ``6``.
n_bins : int | Sequence[int] optional
Number of histogram bins to use for ``(B0, T2*)``. The default is ``(40, 10)``
If it is a scalar, assume ``n_bins = (n_bins, 10)``.
For real fieldmap (B0 only), ``n_bins[1]`` is ignored.
mask : np.ndarray, optional
Boolean mask of the region of interest
(e.g., corresponding to the imaged object).
This is used to exclude the background fieldmap values
from histogram computation. Must have same shape as ``b0_map``.
The default is ``None`` (use the whole map).
Also supports Cupy arrays and Torch tensors.
r2star_map : np.ndarray, optional
Effective transverse relaxation map (R2*).
``r2star_map`` and ``readout_time`` should have reciprocal units.
Must have same shape as ``b0_map``.
The default is ``None`` (purely imaginary field).
Also supports Cupy arrays and Torch tensors.
Notes
-----
The total field map used to calculate the field coefficients is
``field_map = R2*_map + 1j * B0_map``. If R2* is not provided,
the field is purely immaginary: ``field_map = 1j * B0_map``.
Returns
-------
B : np.ndarray
Temporal interpolator of shape ``(n_time_segments, len(t))``.
Array module is the same as input field_map.
tl : np.ndarray
Time segment centers of shape ``(n_time_segments,)``.
Array module is the same as input field_map.
"""
# default
if not isinstance(n_bins, (list, tuple)):
n_bins = (n_bins, 10)
n_bins = list(n_bins)
# get backend and device
xp = get_array_module(b0_map)
# enforce data types
b0_map = xp.asarray(b0_map, dtype=xp.float32)
readout_time = xp.asarray(readout_time, dtype=xp.float32).ravel()
if mask is None:
mask = xp.ones_like(b0_map, dtype=bool)
else:
mask = xp.asarray(mask, dtype=bool)
if r2star_map is not None:
r2star_map = xp.asarray(r2star_map, dtype=xp.float32)
# Hz to radians / s
field_map = _get_complex_fieldmap(b0_map, r2star_map)
# enforce precision
field_map = xp.asarray(field_map, dtype=xp.complex64)
# create histograms
z = field_map[mask].ravel()
if r2star_map is not None:
z = xp.stack((z.imag, z.real), axis=1)
hk, ze = xp.histogramdd(z, bins=n_bins)
ze = list(ze)
# get bin centers
zc = [e[1:] - (e[1] - e[0]) / 2 for e in ze]
# complexify
zk = _outer_sum(1j * zc[0], zc[1]) # [K1 K2]
zk = zk.T
hk = hk.T
else:
hk, ze = xp.histogram(z.imag, bins=n_bins[0])
# get bin centers
zc = ze[1:] - (ze[1] - ze[0]) / 2
# complexify
zk = 1j * zc # [K 1]
# flatten histogram values and centers
hk = hk.ravel()
zk = zk.ravel()
# generate time for each segment
tl = xp.linspace(
readout_time.min(), readout_time.max(), n_time_segments, dtype=xp.float32
) # time seg centers in [s]
# prepare for basis calculation
ch = xp.exp(-tl[:, None, ...] @ zk[None, ...])
w = xp.diag(hk**0.5)
p = xp.linalg.pinv(w @ ch.T) @ w
# actual temporal basis calculation
B = p @ xp.exp(-zk[:, None, ...] * readout_time[None, ...])
B = B.astype(xp.complex64)
return B, tl
def _outer_sum(xx, yy):
xx = xx[:, None, ...] # add a singleton dimension at axis 1
yy = yy[None, ...] # add a singleton dimension at axis 0
ss = xx + yy # compute the outer sum
return ss
[docs]
class MRIFourierCorrected(FourierOperatorBase):
"""Fourier Operator with B0 Inhomogeneities compensation.
This is a wrapper around the Fourier Operator to compensate for the
B0 inhomogeneities in the k-space.
Parameters
----------
b0_map : np.ndarray
Static field inhomogeneities map.
``b0_map`` and ``readout_time`` should have reciprocal units.
Also supports Cupy arrays and Torch tensors.
readout_time : np.ndarray
Readout time in ``[s]`` of shape ``(n_shots, n_pts)`` or ``(n_shots * n_pts,)``.
Also supports Cupy arrays and Torch tensors.
n_time_segments : int, optional
Number of time segments. The default is ``6``.
n_bins : int | Sequence[int] optional
Number of histogram bins to use for ``(B0, T2*)``. The default is ``(40, 10)``
If it is a scalar, assume ``n_bins = (n_bins, 10)``.
For real fieldmap (B0 only), ``n_bins[1]`` is ignored.
mask : np.ndarray, optional
Boolean mask of the region of interest
(e.g., corresponding to the imaged object).
This is used to exclude the background fieldmap values
from histogram computation.
The default is ``None`` (use the whole map).
Also supports Cupy arrays and Torch tensors.
B : np.ndarray, optional
Temporal interpolator of shape ``(n_time_segments, len(readout_time))``.
tl : np.ndarray, optional
Time segment centers of shape ``(n_time_segments,)``.
Also supports Cupy arrays and Torch tensors.
r2star_map : np.ndarray, optional
Effective transverse relaxation map (R2*).
``r2star_map`` and ``readout_time`` should have reciprocal units.
Must have same shape as ``b0_map``.
The default is ``None`` (purely imaginary field).
Also supports Cupy arrays and Torch tensors.
backend: str, optional
The backend to use for computations. Either 'cpu', 'gpu' or 'torch'.
The default is `cpu`.
Notes
-----
The total field map used to calculate the field coefficients is
``field_map = R2*_map + 1j * B0_map``. If R2* is not provided,
the field is purely immaginary: ``field_map = 1j * B0_map``.
"""
def __init__(
self,
fourier_op,
b0_map=None,
readout_time=None,
n_time_segments=6,
n_bins=(40, 10),
mask=None,
r2star_map=None,
B=None,
tl=None,
backend="cpu",
):
if backend == "gpu" and not CUPY_AVAILABLE:
raise RuntimeError("Cupy is required for gpu computations.")
if backend == "torch":
self.xp = torch
if backend == "gpu":
self.xp = cp
elif backend == "cpu":
self.xp = np
else:
raise ValueError("Unsupported backend.")
self._fourier_op = fourier_op
self.n_coils = fourier_op.n_coils
self.shape = fourier_op.shape
self.smaps = fourier_op.smaps
self.autograd_available = fourier_op.autograd_available
if B is not None and tl is not None:
self.B = self.xp.asarray(B)
self.tl = self.xp.asarray(tl)
else:
b0_map = self.xp.asarray(b0_map)
self.B, self.tl = get_interpolators_from_fieldmap(
b0_map,
readout_time,
n_time_segments,
n_bins,
mask,
r2star_map,
)
if self.B is None or self.tl is None:
raise ValueError("Please either provide fieldmap and t or B and tl")
self.n_interpolators = self.B.shape[0]
# create spatial interpolator
field_map = _get_complex_fieldmap(b0_map, r2star_map)
if is_cuda_array(b0_map):
self.C = None
self.field_map = field_map
else:
self.C = _get_spatial_coefficients(field_map, self.tl)
self.field_map = None
[docs]
def op(self, data, *args):
"""Compute Forward Operation with off-resonance effect.
Parameters
----------
x: numpy.ndarray
N-D input image.
Also supports Cupy arrays and Torch tensors.
Returns
-------
numpy.ndarray
Masked distorted N-D k-space.
Array module is the same as input data.
"""
y = 0.0
data_d = self.xp.asarray(data)
if self.C is not None:
for idx in range(self.n_interpolators):
y += self.B[idx] * self._fourier_op.op(self.C[idx] * data_d, *args)
else:
for idx in range(self.n_interpolators):
C = self.xp.exp(-self.field_map * self.tl[idx].item())
y += self.B[idx] * self._fourier_op.op(C * data_d, *args)
return y
[docs]
def adj_op(self, coeffs, *args):
"""
Compute Adjoint Operation with off-resonance effect.
Parameters
----------
x: numpy.ndarray
Masked distorted N-D k-space.
Also supports Cupy arrays and Torch tensors.
Returns
-------
numpy.ndarray
Inverse Fourier transform of the distorted input k-space.
Array module is the same as input coeffs.
"""
y = 0.0
coeffs_d = self.xp.array(coeffs)
if self.C is not None:
for idx in range(self.n_interpolators):
y += self.xp.conj(self.C[idx]) * self._fourier_op.adj_op(
self.xp.conj(self.B[idx]) * coeffs_d, *args
)
else:
for idx in range(self.n_interpolators):
C = self.xp.exp(-self.field_map * self.tl[idx].item())
y += self.xp.conj(C) * self._fourier_op.adj_op(
self.xp.conj(self.B[idx]) * coeffs_d, *args
)
return y
[docs]
@staticmethod
def get_spatial_coefficients(field_map, tl):
"""Compute spatial coefficients for field approximation.
Parameters
----------
field_map : np.ndarray
Total field map used to calculate the field coefficients is
``field_map = R2*_map + 1j * B0_map``.
Also supports Cupy arrays and Torch tensors.
tl : np.ndarray
Time segment centers of shape ``(n_time_segments,)``.
Also supports Cupy arrays and Torch tensors.
Returns
-------
C : np.ndarray
Off-resonance phase map at each time segment center of shape
``(n_time_segments, *field_map.shape)``.
Array module is the same as input field_map.
"""
return _get_spatial_coefficients(field_map, tl)
def _get_complex_fieldmap(b0_map, r2star_map=None):
xp = get_array_module(b0_map)
if r2star_map is not None:
r2star_map = xp.asarray(r2star_map, dtype=xp.float32)
field_map = 2 * math.pi * (r2star_map + 1j * b0_map)
else:
field_map = 2 * math.pi * 1j * b0_map
return field_map
def _get_spatial_coefficients(field_map, tl):
xp = get_array_module(field_map)
# get spatial coeffs
C = xp.exp(-tl * field_map[..., None])
C = C[None, ...].swapaxes(0, -1)[
..., 0
] # (..., n_time_segments) -> (n_time_segments, ...)
C = xp.asarray(C, dtype=xp.complex64)
# clean-up of spatial coeffs
C = xp.nan_to_num(C, nan=0.0, posinf=0.0, neginf=0.0)
return C