"""Module to create phantom for simulation."""
from __future__ import annotations
import contextlib
import json
import logging
import os
from pathlib import Path
from collections.abc import Generator
from copy import deepcopy
from dataclasses import dataclass, field
from multiprocessing.managers import SharedMemoryManager
from multiprocessing.shared_memory import SharedMemory
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from _typeshed import GenericPath
from snake.mrd_utils.loader import MRDLoader
import ismrmrd as mrd
import numpy as np
from nibabel.nifti1 import Nifti1Image
from numpy.typing import NDArray
from snake._meta import ThreeFloats, ThreeInts
from ..smaps import get_smaps
from ..parallel import ArrayProps, array_from_shm, array_to_shm, run_parallel
from ..simulation import SimConfig
from .contrast import _contrast_gre
from .utils import PropTissueEnum, TissueFile, resize_tissues
from ..transform import apply_affine4d, serialize_array, unserialize_array
log = logging.getLogger(__name__)
SNAKE_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "snake-fmri")
[docs]
@dataclass
class Phantom:
"""A Phantom consist of all spatial maps that are used in the simulation.
It is a dataclass that contains the tissue masks, properties, labels, and
spatial maps of the phantom.
The tissue masks are a 3D array with the shape (n_tissues, x, y, z), where
n_tissues is the number of tissues.
The properties are a 2D array with the shape (n_tissues, n_properties), where
n_properties is the number of properties.
The labels are a 1D array with the shape (n_tissues,) containing the names
of the tissues.
The sensitivity maps are a 4D array with the shape (n_coils, x, y, z), where
n_coils is the number of coils.
The affine matrix is a 2D array with the shape (4, 4) containing the affine
transformation matrix.
"""
# TODO Add field map inhomogeneity in the phantom
name: str
masks: NDArray[np.float32]
labels: NDArray[np.string_]
props: NDArray[np.float32]
smaps: NDArray[np.complex64] | None = None
affine: NDArray[np.float32] = field(
default_factory=lambda: np.eye(4, dtype=np.float32)
)
[docs]
def add_tissue(
self,
tissue_name: str,
mask: NDArray[np.float32],
props: NDArray[np.float32],
phantom_name: str | None = None,
) -> Phantom:
"""Add a tissue to the phantom. Creates a new Phantom object."""
masks = np.concatenate((self.masks, mask[None, ...]), axis=0)
labels = np.concatenate((self.labels, np.array([tissue_name])))
props = np.concatenate((self.props, props), axis=0)
return Phantom(
phantom_name or self.name, masks, labels, props, smaps=self.smaps
)
@property
def labels_idx(self) -> dict[str, int]:
"""Get the index of the labels."""
return {label: i for i, label in enumerate(self.labels)}
[docs]
def make_smaps(
self, n_coils: int = None, sim_conf: SimConfig = None, antenna: str = "birdcage"
) -> None:
"""Get coil sensitivity maps for the phantom."""
if n_coils is None and sim_conf is not None:
n_coils = sim_conf.hardware.n_coils
elif sim_conf is None and n_coils is None:
raise ValueError("Either n_coils or sim_conf must be provided.")
if n_coils == 1:
log.warning("Only one coil, no need for smaps.")
elif n_coils > 1 and self.smaps is None:
self.smaps = get_smaps(self.anat_shape, n_coils=n_coils, antenna=antenna)
log.debug(f"Created smaps for {n_coils} coils.")
elif self.smaps is not None:
log.warning("Smaps already exists.")
[docs]
@classmethod
def from_brainweb(
cls,
sub_id: int,
sim_conf: SimConfig,
tissue_file: str | TissueFile = TissueFile.tissue_1T5,
tissue_select: list[str] | None = None,
tissue_ignore: list[str] | None = None,
output_res: float | ThreeFloats = 0.5,
cache_dir: str | None = None,
) -> Phantom:
"""Get the Brainweb Phantom.
Parameters
----------
sub_id: int
Subject ID of the brainweb dataset.
sim_conf: SimConfig
Simulation configuration.
tissue_file: str
File with the tissue properties.
tissue_select: list[str]
List of tissues to select.
tissue_ignore: list[str]
List of tissues to ignore.
output_res: float
Resolution of the output phantom.
Returns
-------
Phantom
The phantom object.
"""
if cache_dir is None:
cache_dir = os.environ.get("SNAKE_CACHE_DIR", SNAKE_CACHE_DIR)
phantom_hash = hash(
json.dumps(
dict(
sub_id=sub_id,
tissue_file=tissue_file,
tissue_select=tissue_select,
tissue_ignore=tissue_ignore,
output_res=output_res,
)
)
)
phantom_file = None
if cache_dir is not False:
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
phantom_file = os.path.join(cache_dir, f"phantom_{phantom_hash}.npy")
if os.path.exists(phantom_file):
log.debug(f"Loading phantom from cache: {phantom_file}")
return cls.from_mrd_dataset(phantom_file)
from brainweb_dl import BrainWebTissuesV2, get_mri
if tissue_ignore and tissue_select:
raise ValueError("Only one of tissue_select or tissue_ignore can be used.")
if tissue_select:
tissue_select = [t.lower() for t in tissue_select]
if tissue_ignore:
tissue_ignore = [t.lower() for t in tissue_ignore]
# TODO: Add A caching for the phantom. Use the SNAKE_CACHE_DIR env variable
tissues_mask, affine = get_mri(sub_id, contrast="fuzzy", with_affine=True)
tissues_mask = tissues_mask.astype(np.float32)
affine = affine.astype(np.float32)
tissues_mask = np.ascontiguousarray(tissues_mask.T)
tissues_list = []
try:
if isinstance(tissue_file, TissueFile):
tissue_file = tissue_file.value
else:
tissue_file = TissueFile[tissue_file].value
except ValueError as exc:
if not os.path.exists(tissue_file):
raise FileNotFoundError(f"File {tissue_file} does not exist.") from exc
finally:
tissue_file = str(tissue_file)
log.info(f"Using tissue file:{tissue_file} ")
with open(tissue_file) as f:
lines = f.readlines()
select = []
for line in lines[1:]:
vals = line.split(",")
t1, t2, t2s, rho, chi = map(np.float32, vals[1:])
name = vals[0]
t = (name, t1, t2, t2s, rho, chi)
if (
(tissue_select and name in tissue_select)
or (tissue_ignore and name not in tissue_ignore)
or (not tissue_select and not tissue_ignore)
):
tissues_list.append(t)
select.append(BrainWebTissuesV2[name.upper()])
log.info(
f"Selected tissues: {select}, {[t[0] for t in tissues_list]}",
)
if len(tissues_list) == 0:
raise ValueError("No tissues selected")
tissues_mask = tissues_mask[select]
shape = tissues_mask.shape
#
# TODO: Use the sim shape properly.
if output_res != 0.5:
if isinstance(output_res, int | float):
output_res = [output_res] * 3
z = np.array([0.5, 0.5, 0.5]) / np.array(output_res)
new_shape = (shape[0], *np.round(np.array(shape[1:]) * z).astype(int))
tissue_resized = np.zeros(new_shape, dtype=np.float32)
for i in range(3):
affine[i, i] = output_res[i]
run_parallel(
resize_tissues,
tissues_mask,
tissue_resized,
parallel_axis=0,
z=tuple(z),
)
tissues_mask = tissue_resized
smaps = None
if sim_conf.hardware.n_coils > 1:
smaps = get_smaps(
tissues_mask.shape[1:],
n_coils=sim_conf.hardware.n_coils,
)
phantom = cls(
"brainweb-{sub_id:02d}",
tissues_mask,
labels=np.array([t[0] for t in tissues_list]),
props=np.array([t[1:] for t in tissues_list]),
smaps=smaps,
affine=affine,
)
if phantom_file:
log.debug(f"Saving phantom to cache: {phantom_file}")
phantom.to_mrd_dataset(phantom_file)
return phantom
[docs]
@classmethod
def from_mrd_dataset(
cls, dataset: MRDLoader | os.PathLike, imnum: int = 0
) -> Phantom:
"""Load the phantom from a mrd dataset."""
from snake.mrd_utils.loader import get_affine_from_image, MRDLoader
if not isinstance(dataset, MRDLoader):
dataset = MRDLoader(dataset)
with dataset:
image = dataset._read_image("phantom", imnum)
name = image.meta.pop("name")
labels = np.array(image.meta["labels"].split(","))
props = unserialize_array(image.meta["props"])
affine = get_affine_from_image(image)
# smaps
try:
smaps = dataset._read_image("smaps", imnum).data
except LookupError:
smaps = None
return cls(
masks=image.data,
labels=labels,
props=props,
name=name,
affine=affine,
smaps=smaps,
)
[docs]
def to_mrd_dataset(self, dataset: mrd.Dataset | GenericPath) -> mrd.Dataset:
"""Add the phantom as an image to the dataset."""
# Create the image
if not isinstance(dataset, mrd.Dataset):
dataset = mrd.Dataset(dataset, create_if_needed=True)
meta_sr = mrd.Meta(
{
"name": self.name,
"labels": f"{','.join(self.labels)}",
"props": serialize_array(self.props),
"affine": serialize_array(self.affine),
}
).serialize()
# Convert the affine matrix to position, field of view, etc.
offsets = self.affine[:3, 3]
position = (-offsets[0], -offsets[1], offsets[2])
read_dir = self.affine[:3, 0] / self.affine[0, 0]
read_dir = (-read_dir[0], -read_dir[1], read_dir[2])
phase_dir = self.affine[:3, 1] / self.affine[1, 1]
phase_dir = (-phase_dir[0], -phase_dir[1], phase_dir[2])
slice_dir = self.affine[:3, 2] / self.affine[2, 2]
slice_dir = (-slice_dir[0], -slice_dir[1], slice_dir[2])
fov_mm = tuple(np.float32(np.array(self.anat_shape) * np.diag(self.affine)[:3]))
# Add the phantom data
dataset.append_image(
"phantom",
mrd.image.Image(
head=mrd.image.ImageHeader(
matrix_size=self.anat_shape,
field_of_view=fov_mm,
position=position,
phase_dir=phase_dir,
slice_dir=slice_dir,
read_dir=read_dir,
channels=self.n_tissues,
acquisition_time_stamp=0,
attribute_string_len=len(meta_sr),
),
data=self.masks,
attribute_string=meta_sr,
),
)
# Add the smaps
if self.smaps is not None:
dataset.append_image(
"smaps",
mrd.image.Image(
head=mrd.image.ImageHeader(
matrix_size=self.anat_shape,
field_of_view=fov_mm,
position=position,
phase_dir=phase_dir,
slice_dir=slice_dir,
read_dir=read_dir,
channels=len(self.smaps),
acquisition_time_stamp=0,
),
data=self.smaps,
),
)
return dataset
[docs]
@classmethod
@contextlib.contextmanager
def from_shared_memory(
cls,
name: str,
mask_prop: ArrayProps,
properties_prop: ArrayProps,
label_prop: ArrayProps,
smaps_prop: ArrayProps,
affine_prop: ArrayProps,
) -> Generator[Phantom, None, None]:
"""Give access the tissue masks and properties in shared memory."""
with array_from_shm(
mask_prop, label_prop, properties_prop, smaps_prop, affine_prop
) as arrs:
yield cls(name, *arrs)
[docs]
def in_shared_memory(
self, manager: SharedMemoryManager
) -> tuple[
tuple[str, ArrayProps, ArrayProps, ArrayProps, ArrayProps | None, ArrayProps],
tuple[
SharedMemory, SharedMemory, SharedMemory, SharedMemory | None, SharedMemory
],
]:
"""Add a copy of the phantom in shared memory."""
tissue_mask, _, tisue_mask_smm = array_to_shm(self.masks, manager)
tissue_props, _, tissue_prop_smm = array_to_shm(self.props, manager)
labels, _, labels_sm = array_to_shm(self.labels, manager)
affine, _, affine_sm = array_to_shm(self.affine, manager)
if self.smaps is not None:
smaps, _, smaps_sm = array_to_shm(self.smaps, manager)
else:
smaps, smaps_sm = None, None
return (
(self.name, tissue_mask, tissue_props, labels, smaps, affine),
(tisue_mask_smm, tissue_prop_smm, labels_sm, smaps_sm, affine_sm),
)
[docs]
def masks2nifti(self) -> Nifti1Image:
"""Return the masks of the phantom as a Nifti object."""
return Nifti1Image(
self.masks,
affine=self.affine,
extra={"props": self.props, "labels": self.labels},
)
[docs]
def smaps2nifti(self) -> Nifti1Image:
"""Return the smaps as a Nifti object."""
if self.smaps is not None:
return Nifti1Image(self.smaps, affine=self.affine)
else:
raise ValueError("No Smaps to convert.")
[docs]
def to_nifti(
self, filename: str | GenericPath = None
) -> tuple[GenericPath, GenericPath | None]:
"""Save the phantom as a pair of niftis file."""
mask_nifti = self.masks2nifti()
smaps_nifti = None
smaps_filename = None
if self.smaps is not None:
smaps_nifti = self.smaps2nifti()
smaps_filename = Path(str(filename).replace(".nii", "_smaps.nii"))
if not filename:
return filename, smaps_nifti
mask_nifti.to_filename(filename)
if self.smaps is not None:
smaps_nifti.to_filename(smaps_filename)
return filename, smaps_filename
[docs]
@classmethod
def from_nifti(
cls,
mask_nifti: Nifti1Image | GenericPath,
props: NDArray[np.float32] = None,
labels: NDArray[np.string_] = None,
smaps: Nifti1Image | GenericPath | None = None,
) -> Phantom:
"""Create a phantom from nifti files."""
if not isinstance(mask_nifti, Nifti1Image):
mask_nifti_name = mask_nifti
mask_nifti = Nifti1Image.from_filename(mask_nifti)
else:
mask_nifti_name = mask_nifti.get_filename() or "from_nifti"
if smaps and not isinstance(smaps, Nifti1Image):
smaps_nifti = Nifti1Image.from_filename(smaps)
else:
smaps_nifti = smaps
affine = mask_nifti.affine
if props is None:
props = mask_nifti.extra["props"]
if labels is None:
labels = mask_nifti.extra["labels"]
masks = np.asarray(mask_nifti.get_fdata()).astype(np.float32)
smaps = None
if smaps_nifti:
smaps = np.asarray(smaps_nifti.get_fdata()).astype(np.complex64)
return cls(
name=mask_nifti_name,
masks=masks,
labels=labels,
props=props,
smaps=smaps,
affine=affine,
)
[docs]
def contrast(
self,
*,
TR: float | None = None,
TE: float | None = None,
FA: float | None = None,
sequence: Literal["GRE"] = "GRE",
sim_conf: SimConfig | None = None,
resample: bool = True,
aggregate: bool = True,
use_gpu: bool = True,
) -> NDArray[np.float32]:
"""Compute the contrast of the phantom for a given sequence.
Parameters
----------
TR: float
TE: float
FA: float
sim_conf: SimConfig
Other way to provide sequence parameters
aggregate: bool, optional default=True
Sum all the tissues contrast for getting a single image.
sequence="GRE"
Default value, no other value is currently supported.
Results
-------
NDArray
The contrast of the tissues.
"""
if resample:
if sim_conf is None:
raise ValueError("sim_conf must be provided for resampling.")
affine = sim_conf.fov.affine
shape = sim_conf.fov.shape
self = self.resample(affine, shape, use_gpu=use_gpu)
if sim_conf is not None:
TR = sim_conf.seq.TR_eff # Here we use the effective TR.
TE = sim_conf.seq.TE
FA = sim_conf.seq.FA
if sim_conf is None and TR is None and TE is None and FA is None:
raise ValueError("Missing either sim_conf or TR,TE,FA")
if sequence.upper() == "GRE":
contrasts = _contrast_gre(self.props, TR=TR, TE=TE, FA=FA)
else:
raise NotImplementedError("Contrast not implemented.")
if aggregate:
ret = np.zeros(self.anat_shape, dtype=np.float32)
for c, m in zip(contrasts, self.masks, strict=False):
ret += c * m
return ret
else:
return self.masks * contrasts[(..., *([None] * len(self.anat_shape)))]
[docs]
def resample(
self,
new_affine: NDArray,
new_shape: ThreeInts,
use_gpu: bool = False,
**kwargs: Any,
) -> Phantom:
"""Resample the phantom to a new shape and affine matrix.
Parameters
----------
new_affine : NDArray
The new affine matrix.
new_shape : ThreeInts
The new shape of the phantom.
use_gpu : bool, optional
Use the GPU for the resampling, by default False.
"""
new_masks = apply_affine4d(
self.masks,
old_affine=self.affine,
new_affine=new_affine,
new_shape=new_shape,
use_gpu=use_gpu,
**kwargs,
)
new_smaps = None
if self.smaps is not None:
new_smaps = apply_affine4d(
self.smaps,
old_affine=self.affine,
new_affine=new_affine,
new_shape=new_shape,
use_gpu=use_gpu,
**kwargs,
)
return Phantom(
self.name,
new_masks,
self.labels,
self.props,
smaps=new_smaps,
affine=new_affine,
)
@property
def anat_shape(self) -> tuple[int, ...]:
"""Get the shape of the base volume."""
return self.masks.shape[1:]
@property
def n_tissues(self) -> int:
"""Get the number of tissues."""
return len(self.masks)
[docs]
def __repr__(self):
ret = f"Phantom[{self.name}]: {self.props.shape}\n"
ret += f"{'tissue name':14s}" + "".join(
f"{str(prop):4s}" for prop in PropTissueEnum
)
ret += "\n"
for i, tissue_name in enumerate(self.labels):
props = self.props[i]
ret += (
f"{tissue_name:14s}"
+ "".join(f"{props[p]:4}" for p in PropTissueEnum.__members__.values())
+ "\n"
)
return ret
[docs]
def __deepcopy__(self, memo: Any) -> Phantom:
"""Create a copy of the phantom."""
return Phantom(
name=self.name,
masks=deepcopy(self.masks, memo),
labels=deepcopy(self.labels, memo),
props=deepcopy(self.props, memo),
smaps=deepcopy(self.smaps, memo),
affine=deepcopy(self.affine, memo),
)
[docs]
def copy(self) -> Phantom:
"""Return deep copy of the Phantom."""
return deepcopy(self)