"""Engines are responsible for the acquisition of Kspace."""
from __future__ import annotations
import gc
import logging
import multiprocessing as mp
import os
from collections.abc import Mapping, Sequence
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from multiprocessing.managers import SharedMemoryManager
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, ClassVar, overload
import ismrmrd as mrd
import numpy as np
from numpy.typing import NDArray
from tqdm.auto import tqdm
from typing_extensions import dataclass_transform
from snake._meta import EnvConfig, MetaDCRegister, batched
from ...mrd_utils import MRDLoader, make_base_mrd
from ..handlers import AbstractHandler, HandlerList
from ..parallel import ArrayProps
from ..phantom import DynamicData, Phantom, PropTissueEnum
from ..sampling import BaseSampler
from ..simulation import SimConfig
from .utils import get_noise
GenericPath = Path | str
[docs]
class BaseAcquisitionEngine(metaclass=MetaEngine):
"""Base acquisition engine.
Specific step can be overwritten in subclasses.
Parameters
----------
model : str, optional
The model to use, by default "simple".
snr : float, optional
The signal to noise ratio, by default np.inf.
slice_2d : bool, optional
Whether to slice the 2D data, by default False.
"""
__engine_name__: ClassVar[str]
__registry__: ClassVar[dict[str, type[BaseAcquisitionEngine]]]
log: ClassVar[logging.Logger]
model: str = "simple"
snr: float = np.inf
slice_2d: bool = False
[docs]
def _get_chunk_list(
self,
data_loader: MRDLoader,
) -> Sequence[int]:
return range(data_loader.n_acquisition)
@overload
def _job_trajectories(
self,
dataset: mrd.Dataset,
hdr: mrd.xsd.ismrmrdHeader,
sim_conf: SimConfig,
chunk: Sequence[int],
) -> NDArray:
raise NotImplementedError
[docs]
@staticmethod
def _job_get_T2s_decay(
dwell_time_ms: float,
echo_idx: int,
n_samples: int,
phantom: Phantom,
) -> NDArray:
t = dwell_time_ms * (np.arange(n_samples, dtype=np.float32) - echo_idx)
return np.exp(-t[None, :] / phantom.props[:, PropTissueEnum.T2s, None])
@overload
@staticmethod
def _job_model_T2s(
phantom: Phantom,
dyn_datas: list[DynamicData],
sim_conf: SimConfig,
trajectories: NDArray, # (Chunksize, N, 3)
*args: Any,
**kwargs: Any,
) -> NDArray:
raise NotImplementedError
@overload
@staticmethod
def _job_model_simple(
phantom: Phantom,
dyn_datas: list[DynamicData],
sim_conf: SimConfig,
trajectories: NDArray, # (Chunksize, N, 3)
*args: Any,
**kwargs: Any,
) -> NDArray:
raise NotImplementedError
@overload
def _write_chunk_data(
self, dataset: mrd.Dataset, chunk: Sequence[int], chunk_data: NDArray
) -> None:
raise NotImplementedError
[docs]
def _acquire_ksp_job(
self,
filename: GenericPath,
chunk: Sequence[int],
tmp_dir: str,
shared_phantom_props: (
tuple[str, ArrayProps, ArrayProps, ArrayProps, ArrayProps | None] | None
) = None,
**kwargs: Mapping[str, Any],
) -> str:
"""Entry point for worker.
This handles the io part (Read dataset, write partial k-space),
and dispatch to specialized functions
for getting the k-space.
"""
# https://github.com/h5py/h5py/issues/712#issuecomment-562980532
# We know that we are going to read the dataset in read-only mode in
# this function and use the main process to write the data.
# This is an alternative to using swmr mode, that I could not get to work.
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
with MRDLoader(filename, swmr=True) as data_loader:
hdr = data_loader.header
# Get the Phantom, SimConfig, and all ...
sim_conf = data_loader.get_sim_conf()
ddatas = data_loader.get_all_dynamic()
# sim_conf = SimConfig.from_mrd_dataset(dataset)
for d in ddatas: # only keep the dynamic data that are in the chunk
d.data = d.data[:, chunk]
trajs = self._job_trajectories(data_loader, hdr, sim_conf, chunk)
_job_model = getattr(self, f"_job_model_{self.model}")
if shared_phantom_props is None:
phantom = data_loader.get_phantom()
ksp = _job_model(phantom, ddatas, sim_conf, trajs, **kwargs)
else:
with Phantom.from_shared_memory(*shared_phantom_props) as phantom:
ksp = _job_model(phantom, ddatas, sim_conf, trajs, **kwargs)
chunk_file = os.path.join(tmp_dir, f"partial_{chunk[0]}-{chunk[-1]}.npy")
np.save(chunk_file, ksp)
return chunk_file
[docs]
def __call__(
self,
filename: GenericPath,
sampler: BaseSampler,
phantom: Phantom,
sim_conf: SimConfig,
handlers: list[AbstractHandler] | HandlerList | None = None,
coil_cov: NDArray | None = None,
worker_chunk_size: int = 0,
n_workers: int = 0,
**kwargs: Any,
):
"""Perform the acquisition and fill the dataset.
Parameters
----------
filename : GenericPath
The path to the MRD file.
sampler : BaseSampler
The sampler to use.
phantom : Phantom
The phantom to use.
sim_conf : SimConfig
The simulation configuration.
handlers : list[AbstractHandler] | HandlerList | None, optional
The handlers to use, by default None.
coil_cov : NDArray | None, optional
The coil covariance matrix, by default None.
worker_chunk_size : int, optional
The size of the chunk to process for each worker, by default 0 (auto).
Each worker w
n_workers : int, optional
The number of workers to use, by default 0 (auto). Half of CPU count will
be used (This usually corresponds to the number of physical cores on the
machine).
resample_early:bool, optional
Whether to resample the phantom early, by default False
kwargs : Any
Additional keyword arguments, passed down to internal implementation.
Notes
-----
This function is the main entry point for the acquisition engine.
It will create the base dataset, and then dispatch the work to the workers.
Specific modeling steps are implemented in subclasses' methods `_job_model_T2s`
and `_job_model_simple`.
"""
if self.slice_2d: # Update the correct TR_eff
sim_conf.TR_eff = sampler.TR_vol_ms
self.log.warning("Using 2D acquisition, the TR_eff is updated to TR_vol")
else:
sim_conf.TR_eff = sim_conf.seq.TR
# Create the base dataset
make_base_mrd(
filename,
sampler,
phantom,
sim_conf,
handlers,
coil_cov,
self.model,
self.slice_2d,
)
# Guesstimate the workload
if worker_chunk_size <= 0:
# get the number of shot
worker_chunk_size = sampler.get_next_frame(sim_conf).shape[0]
if n_workers <= 0:
n_workers = mp.cpu_count() // 2
with MRDLoader(filename) as data_loader:
sim_conf = data_loader.get_sim_conf()
phantom = data_loader.get_phantom()
shot_idxs = self._get_chunk_list(data_loader)
chunk_list = list(batched(shot_idxs, worker_chunk_size))
ideal_phantom = phantom.contrast(sim_conf=sim_conf, aggregate=True)
coil_cov = data_loader.get_coil_cov() or np.eye(sim_conf.hardware.n_coils)
if self.snr > 0:
energy = np.mean(ideal_phantom**2)
coil_cov = coil_cov * energy / self.snr
del ideal_phantom
# https://github.com/h5py/h5py/issues/712#issuecomment-562980532
# We know that we are going to read the dataset in read-only mode
# and use the main process (here) to write the data.
# This is an alternative to using swmr mode, that I could not get to work.
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
if n_workers > 1:
Executor = ProcessPoolExecutor(
n_workers, mp_context=mp.get_context(self.__mp_mode__)
)
else:
Executor = ThreadPoolExecutor(max_workers=1)
with (
SharedMemoryManager() as smm,
Executor as executor,
tqdm(total=len(shot_idxs)) as pbar,
MRDLoader(filename, writeable=True) as data_loader,
TemporaryDirectory(
dir=EnvConfig["SNAKE_TMP_DIR"], prefix="snake-"
) as tmp_chunk_dir,
):
# data_loader._file.swmr_mode = True
phantom_props, _ = phantom.in_shared_memory(smm)
# TODO: also put the smaps in shared memory
futures = {
executor.submit(
self._acquire_ksp_job,
filename,
chunk_id,
tmp_dir=tmp_chunk_dir,
shared_phantom_props=phantom_props,
slice_2d=self.slice_2d,
**kwargs,
): chunk_id
for chunk_id in chunk_list
}
for future in as_completed(futures):
chunk = futures[future]
try:
f_chunk = str(future.result())
except Exception as exc:
self.log.error(f"Error in chunk {min(chunk)}-{max(chunk)}")
raise exc
else:
pbar.update(worker_chunk_size)
chunk_ksp = np.load(f_chunk)
# Add noise
if self.snr > 0:
noise = get_noise(chunk_ksp, coil_cov, sim_conf.rng)
chunk_ksp += noise
self._write_chunk_data(
data_loader,
chunk,
chunk_ksp,
)
os.remove(f_chunk)
gc.collect()