Source code for mrinufft.trajectories.display

"""Display functions for trajectories."""

from __future__ import annotations

import itertools
from typing import Any

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
from numpy.typing import NDArray

from .utils import (

[docs] class displayConfig: """ A container class used to share arguments related to display. The values can be updated either directy (and permanently) or temporarily by using a context manager. Examples -------- >>> from mrinufft.trajectories.display import displayConfig >>> displayConfig.alpha 0.2 >>> with displayConfig(alpha=0.5): print(displayConfig.alpha) 0.5 >>> displayConfig.alpha 0.2 """ alpha: float = 0.2 """Transparency used for area plots, by default ``0.2``.""" linewidth: float = 2 """Width for lines or curves, by default ``2``.""" pointsize: int = 10 """Size for points used to show constraints, by default ``10``.""" fontsize: int = 18 """Font size for most labels and texts, by default ``18``.""" small_fontsize: int = 14 """Font size for smaller texts, by default ``14``.""" nb_colors: int = 10 """Number of colors to use in the color cycle, by default ``10``.""" palette: str = "tab10" """Name of the color palette to use, by default ``"tab10"``. This can be any of the matplotlib colormaps, or a list of colors.""" one_shot_color: str = "k" """Matplotlib color for the highlighted shot, by default ``"k"`` (black).""" one_shot_linewidth_factor: float = 2 """Factor to multiply the linewidth of the highlighted shot, by default ``2``.""" gradient_point_color: str = "r" """Matplotlib color for gradient constraint points, by default ``"r"`` (red).""" slewrate_point_color: str = "b" """Matplotlib color for slew rate constraint points, by default ``"b"`` (blue).""" def __init__(self, **kwargs: Any) -> None: # noqa ANN401 """Update the display configuration.""" self.update(**kwargs)
[docs] def update(self, **kwargs: Any) -> None: # noqa ANN401 """Update the display configuration.""" self._old_values = {} for key, value in kwargs.items(): self._old_values[key] = getattr(displayConfig, key) setattr(displayConfig, key, value)
[docs] def reset(self) -> None: """Restore the display configuration.""" for key, value in self._old_values.items(): setattr(displayConfig, key, value) delattr(self, "_old_values")
def __enter__(self) -> displayConfig: """Enter the context manager.""" return self def __exit__(self, *args: Any) -> None: # noqa ANN401 """Exit the context manager.""" self.reset()
[docs] @classmethod def get_colorlist(cls) -> list[str | NDArray]: """Extract a list of colors from a matplotlib palette. If the palette is continuous, the colors will be sampled from it. If its a categorical palette, the colors will be used in cycle. Parameters ---------- palette : str, or list of colors, or matplotlib colormap Name of the palette to use, or list of colors, or matplotlib colormap. nb_colors : int, optional Number of colors to extract from the palette. The default is -1, and the value will be read from displayConfig.nb_colors. Returns ------- colorlist : list of matplotlib colors. """ if isinstance(cls.palette, str): cm = mpl.colormaps[cls.palette] elif isinstance(cls.palette, mpl.colors.Colormap): cm = cls.palette elif isinstance(cls.palette, list): cm = colorlist = [] colors = getattr(cm, "colors", []) if 0 < len(colors) < cls.nb_colors: colorlist = [ c for _, c in zip(range(cls.nb_colors), itertools.cycle(cm.colors)) ] else: colorlist = cm(np.linspace(0, 1, cls.nb_colors)) return colorlist
############## # TICK UTILS # ############## def _setup_2D_ticks(figsize: float, fig: plt.Figure | None = None) -> plt.Axes: """Add ticks to 2D plot.""" if fig is None: fig = plt.figure(figsize=(figsize, figsize)) ax = fig if (isinstance(fig, plt.Axes)) else fig.subplots() ax.grid(True) ax.set_xticks([-KMAX, -KMAX / 2, 0, KMAX / 2, KMAX]) ax.set_yticks([-KMAX, -KMAX / 2, 0, KMAX / 2, KMAX]) ax.set_xlim((-KMAX, KMAX)) ax.set_ylim((-KMAX, KMAX)) ax.set_xlabel("kx", fontsize=displayConfig.fontsize) ax.set_ylabel("ky", fontsize=displayConfig.fontsize) return ax def _setup_3D_ticks(figsize: float, fig: plt.Figure | None = None) -> plt.Axes: """Add ticks to 3D plot.""" if fig is None: fig = plt.figure(figsize=(figsize, figsize)) ax = fig if (isinstance(fig, plt.Axes)) else fig.add_subplot(projection="3d") ax.set_xticks([-KMAX, -KMAX / 2, 0, KMAX / 2, KMAX]) ax.set_yticks([-KMAX, -KMAX / 2, 0, KMAX / 2, KMAX]) ax.set_zticks([-KMAX, -KMAX / 2, 0, KMAX / 2, KMAX]) ax.axes.set_xlim3d(left=-KMAX, right=KMAX) ax.axes.set_ylim3d(bottom=-KMAX, top=KMAX) ax.axes.set_zlim3d(bottom=-KMAX, top=KMAX) ax.set_box_aspect((2 * KMAX, 2 * KMAX, 2 * KMAX)) ax.set_xlabel("kx", fontsize=displayConfig.fontsize) ax.set_ylabel("ky", fontsize=displayConfig.fontsize) ax.set_zlabel("kz", fontsize=displayConfig.fontsize) return ax ###################### # TRAJECTORY DISPLAY # ######################
[docs] def display_2D_trajectory( trajectory: NDArray, figsize: float = 5, one_shot: bool | int = False, subfigure: plt.Figure | plt.Axes | None = None, show_constraints: bool = False, gmax: float = DEFAULT_GMAX, smax: float = DEFAULT_SMAX, constraints_order: int | str | None = None, **constraints_kwargs: Any, # noqa ANN401 ) -> plt.Axes: """Display 2D trajectories. Parameters ---------- trajectory : NDArray Trajectory to display. figsize : float, optional Size of the figure. one_shot : bool or int, optional State if a specific shot should be highlighted in bold black. If `True`, highlight the middle shot. If `int`, highlight the shot at that index. The default is `False`. subfigure: plt.Figure, plt.SubFigure or plt.Axes, optional The figure where the trajectory should be displayed. The default is `None`. show_constraints : bool, optional Display the points where the gradients and slew rates are above the `gmax` and `smax` limits, respectively. The default is `False`. gmax: float, optional Maximum constraint on the gradients in T/m. The default is `DEFAULT_GMAX`. smax: float, optional Maximum constraint on the slew rates in T/m/ms. The default is `DEFAULT_SMAX`. constraint_order: int, str, optional Norm order defining how the constraints are checked, typically 2 or `np.inf`, following the `numpy.linalg.norm` conventions on parameter `ord`. The default is None. **constraints_kwargs Acquisition parameters used to check on hardware constraints, following the parameter convention from `mrinufft.trajectories.utils.compute_gradients_and_slew_rates`. Returns ------- ax : plt.Axes Axes of the figure. """ # Setup figure and ticks Nc, Ns = trajectory.shape[:2] ax = _setup_2D_ticks(figsize, subfigure) colors = displayConfig.get_colorlist() # Display every shot for i in range(Nc): ax.plot( trajectory[i, :, 0], trajectory[i, :, 1], color=colors[i % displayConfig.nb_colors], linewidth=displayConfig.linewidth, ) # Display one shot in particular if requested if one_shot is not False: # If True or int # Select shot shot_id = Nc // 2 if one_shot is not True: # If int shot_id = one_shot # Highlight the shot in black ax.plot( trajectory[shot_id, :, 0], trajectory[shot_id, :, 1], color=displayConfig.one_shot_color, linewidth=displayConfig.one_shot_linewidth_factor * displayConfig.linewidth, ) # Point out violated constraints if requested if show_constraints: gradients, slews = compute_gradients_and_slew_rates( trajectory, **constraints_kwargs ) # Pad and compute norms gradients = np.linalg.norm( np.pad(gradients, ((0, 0), (1, 0), (0, 0))), axis=-1, ord=constraints_order ) slews = np.linalg.norm( np.pad(slews, ((0, 0), (2, 0), (0, 0))), axis=-1, ord=constraints_order ) # Check constraints trajectory = trajectory.reshape((-1, 2)) gradients = trajectory[np.where(gradients.flatten() > gmax)] slews = trajectory[np.where(slews.flatten() > smax)] # Scatter points with vivid colors ax.scatter( gradients[:, 0], gradients[:, 1], color=displayConfig.gradient_point_color, s=displayConfig.pointsize, ) ax.scatter( slews[:, 0], slews[:, 1], color=displayConfig.slewrate_point_color, s=displayConfig.pointsize, ) return ax
[docs] def display_3D_trajectory( trajectory: NDArray, nb_repetitions: int | None = None, figsize: float = 5, per_plane: bool = True, one_shot: bool | int = False, subfigure: plt.Figure | plt.Axes | None = None, show_constraints: bool = False, gmax: float = DEFAULT_GMAX, smax: float = DEFAULT_SMAX, constraints_order: int | str | None = None, **constraints_kwargs: Any, # noqa ANN401 ) -> plt.Axes: """Display 3D trajectories. Parameters ---------- trajectory : NDArray Trajectory to display. nb_repetitions : int Number of repetitions (planes, cones, shells, etc). The default is `None`. figsize : float, optional Size of the figure. per_plane : bool, optional If True, display the trajectory with a different color for each plane. one_shot : bool or int, optional State if a specific shot should be highlighted in bold black. If `True`, highlight the middle shot. If `int`, highlight the shot at that index. The default is `False`. subfigure: plt.Figure, plt.SubFigure or plt.Axes, optional The figure where the trajectory should be displayed. The default is `None`. show_constraints : bool, optional Display the points where the gradients and slew rates are above the `gmax` and `smax` limits, respectively. The default is `False`. gmax: float, optional Maximum constraint on the gradients in T/m. The default is `DEFAULT_GMAX`. smax: float, optional Maximum constraint on the slew rates in T/m/ms. The default is `DEFAULT_SMAX`. constraint_order: int, str, optional Norm order defining how the constraints are checked, typically 2 or `np.inf`, following the `numpy.linalg.norm` conventions on parameter `ord`. The default is None. **kwargs Acquisition parameters used to check on hardware constraints, following the parameter convention from `mrinufft.trajectories.utils.compute_gradients_and_slew_rates`. Returns ------- ax : plt.Axes Axes of the figure. """ # Setup figure and ticks, and handle 2D trajectories ax = _setup_3D_ticks(figsize, subfigure) if nb_repetitions is None: nb_repetitions = trajectory.shape[0] if trajectory.shape[-1] == 2: trajectory = np.concatenate( [trajectory, np.zeros((*(trajectory.shape[:2]), 1))], axis=-1 ) trajectory = trajectory.reshape((nb_repetitions, -1, trajectory.shape[-2], 3)) Nc, Ns = trajectory.shape[1:3] colors = displayConfig.get_colorlist() # Display every shot for i in range(nb_repetitions): for j in range(Nc): ax.plot( trajectory[i, j, :, 0], trajectory[i, j, :, 1], trajectory[i, j, :, 2], color=colors[(i + j * (not per_plane)) % displayConfig.nb_colors], linewidth=displayConfig.linewidth, ) # Display one shot in particular if requested if one_shot is not False: # If True or int trajectory = trajectory.reshape((-1, Ns, 3)) # Select shot shot_id = Nc // 2 if one_shot is not True: # If int shot_id = one_shot # Highlight the shot in black ax.plot( trajectory[shot_id, :, 0], trajectory[shot_id, :, 1], trajectory[shot_id, :, 2], color=displayConfig.one_shot_color, linewidth=displayConfig.one_shot_linewidth_factor * displayConfig.linewidth, ) trajectory = trajectory.reshape((-1, Nc, Ns, 3)) # Point out violated constraints if requested if show_constraints: gradients, slewrates = compute_gradients_and_slew_rates( trajectory.reshape((-1, Ns, 3)), **constraints_kwargs ) # Pad and compute norms gradients = np.linalg.norm( np.pad(gradients, ((0, 0), (1, 0), (0, 0))), axis=-1, ord=constraints_order ) slewrates = np.linalg.norm( np.pad(slewrates, ((0, 0), (2, 0), (0, 0))), axis=-1, ord=constraints_order ) # Check constraints gradients = trajectory.reshape((-1, 3))[np.where(gradients.flatten() > gmax)] slewrates = trajectory.reshape((-1, 3))[np.where(slewrates.flatten() > smax)] # Scatter points with vivid colors ax.scatter( *(gradients.T), color=displayConfig.gradient_point_color, s=displayConfig.pointsize, ) ax.scatter( *(slewrates.T), color=displayConfig.slewrate_point_color, s=displayConfig.pointsize, ) return ax
#################### # GRADIENT DISPLAY # ####################
[docs] def display_gradients_simply( trajectory: NDArray, shot_ids: tuple[int, ...] = (0,), figsize: float = 5, fill_area: bool = True, show_signal: bool = True, uni_signal: str | None = "gray", uni_gradient: str | None = None, subfigure: plt.Figure | None = None, ) -> tuple[plt.Axes]: """Display gradients based on trajectory of any dimension. Parameters ---------- trajectory : NDArray Trajectory to display. shot_ids : tuple[int, ...], optional Indices of the shots to display. The default is `[0]`. figsize : float, optional Size of the figure. fill_area : bool, optional Fills the area under the curve for improved visibility and representation of the integral, aka trajectory. The default is `True`. show_signal : bool, optional Show an additional illustration of the signal as the modulated distance to the center. The default is `True`. uni_signal : str or None, optional Define whether the signal should be represented by a unique color given as argument or just by the default color cycle when `None`. The default is `"gray"`. uni_signal : str or None, optional Define whether the gradients should be represented by a unique color given as argument or just by the default color cycle when `None`. The default is `None`. subfigure: plt.Figure, optional The figure where the trajectory should be displayed. The default is `None`. Returns ------- axes : plt.Axes Axes of the figure. """ # Setup figure and labels Nd = trajectory.shape[-1] if subfigure is None: fig = plt.figure(figsize=(figsize, figsize * (Nd + show_signal) / Nd)) else: fig = subfigure axes = fig.subplots(Nd + show_signal, 1) for i, ax in enumerate(axes[:Nd]): ax.set_ylabel("G{}".format(["x", "y", "z"][i]), fontsize=displayConfig.fontsize) axes[-1].set_xlabel("Time", fontsize=displayConfig.fontsize) # Setup axes ticks for ax in axes: ax.grid(True) ax.xaxis.set_tick_params(labelbottom=False) ax.yaxis.set_tick_params(labelleft=False) # Plot the curves for each axis gradients = np.diff(trajectory, axis=1) vmax = 1.1 * np.max(np.abs(gradients[shot_ids, ...])) x_axis = np.arange(gradients.shape[1]) colors = displayConfig.get_colorlist() for j, s_id in enumerate(shot_ids): for i, ax in enumerate(axes[:Nd]): ax.set_ylim((-vmax, vmax)) color = ( uni_gradient if uni_gradient is not None else colors[j % displayConfig.nb_colors] ) ax.plot(x_axis, gradients[s_id, ..., i], color=color) if fill_area: ax.fill_between( x_axis, gradients[s_id, ..., i], alpha=displayConfig.alpha, color=color, ) # Return axes alone if not show_signal: return axes # Show signal as modulated distance to center distances = np.linalg.norm(trajectory[shot_ids, 1:-1], axis=-1) distances = np.tile(distances.reshape((len(shot_ids), -1, 1)), (1, 1, 10)) signal = 1 - distances.reshape((len(shot_ids), -1)) / np.max(distances) signal = ( signal * np.exp(2j * np.pi * figsize / 100 * np.arange(signal.shape[1])) ).real signal = signal * np.abs(signal) ** 3 colors = displayConfig.get_colorlist() # Show signal for each requested shot axes[-1].set_ylim((-1, 1)) axes[-1].set_ylabel("Signal", fontsize=displayConfig.fontsize) for j in range(len(shot_ids)): color = ( uni_signal if (uni_signal is not None) else colors[j % displayConfig.nb_colors] ) axes[-1].plot(np.arange(signal.shape[1]), signal[j], color=color) return axes
[docs] def display_gradients( trajectory: NDArray, shot_ids: tuple[int, ...] = (0,), figsize: float = 5, fill_area: bool = True, show_signal: bool = True, uni_signal: str | None = "gray", uni_gradient: str | None = None, subfigure: plt.Figure | plt.Axes | None = None, show_constraints: bool = False, gmax: float = DEFAULT_GMAX, smax: float = DEFAULT_SMAX, constraints_order: int | str | None = None, raster_time: float = DEFAULT_RASTER_TIME, **constraints_kwargs: Any, # noqa ANN401 ) -> tuple[plt.Axes]: """Display gradients based on trajectory of any dimension. Parameters ---------- trajectory : NDArray Trajectory to display. shot_ids : list of int Indices of the shots to display. The default is `(0,)`. figsize : float, optional Size of the figure. fill_area : bool, optional Fills the area under the curve for improved visibility and representation of the integral, aka trajectory. The default is `True`. show_signal : bool, optional Show an additional illustration of the signal as the modulated distance to the center. The default is `True`. uni_signal : str or None, optional Define whether the signal should be represented by a unique color given as argument or just by the default color cycle when `None`. The default is `"gray"`. uni_signal : str or None, optional Define whether the gradients should be represented by a unique color given as argument or just by the default color cycle when `None`. The default is `None`. subfigure: plt.Figure or plt.SubFigure, optional The figure where the trajectory should be displayed. The default is `None`. show_constraints : bool, optional Display the points where the gradients and slew rates are above the `gmax` and `smax` limits, respectively. The default is `False`. gmax: float, optional Maximum constraint on the gradients in T/m. The default is `DEFAULT_GMAX`. smax: float, optional Maximum constraint on the slew rates in T/m/ms. The default is `DEFAULT_SMAX`. constraint_order: int, str, optional Norm order defining how the constraints are checked, typically 2 or `np.inf`, following the `numpy.linalg.norm` conventions on parameter `ord`. The default is None. raster_time: float, optional Amount of time between the acquisition of two consecutive samples in ms. The default is `DEFAULT_RASTER_TIME`. **constraints_kwargs Acquisition parameters used to check on hardware constraints, following the parameter convention from `mrinufft.trajectories.utils.compute_gradients_and_slew_rates`. Returns ------- axes : plt.Axes Axes of the figure. """ # Initialize figure with a simpler version axes = display_gradients_simply( trajectory, shot_ids, figsize, fill_area, show_signal, uni_signal, uni_gradient, subfigure, ) # Setup figure and labels Nd = trajectory.shape[-1] for i, ax in enumerate(axes[:Nd]): ax.set_ylabel( "G{} (mT/m)".format(["x", "y", "z"][i]), fontsize=displayConfig.small_fontsize, ) axes[-1].set_xlabel("Time (ms)", fontsize=displayConfig.small_fontsize) if show_signal: axes[-1].set_ylabel("Signal (a.u.)", fontsize=displayConfig.small_fontsize) # Update axis ticks with rescaled values for i, ax in enumerate(axes): # Update xtick labels with time values if ax == axes[-1]: ax.xaxis.set_tick_params(labelbottom=True) ticks = ax.get_xticks() scale = (0.1 if (show_signal and ax == axes[-1]) else 1) * raster_time locator = mticker.FixedLocator(ticks) formatter = mticker.FixedFormatter(np.around(scale * ticks, 2)) ax.xaxis.set_major_locator(locator) ax.xaxis.set_major_formatter(formatter) # Update ytick labels with gradient values ax.yaxis.set_tick_params(labelleft=True) ticks = ax.get_yticks() idx = min(i, Nd - 1) norms = np.diff(trajectory[:1, :2, idx]).squeeze() norms = np.where(norms != 0, norms, 1) scale = ( convert_trajectory_to_gradients( trajectory[:1, :2], raster_time=raster_time, **constraints_kwargs )[0][0, 0, idx] / norms ) scale = 1e3 * scale # Convert from T/m to mT/m locator = mticker.FixedLocator(ticks) formatter = mticker.FixedFormatter(np.around(scale * ticks, 1)) if not show_signal or ax != axes[-1]: ax.yaxis.set_major_locator(locator) ax.yaxis.set_major_formatter(formatter) # Move on with constraints if requested if not show_constraints: return axes # Compute true gradients and slew rates gradients, slewrates = compute_gradients_and_slew_rates( trajectory[shot_ids, :], **constraints_kwargs ) gradients = np.linalg.norm(gradients, axis=-1, ord=constraints_order) slewrates = np.linalg.norm(slewrates, axis=-1, ord=constraints_order) slewrates = np.pad(slewrates, ((0, 0), (0, 1))) # Point out hardware constraint violations for ax in axes[:Nd]: pts = np.where(gradients > gmax) ax.scatter( pts, np.zeros_like(pts), color=displayConfig.gradient_point_color, s=displayConfig.pointsize, ) pts = np.where(slewrates > smax) ax.scatter( pts, np.zeros_like(pts), color=displayConfig.slewrate_point_color, s=displayConfig.pointsize, ) return axes