Source code for snake.core.engine.nufft

"""Acquisition engine using nufft."""

from collections.abc import Sequence

import ismrmrd as mrd
import numpy as np
from mrinufft.operators import FourierOperatorBase, get_operator
from numpy.typing import NDArray

from snake.core.phantom import DynamicData, Phantom
from snake.core.simulation import SimConfig

from .base import BaseAcquisitionEngine
from .utils import get_phantom_state


[docs] class NufftAcquisitionEngine(BaseAcquisitionEngine): """Acquisition engine using nufft.""" __engine_name__ = "NUFFT" __mp_mode__ = "spawn" model: str = "simple" snr: float = np.inf slice_2d: bool = False
[docs] def _job_trajectories( self, dataset: mrd.Dataset, hdr: mrd.xsd.ismrmrdHeader, sim_conf: SimConfig, shot_idx: Sequence[int] | int, ) -> NDArray: """Get Non Cartesian trajectories from the dataset. Returns ------- NDArray The trajectories. """ if not isinstance(shot_idx, Sequence): shot_idx = [shot_idx] head = dataset._dataset["data"][0]["head"] n_samples = head["number_of_samples"] ndim = head["trajectory_dimensions"] trajectories = np.zeros((len(shot_idx), n_samples, ndim), dtype=np.float32) for i, s in enumerate(shot_idx): trajectories[i] = dataset._dataset["data"][s]["traj"].reshape( n_samples, ndim ) return trajectories
[docs] @staticmethod def _init_model_nufft( samples: NDArray, sim_conf: SimConfig, backend: str, slice_2d: bool = False, ) -> FourierOperatorBase: """Initialize the nufft operator.""" kwargs = {} if slice_2d and "stacked" in backend: raise ValueError("Stacked NUFFT does not support 2D slice") if "stacked" in backend: kwargs["z_index"] = "auto" shape_ = sim_conf.shape if slice_2d: shape_ = sim_conf.shape[:-1] nufft = get_operator(backend)( samples, # will be updated in the loop shape=shape_, n_coils=sim_conf.hardware.n_coils, smaps=None, density=False, squeeze_dims=False, **kwargs, ) return nufft
[docs] @staticmethod def _job_model_T2s( phantom: Phantom, dyn_datas: list[DynamicData], sim_conf: SimConfig, trajectories: NDArray, nufft_backend: str, slice_2d: bool = False, ) -> np.ndarray: """Acquire k-space data with T2s relaxation effect.""" chunk_size, n_samples, _ = trajectories.shape final_ksp = np.zeros( (chunk_size, sim_conf.hardware.n_coils, n_samples), dtype=np.complex64 ) # (n_tissues_true, n_samples) Filter the tissues that have NaN Values. nufft = NufftAcquisitionEngine._init_model_nufft( trajectories[0], sim_conf, backend=nufft_backend, slice_2d=slice_2d, ) echo_idx = np.argmin(np.sum(np.abs(trajectories[0]) ** 2), axis=-1) t2s_decay = BaseAcquisitionEngine._job_get_T2s_decay( sim_conf.hardware.dwell_time_ms, echo_idx, n_samples, phantom ) nufft.n_batchs = len(phantom.masks) # number of tissues. for i, traj in enumerate(trajectories): phantom_state, smaps = get_phantom_state( phantom, dyn_datas, i, sim_conf, aggregate=False ) if slice_2d: slice_loc = round((traj[0, -1] + 0.5) * sim_conf.shape[-1]) nufft.samples = traj[:, :2] if smaps is not None: nufft.smaps = smaps[..., slice_loc] phantom_state = phantom_state[:, None, ..., slice_loc] else: phantom_state = phantom_state[:, None, ...] nufft.samples = traj ksp = nufft.op(phantom_state) # apply the T2s and sum over tissues # final_ksp[i] = np.sum(ksp * t2s_decay[:, None, :], axis=0) final_ksp[i] = np.einsum("kij, kj-> ij", ksp, t2s_decay) return final_ksp
[docs] @staticmethod def _job_model_simple( phantom: Phantom, dyn_datas: list[DynamicData], sim_conf: SimConfig, trajectories: NDArray, nufft_backend: str, slice_2d: bool = False, ) -> np.ndarray: """Acquire k-space data. No T2s decay.""" chunk_size, n_samples, _ = trajectories.shape final_ksp = np.zeros( (chunk_size, sim_conf.hardware.n_coils, n_samples), dtype=np.complex64 ) nufft = NufftAcquisitionEngine._init_model_nufft( trajectories[0], sim_conf, backend=nufft_backend, slice_2d=slice_2d, ) # (n_tissues_true, n_samples) Filter the tissues that have NaN Values for i, traj in enumerate(trajectories): phantom_state, smaps = get_phantom_state(phantom, dyn_datas, i, sim_conf) nufft.n_batchs = 1 # number of tissues. if slice_2d: slice_loc = int((traj[0, -1] + 0.5) * sim_conf.shape[-1]) nufft.samples = traj[:, :2] if smaps is not None: nufft.smaps = smaps[..., slice_loc] phantom_state = phantom_state[None, ..., slice_loc] else: nufft.samples = traj phantom_state = phantom_state[None, ...] final_ksp[i] = nufft.op(phantom_state) return final_ksp
[docs] def _write_chunk_data( self, dataset: mrd.Dataset, chunk: Sequence[int], chunk_data: NDArray ) -> None: shot_idx = np.asarray(chunk) acq_chunk = dataset._dataset["data"][shot_idx] chunk_data = chunk_data.view(np.float32) acq_chunk["data"] = chunk_data.reshape(acq_chunk["data"].shape) dataset._dataset["data"][shot_idx] = acq_chunk