Source code for snake.core.simulation

"""SImulation base objects."""

from __future__ import annotations
import logging
from dataclasses import dataclass, field
import numpy as np
from snake._meta import ThreeInts, ThreeFloats
from numpy.typing import NDArray
from scipy.spatial.transform import Rotation as R
from .._meta import dataclass_repr_html

log = logging.getLogger(__name__)


[docs] @dataclass class GreConfig: """Gradient Recall Echo Sequence parameters.""" """TR: Repetition Time in ms. This is the time between two consecutive RF pulses.""" TR: float """TE: Echo Time in ms. This is the time between the RF pulse and the echo.""" TE: float """FA: Flip Angle in degrees. This is the angle of the RF pulse to the magnetization.""" FA: float _repr_html_ = dataclass_repr_html
[docs] def __post_init__(self) -> None: """Validate the parameters. And create a Effective TR.""" if self.TE >= self.TR: raise ValueError("TE must be less than TR.") if self.FA < 0 or self.FA > 180: raise ValueError("FA must be between 0 and 180 degrees.") if self.TR < 0 or self.TE < 0: raise ValueError("TR and TE must be positive.") self.TR_eff = ( self.TR ) # To be updated if needed. this will be used for the contrast calculation
[docs] @dataclass class HardwareConfig: """Scanner Hardware parameters.""" gmax: float = 40 smax: float = 200 n_coils: int = 8 dwell_time_ms: float = 1e-3 raster_time_ms: float = 5e-3 field: float = 3.0 _repr_html_ = dataclass_repr_html
default_hardware = HardwareConfig() default_gre = GreConfig(TR=50, TE=30, FA=15)
[docs] @dataclass class FOVConfig: """Field of View configuration. This class is used to define the FOV of the simulation. It uses the RAS convention and mm units. Default values are from the BrainWeb dataset. """ size: ThreeFloats = (181, 217, 181) """Size of the FOV in millimeter.""" offset: ThreeFloats = (-90.25, -126.25, -72.25) """distance (in mm) of the bottom left left voxel to magnet isocenter.""" angles: ThreeFloats = (0, 0, 0) """Euler Rotation Angles of the FOV in degrees""" res_mm: ThreeFloats = (1, 1, 1) """Resolution of the FOV in mm.""" _repr_html_ = dataclass_repr_html
[docs] def __post_init__(self) -> None: """Validate the parameters.""" if any(r <= 0 for r in self.res_mm) or any(s <= 0 for s in self.size): raise ValueError("resolution and size must be positive.") if any(abs(a) > 180 for a in self.angles): raise ValueError("Angles must be between -180 and 180 degrees.") if any(r > s for r, s in zip(self.res_mm, self.size, strict=False)): log.warning( "Resolution is higher than the size of the FOV," " setting to 1voxel thickness." ) self.size = tuple( max(r, s) for r, s in zip(self.res_mm, self.size, strict=False) )
[docs] @classmethod def from_affine(cls, affine: NDArray, size: ThreeFloats) -> FOVConfig: """Create a FOVConfig from an affine matrix.""" res_mm = np.sqrt(np.sum(affine[:3, :3] ** 2, axis=0)) offset = affine[:3, 3] angles = R.from_matrix(affine[:3, :3] / res_mm).as_euler("xyz", degrees=True) return cls(res_mm=res_mm, offset=offset, angles=angles, size=size)
@property def affine(self) -> NDArray[np.float32]: """Generate an affine matrix from the FOV configuration.""" affine = np.eye(4, dtype=np.float32) affine[:3, :3] = np.diag(self.res_mm) affine[:3, 3] = np.array(self.offset) rotation_matrix = R.from_euler("xyz", self.angles, degrees=True).as_matrix() affine[:3, :3] = affine[:3, :3] @ rotation_matrix return affine @property def shape(self) -> ThreeInts: """Shape of the associated array in voxels units.""" return tuple(round(s / r) for s, r in zip(self.size, self.res_mm, strict=False))
[docs] @dataclass class SimConfig: """All base configuration of a simulation.""" max_sim_time: float = 300 seq: GreConfig = field(default_factory=lambda: GreConfig(TR=50, TE=30, FA=15)) hardware: HardwareConfig = field(default_factory=lambda: HardwareConfig()) fov: FOVConfig = field(default_factory=lambda: FOVConfig()) # fov_mm: tuple[float, float, float] = (192.0, 192.0, 128.0) # shape: tuple[int, int, int] = (192, 192, 128) # Target reconstruction shape rng_seed: int = 19290506 _repr_html_ = dataclass_repr_html
[docs] def __post_init__(self) -> None: # To be compatible with frozen dataclass self.rng: np.random.Generator = np.random.default_rng(self.rng_seed)
@property def max_n_shots(self) -> int: """Maximum number of frames.""" return int(self.max_sim_time * 1000 / self.sim_tr_ms) @property def res_mm(self) -> ThreeFloats: """Voxel resolution in mm.""" return self.fov.res_mm @property def sim_tr_ms(self) -> float: """Simulation resolution in ms.""" return self.seq.TR @property def shape(self) -> ThreeInts: """Shape of the simulation.""" return self.fov.shape @property def fov_mm(self) -> ThreeFloats: """Size of the FOV in mm.""" return self.fov.size