"""Loader of MRD data."""
from __future__ import annotations
import logging
import os
import functools
from functools import cached_property
import ismrmrd as mrd
import h5py
import numpy as np
from numpy.typing import NDArray
from typing import Any, TYPE_CHECKING, overload
from collections.abc import Generator
from .._meta import LogMixin
if TYPE_CHECKING:
from ..core import Phantom, DynamicData
from ..core import SimConfig
from .utils import b64encode2obj
log = logging.getLogger(__name__)
GenericPath = os.PathLike | str
[docs]
class MRDLoader(LogMixin):
"""Base class for MRD data loader.
This is to be used as a context manager.
It reimplements most of the methods of the mrd.Dataset class, and adds some
useful wrappers. With this dataloader, you can open the dataset in readonly mode,
which is not possible with mrd.
"""
def __init__(
self,
filename: GenericPath,
dataset_name: str = "dataset",
writeable: bool = False,
swmr: bool = False,
):
self._filename = filename
self._dataset_name = dataset_name
self._writeable = writeable
self._swmr = swmr
self._level = 0
self._file: h5py.File | None = None
[docs]
def __enter__(self):
# Track the number of times the dataloader is used as a context manager
# to close the file only when the last context manager is closed.
if self._file is None:
self._file = h5py.File(
self._filename,
"r+" if self._writeable else "r",
libver="latest",
swmr=self._swmr,
)
try:
header = self.header
except LookupError:
log.warning(
"No matrix size found in the header."
" The header is probably missing."
)
self._shape = None
else:
matrixSize = header.encoding[0].encodedSpace.matrixSize
self._shape = matrixSize.x, matrixSize.y, matrixSize.z
self._level += 1
return self
[docs]
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
self._level -= 1
if self._level == 0 and self._file is not None:
self._file.close()
self._file = None
[docs]
def iter_frames(
self,
start: int | None = None,
stop: int | None = None,
step: int | None = None,
shot_dim: bool = False,
) -> Generator[tuple[int, NDArray[np.float32], NDArray[np.complex64]], None, None]:
"""Iterate over kspace frames of the dataset.
Parameters
----------
start : int, optional
Start index of the iteration.
stop : int, optional
Stop index of the iteration.
step: int, optional
Step of the iteration.
shot_dim: bool, optional
Return the data reshaped with the shot dimension first.
Yields
------
tuple[int, np.ndarray, np.ndarray]
The index of the frame, the trajectory and the kspace data.
"""
if start is None:
start = 0
if stop is None:
stop = self.n_frames
if step is None:
step = 1
with self:
for i in np.arange(start, stop, step):
yield i, *self.get_kspace_frame(i, shot_dim=shot_dim)
@overload
def get_kspace_frame(
self, idx: int
) -> tuple[NDArray[np.float32], NDArray[np.complex64]]:
# Get k-space frame trajectory/mask and data.
raise NotImplementedError()
###########################
# MRD interfaces methods #
###########################
[docs]
def _read_image(self, impath: str, imnum: int = 0) -> mrd.Image:
# create an image
# and fill with the header and attribute string for this image
im = mrd.Image(
self._dataset[impath]["header"][imnum],
self._dataset[impath]["attributes"][imnum],
)
# copy the data
# ismrmrd complex data is stored as pairs named real and imag
# TODO do we need to store and reset or the config local to the module?
cplxcfg = h5py.get_config().complex_names
h5py.get_config().complex_names = ("real", "imag")
im.data[:] = self._dataset[impath]["data"][imnum]
h5py.get_config().complex_names = cplxcfg
return im
######################
# dataset properties #
######################
@cached_property
def header(self) -> mrd.xsd.ismrmrdHeader:
"""Get the header from the mrd file."""
return mrd.xsd.CreateFromDocument(self._read_xml_header())
@property
def _dataset(self) -> h5py.Dataset:
"""Get MRD dataset."""
if self._file is not None:
return self._file[self._dataset_name]
else:
raise FileNotFoundError(
"Dataset not opened. Use the dataloader as a context manager."
)
@property
def n_frames(self) -> int:
"""Number of frames."""
return self.header.encoding[0].encodingLimits.repetition.maximum
@property
def shape(self) -> tuple[int, ...]:
"""Shape of the volume."""
try:
return self._shape
except AttributeError as exc:
raise RuntimeError(
"You need to run the dataloader as a context manager."
) from exc
@property
def n_coils(self) -> int:
"""Number of coils."""
return self.header.acquisitionSystemInformation.receiverChannels
@property
def n_acquisition(self) -> int:
"""Number of acquisition in the dataset."""
return self._dataset["data"].size
[docs]
def __len__(self):
return self.n_frames
@property
def n_sample(self) -> int:
"""Number of samples in a single acquisition."""
return self.header.encoding[0].encodingLimits.kspace_encoding_step_0.maximum
@property
def n_shots(self) -> int:
"""Number of samples in a single acquisition.
Notes
-----
for EPI this is the number of phase encoding lines in the EPI zigzag.
"""
return self.header.encoding[0].encodingLimits.kspace_encoding_step_1.maximum
@property
def engine_model(self) -> str:
"""Get the engine model."""
return self.header.userParameters.userParameterString[0].value
@property
def slice_2d(self) -> bool:
"""Is the acquisition run on 2D slices."""
return self.header.userParameters.userParameterString[1].value == "True"
#############
# Get data #
#############
[docs]
@functools.lru_cache # noqa
def get_phantom(self, imnum: int = 0) -> Phantom:
"""Load the phantom from the dataset."""
from ..core import Phantom
return Phantom.from_mrd_dataset(self, imnum)
@cached_property
def _all_waveform_infos(self) -> dict[int, dict]:
return parse_waveform_information(self.header)
[docs]
def get_dynamic(self, waveform_num: int) -> DynamicData:
"""Get dynamic data."""
waveform = self._read_waveform(waveform_num)
wave_info = self._all_waveform_infos[waveform.waveform_id]
from ..core import DynamicData
return DynamicData._from_waveform(waveform, wave_info)
[docs]
def get_all_dynamic(self) -> list[DynamicData]:
"""Get all dynamic data."""
from ..core import DynamicData
all_dyn_data = []
try:
n_waves = self._dataset["waveforms"].size
except Exception as e:
log.debug(e)
return []
for i in range(n_waves):
waveform = self._read_waveform(i)
wave_info = self._all_waveform_infos[waveform.waveform_id]
all_dyn_data.append(DynamicData._from_waveform(waveform, wave_info))
return all_dyn_data
[docs]
@functools.lru_cache # noqa
def get_sim_conf(self) -> SimConfig:
"""Parse the sim config."""
return parse_sim_conf(self.header)
[docs]
def _get_image_data(self, name: str, idx: int = 0) -> NDArray[np.complex64] | None:
try:
image = self._read_image(name, idx).data.astype(np.complex64)
except LookupError:
log.warning(f"No {name} found in the dataset.")
return None
return image
[docs]
def get_smaps(self, resample: bool = True) -> NDArray[np.complex64] | None:
"""Load the sensitivity maps from the dataset.
Parameters
----------
resample: bool
If resample is True (default), the sensitivity maps are resampled using the
affine transformation
describe in the phantom and sim_conf.
Returns
-------
None: if no Smaps is found
NDArray: The coils sensitivity maps
"""
sim_conf = self.get_sim_conf()
try:
smaps_im = self._read_image("smaps")
except KeyError:
return None
else:
smaps_affine = get_affine_from_image(smaps_im)
smaps = smaps_im.data
sim_conf = self.get_sim_conf()
if resample:
from snake.core.transform import apply_affine4d
smaps = apply_affine4d(
smaps,
smaps_affine,
sim_conf.fov.affine,
new_shape=sim_conf.fov.shape,
use_gpu=True,
).squeeze()
return smaps
[docs]
def get_coil_cov(self) -> NDArray | None:
"""Load the coil covariance from the dataset."""
return self._get_image_data("coil_cov")
[docs]
class CartesianFrameDataLoader(MRDLoader):
"""Load cartesian MRD files k-space frames iteratively.
Parameters
----------
filename: source for the MRD file.
Examples
--------
>>> for mask, kspace in CartesianFrameDataLoader("test.mrd"):
image = ifft(kspace)
"""
[docs]
def get_kspace_frame(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
"""Get the k-space frame."""
kspace = np.zeros((self.n_coils, *self.shape), dtype=np.complex64)
mask = np.zeros(self.shape, dtype=bool)
n_acq_per_frame = self.n_acquisition // self.n_frames
start = idx * n_acq_per_frame
end = (idx + 1) * n_acq_per_frame
# Do a single read of the dataset much faster !
acq = self._dataset["data"][start:end]
traj = acq["traj"].reshape(-1, 3)
data = acq["data"].view(np.complex64)
data = data.reshape(-1, self.n_shots, self.n_coils, self.n_sample)
data = np.moveaxis(data, 2, 0) # putting the coil dimension first.
data = data.reshape(self.n_coils, -1)
traj_locs: tuple = tuple(np.int32(traj.T)) # type: ignore
for c in range(self.n_coils):
kspace[c][traj_locs] = data[c]
mask[traj_locs] = True
return mask, kspace
[docs]
class NonCartesianFrameDataLoader(MRDLoader):
"""Non Cartesian Dataloader.
Iterate over the acquisition of the MRD file.
Examples
--------
>>> from mrinufft import get_operator
>>> dataloader = NonCartesianFrameDataLoader("test.mrd")
>>> for mask, kspace in data_loader:
... nufft = get_operator("finufft")(traj,
... shape=dataloader.shape, n_coils=dataloader.n_coils)
... image = nufft.adj_op(kspace)
"""
[docs]
def get_kspace_frame(
self, idx: int, shot_dim: bool = False
) -> tuple[NDArray[np.float32], NDArray[np.complex64]]:
"""Get the k-space frame and the associated trajectory.
Parameters
----------
idx : int
Index of the frame to get.
shot_dim : bool, optional
Return the data reshaped with the shot dimension first.
Returns
-------
np.ndarray
The trajectory.
np.ndarray
The kspace data.
"""
n_acq_per_frame = self.n_acquisition // self.n_frames
start = idx * n_acq_per_frame
end = (idx + 1) * n_acq_per_frame
# Do a single read of the dataset much faster !
acq = self._dataset["data"][start:end]
traj = acq["traj"].reshape(-1, 3)
data = acq["data"].view(np.complex64)
data = data.reshape(-1, self.n_coils, self.n_sample)
data = np.moveaxis(data, 1, 0)
data = data.reshape(self.n_coils, -1)
if shot_dim:
return (
traj.reshape(n_acq_per_frame, -1, 3),
data.reshape(self.n_coils, n_acq_per_frame, -1),
)
return traj, data
[docs]
def parse_sim_conf(header: mrd.xsd.ismrmrdHeader) -> SimConfig:
"""Parse the header to populate SimConfig from an MRD Header."""
from ..core import GreConfig, HardwareConfig, SimConfig, FOVConfig
n_coils = header.acquisitionSystemInformation.receiverChannels
field = header.acquisitionSystemInformation.systemFieldStrength_T
TR = header.sequenceParameters.TR[0]
TE = header.sequenceParameters.TE[0]
FA = header.sequenceParameters.flipAngle_deg[0]
seq = GreConfig(TR=TR, TE=TE, FA=FA)
caster = {
"gmax": float,
"smax": float,
"dwell_time_ms": float,
"max_sim_time": int,
"rng_seed": int,
"TR_eff": float,
}
parsed = {
up.name: caster[up.name](up.value)
for up in header.userParameters.userParameterDouble
if up.name in caster.keys()
}
if set(caster.keys()) != set(parsed.keys()):
raise ValueError(
f"Missing parameters {set(caster.keys()) - set(parsed.keys())}"
)
caster_str = {
"fov_config": str,
}
parsed_str = {
up.name: caster_str[up.name](up.value)
for up in header.userParameters.userParameterString
if up.name in caster_str.keys()
}
if set(caster_str.keys()) != set(parsed_str.keys()):
raise ValueError(
f"Missing parameters {set(caster_str.keys()) - set(parsed_str.keys())}"
)
hardware = HardwareConfig(
gmax=parsed.pop("gmax"),
smax=parsed.pop("smax"),
dwell_time_ms=parsed.pop("dwell_time_ms"),
n_coils=n_coils,
field=field,
)
seq.TR_eff = parsed.pop("TR_eff")
fov_mm = header.encoding[0].encodedSpace.fieldOfView_mm
fov_mm = (fov_mm.x, fov_mm.y, fov_mm.z)
shape = header.encoding[0].encodedSpace.matrixSize
shape = (shape.x, shape.y, shape.z)
sim_conf = SimConfig(
max_sim_time=parsed.pop("max_sim_time"),
seq=seq,
hardware=hardware,
rng_seed=parsed.pop("rng_seed"),
)
sim_conf.fov: FOVConfig = eval(parsed_str.pop("fov_config"))
return sim_conf
[docs]
def get_affine_from_image(image: mrd.Image) -> NDArray[np.float32]:
"""Extract the affine matrix from the header of an MRD.Image."""
# Affine matrix from the header
position = image._head.position
read_dir = image._head.read_dir
phase_dir = image._head.phase_dir
slice_dir = image._head.slice_dir
affine = np.eye(4, dtype=np.float32)
res = np.array(image._head.field_of_view) / np.array(image._head.matrix_size)
affine[:3, 3] = -position[0], -position[1], position[2]
affine[:3, 0] = (
-read_dir[0] * res[0],
-read_dir[1] * res[0],
read_dir[2] * res[0],
)
affine[:3, 1] = (
-phase_dir[0] * res[1],
-phase_dir[1] * res[1],
phase_dir[2] * res[1],
)
affine[:3, 2] = (
-slice_dir[0] * res[2],
-slice_dir[1] * res[2],
slice_dir[2] * res[2],
)
return affine