"""Display functions for trajectories."""
import itertools
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
from .utils import (
DEFAULT_GMAX,
DEFAULT_RASTER_TIME,
DEFAULT_SMAX,
KMAX,
compute_gradients_and_slew_rates,
convert_trajectory_to_gradients,
)
[docs]
class displayConfig:
"""
A container class used to share arguments related to display.
The values can be updated either directy (and permanently) or temporarily by using
a context manager.
Examples
--------
>>> from mrinufft.trajectories.display import displayConfig
>>> displayConfig.alpha
0.2
>>> with displayConfig(alpha=0.5):
print(displayConfig.alpha)
0.5
>>> displayConfig.alpha
0.2
"""
alpha: float = 0.2
"""Transparency used for area plots, by default ``0.2``."""
linewidth: float = 2
"""Width for lines or curves, by default ``2``."""
pointsize: int = 10
"""Size for points used to show constraints, by default ``10``."""
fontsize: int = 18
"""Font size for most labels and texts, by default ``18``."""
small_fontsize: int = 14
"""Font size for smaller texts, by default ``14``."""
nb_colors = 10
"""Number of colors to use in the color cycle, by default ``10``."""
palette: str = "tab10"
"""Name of the color palette to use, by default ``"tab10"``.
This can be any of the matplotlib colormaps, or a list of colors."""
one_shot_color: str = "k"
"""Matplotlib color for the highlighted shot, by default ``"k"`` (black)."""
gradient_point_color: str = "r"
"""Matplotlib color for gradient constraint points, by default ``"r"`` (red)."""
slewrate_point_color: str = "b"
"""Matplotlib color for slew rate constraint points, by default ``"b"`` (blue)."""
def __init__(self, **kwargs):
"""Update the display configuration."""
self.update(**kwargs)
[docs]
def update(self, **kwargs):
"""Update the display configuration."""
self._old_values = {}
for key, value in kwargs.items():
self._old_values[key] = getattr(displayConfig, key)
setattr(displayConfig, key, value)
[docs]
def reset(self):
"""Restore the display configuration."""
for key, value in self._old_values.items():
setattr(displayConfig, key, value)
delattr(self, "_old_values")
def __enter__(self):
"""Enter the context manager."""
return self
def __exit__(self, *args):
"""Exit the context manager."""
self.reset()
[docs]
@classmethod
def get_colorlist(cls):
"""Extract a list of colors from a matplotlib palette.
If the palette is continuous, the colors will be sampled from it.
If its a categorical palette, the colors will be used in cycle.
Parameters
----------
palette : str, or list of colors, or matplotlib colormap
Name of the palette to use, or list of colors, or matplotlib colormap.
nb_colors : int, optional
Number of colors to extract from the palette.
The default is -1, and the value will be read from displayConfig.nb_colors.
Returns
-------
colorlist : list of matplotlib colors.
"""
if isinstance(cls.palette, str):
cm = mpl.colormaps[cls.palette]
elif isinstance(cls.palette, mpl.colors.Colormap):
cm = cls.palette
elif isinstance(cls.palette, list):
cm = mpl.cm.ListedColormap(cls.palette)
colorlist = []
colors = getattr(cm, "colors", [])
if 0 < len(colors) < cls.nb_colors:
colorlist = [
c for _, c in zip(range(cls.nb_colors), itertools.cycle(cm.colors))
]
else:
colorlist = cm(np.linspace(0, 1, cls.nb_colors))
return colorlist
##############
# TICK UTILS #
##############
def _setup_2D_ticks(figsize, fig=None):
"""Add ticks to 2D plot."""
if fig is None:
fig = plt.figure(figsize=(figsize, figsize))
ax = fig if (isinstance(fig, plt.Axes)) else fig.subplots()
ax.grid(True)
ax.set_xticks([-KMAX, -KMAX / 2, 0, KMAX / 2, KMAX])
ax.set_yticks([-KMAX, -KMAX / 2, 0, KMAX / 2, KMAX])
ax.set_xlim((-KMAX, KMAX))
ax.set_ylim((-KMAX, KMAX))
ax.set_xlabel("kx", fontsize=displayConfig.fontsize)
ax.set_ylabel("ky", fontsize=displayConfig.fontsize)
return ax
def _setup_3D_ticks(figsize, fig=None):
"""Add ticks to 3D plot."""
if fig is None:
fig = plt.figure(figsize=(figsize, figsize))
ax = fig if (isinstance(fig, plt.Axes)) else fig.add_subplot(projection="3d")
ax.set_xticks([-KMAX, -KMAX / 2, 0, KMAX / 2, KMAX])
ax.set_yticks([-KMAX, -KMAX / 2, 0, KMAX / 2, KMAX])
ax.set_zticks([-KMAX, -KMAX / 2, 0, KMAX / 2, KMAX])
ax.axes.set_xlim3d(left=-KMAX, right=KMAX)
ax.axes.set_ylim3d(bottom=-KMAX, top=KMAX)
ax.axes.set_zlim3d(bottom=-KMAX, top=KMAX)
ax.set_box_aspect((2 * KMAX, 2 * KMAX, 2 * KMAX))
ax.set_xlabel("kx", fontsize=displayConfig.fontsize)
ax.set_ylabel("ky", fontsize=displayConfig.fontsize)
ax.set_zlabel("kz", fontsize=displayConfig.fontsize)
return ax
######################
# TRAJECTORY DISPLAY #
######################
[docs]
def display_2D_trajectory(
trajectory,
figsize=5,
one_shot=False,
subfigure=None,
show_constraints=False,
gmax=DEFAULT_GMAX,
smax=DEFAULT_SMAX,
constraints_order=None,
**constraints_kwargs,
):
"""Display 2D trajectories.
Parameters
----------
trajectory : array_like
Trajectory to display.
figsize : float, optional
Size of the figure.
one_shot : bool or int, optional
State if a specific shot should be highlighted in bold black.
If `True`, highlight the middle shot.
If `int`, highlight the shot at that index.
The default is `False`.
subfigure: plt.Figure, plt.SubFigure or plt.Axes, optional
The figure where the trajectory should be displayed.
The default is `None`.
show_constraints : bool, optional
Display the points where the gradients and slew rates
are above the `gmax` and `smax` limits, respectively.
The default is `False`.
gmax: float, optional
Maximum constraint on the gradients in T/m.
The default is `DEFAULT_GMAX`.
smax: float, optional
Maximum constraint on the slew rates in T/m/ms.
The default is `DEFAULT_SMAX`.
constraint_order: int, str, optional
Norm order defining how the constraints are checked,
typically 2 or `np.inf`, following the `numpy.linalg.norm`
conventions on parameter `ord`.
The default is None.
**kwargs
Acquisition parameters used to check on hardware constraints,
following the parameter convention from
`mrinufft.trajectories.utils.compute_gradients_and_slew_rates`.
Returns
-------
ax : plt.Axes
Axes of the figure.
"""
# Setup figure and ticks
Nc, Ns = trajectory.shape[:2]
ax = _setup_2D_ticks(figsize, subfigure)
colors = displayConfig.get_colorlist()
# Display every shot
for i in range(Nc):
ax.plot(
trajectory[i, :, 0],
trajectory[i, :, 1],
color=colors[i % displayConfig.nb_colors],
linewidth=displayConfig.linewidth,
)
# Display one shot in particular if requested
if one_shot is not False: # If True or int
# Select shot
shot_id = Nc // 2
if one_shot is not True: # If int
shot_id = one_shot
# Highlight the shot in black
ax.plot(
trajectory[shot_id, :, 0],
trajectory[shot_id, :, 1],
color=displayConfig.one_shot_color,
linewidth=2 * displayConfig.linewidth,
)
# Point out violated constraints if requested
if show_constraints:
gradients, slews = compute_gradients_and_slew_rates(
trajectory, **constraints_kwargs
)
# Pad and compute norms
gradients = np.linalg.norm(
np.pad(gradients, ((0, 0), (1, 0), (0, 0))), axis=-1, ord=constraints_order
)
slews = np.linalg.norm(
np.pad(slews, ((0, 0), (2, 0), (0, 0))), axis=-1, ord=constraints_order
)
# Check constraints
trajectory = trajectory.reshape((-1, 2))
gradients = trajectory[np.where(gradients.flatten() > gmax)]
slews = trajectory[np.where(slews.flatten() > smax)]
# Scatter points with vivid colors
ax.scatter(
gradients[:, 0],
gradients[:, 1],
color=displayConfig.gradient_point_color,
s=displayConfig.pointsize,
)
ax.scatter(
slews[:, 0],
slews[:, 1],
color=displayConfig.slewrate_point_color,
s=displayConfig.pointsize,
)
return ax
[docs]
def display_3D_trajectory(
trajectory,
nb_repetitions=None,
figsize=5,
per_plane=True,
one_shot=False,
subfigure=None,
show_constraints=False,
gmax=DEFAULT_GMAX,
smax=DEFAULT_SMAX,
constraints_order=None,
**constraints_kwargs,
):
"""Display 3D trajectories.
Parameters
----------
trajectory : array_like
Trajectory to display.
nb_repetitions : int
Number of repetitions (planes, cones, shells, etc).
The default is `None`.
figsize : float, optional
Size of the figure.
per_plane : bool, optional
If True, display the trajectory with a different color
for each plane.
one_shot : bool or int, optional
State if a specific shot should be highlighted in bold black.
If `True`, highlight the middle shot.
If `int`, highlight the shot at that index.
The default is `False`.
subfigure: plt.Figure, plt.SubFigure or plt.Axes, optional
The figure where the trajectory should be displayed.
The default is `None`.
show_constraints : bool, optional
Display the points where the gradients and slew rates
are above the `gmax` and `smax` limits, respectively.
The default is `False`.
gmax: float, optional
Maximum constraint on the gradients in T/m.
The default is `DEFAULT_GMAX`.
smax: float, optional
Maximum constraint on the slew rates in T/m/ms.
The default is `DEFAULT_SMAX`.
constraint_order: int, str, optional
Norm order defining how the constraints are checked,
typically 2 or `np.inf`, following the `numpy.linalg.norm`
conventions on parameter `ord`.
The default is None.
**kwargs
Acquisition parameters used to check on hardware constraints,
following the parameter convention from
`mrinufft.trajectories.utils.compute_gradients_and_slew_rates`.
Returns
-------
ax : plt.Axes
Axes of the figure.
"""
# Setup figure and ticks, and handle 2D trajectories
ax = _setup_3D_ticks(figsize, subfigure)
if nb_repetitions is None:
nb_repetitions = trajectory.shape[0]
if trajectory.shape[-1] == 2:
trajectory = np.concatenate(
[trajectory, np.zeros((*(trajectory.shape[:2]), 1))], axis=-1
)
trajectory = trajectory.reshape((nb_repetitions, -1, trajectory.shape[-2], 3))
Nc, Ns = trajectory.shape[1:3]
colors = displayConfig.get_colorlist()
# Display every shot
for i in range(nb_repetitions):
for j in range(Nc):
ax.plot(
trajectory[i, j, :, 0],
trajectory[i, j, :, 1],
trajectory[i, j, :, 2],
color=colors[(i + j * (not per_plane)) % displayConfig.nb_colors],
linewidth=displayConfig.linewidth,
)
# Display one shot in particular if requested
if one_shot is not False: # If True or int
trajectory = trajectory.reshape((-1, Ns, 3))
# Select shot
shot_id = Nc // 2
if one_shot is not True: # If int
shot_id = one_shot
# Highlight the shot in black
ax.plot(
trajectory[shot_id, :, 0],
trajectory[shot_id, :, 1],
trajectory[shot_id, :, 2],
color=displayConfig.one_shot_color,
linewidth=2 * displayConfig.linewidth,
)
trajectory = trajectory.reshape((-1, Nc, Ns, 3))
# Point out violated constraints if requested
if show_constraints:
gradients, slewrates = compute_gradients_and_slew_rates(
trajectory.reshape((-1, Ns, 3)), **constraints_kwargs
)
# Pad and compute norms
gradients = np.linalg.norm(
np.pad(gradients, ((0, 0), (1, 0), (0, 0))), axis=-1, ord=constraints_order
)
slewrates = np.linalg.norm(
np.pad(slewrates, ((0, 0), (2, 0), (0, 0))), axis=-1, ord=constraints_order
)
# Check constraints
gradients = trajectory.reshape((-1, 3))[np.where(gradients.flatten() > gmax)]
slewrates = trajectory.reshape((-1, 3))[np.where(slewrates.flatten() > smax)]
# Scatter points with vivid colors
ax.scatter(
*(gradients.T),
color=displayConfig.gradient_point_color,
s=displayConfig.pointsize,
)
ax.scatter(
*(slewrates.T),
color=displayConfig.slewrate_point_color,
s=displayConfig.pointsize,
)
return ax
####################
# GRADIENT DISPLAY #
####################
[docs]
def display_gradients_simply(
trajectory,
shot_ids=(0,),
figsize=5,
fill_area=True,
show_signal=True,
uni_signal="gray",
uni_gradient=None,
subfigure=None,
):
"""Display gradients based on trajectory of any dimension.
Parameters
----------
trajectory : array_like
Trajectory to display.
shot_ids : list of int
Indices of the shots to display.
The default is `[0]`.
figsize : float, optional
Size of the figure.
fill_area : bool, optional
Fills the area under the curve for improved visibility and
representation of the integral, aka trajectory.
The default is `True`.
show_signal : bool, optional
Show an additional illustration of the signal as
the modulated distance to the center.
The default is `True`.
uni_signal : str or None, optional
Define whether the signal should be represented by a
unique color given as argument or just by the default
color cycle when `None`.
The default is `"gray"`.
uni_signal : str or None, optional
Define whether the gradients should be represented by a
unique color given as argument or just by the default
color cycle when `None`.
The default is `None`.
subfigure: plt.Figure or plt.SubFigure, optional
The figure where the trajectory should be displayed.
The default is `None`.
Returns
-------
axes : plt.Axes
Axes of the figure.
"""
# Setup figure and labels
Nd = trajectory.shape[-1]
if subfigure is None:
fig = plt.figure(figsize=(figsize, figsize * (Nd + show_signal) / Nd))
else:
fig = subfigure
axes = fig.subplots(Nd + show_signal, 1)
for i, ax in enumerate(axes[:Nd]):
ax.set_ylabel("G{}".format(["x", "y", "z"][i]), fontsize=displayConfig.fontsize)
axes[-1].set_xlabel("Time", fontsize=displayConfig.fontsize)
# Setup axes ticks
for ax in axes:
ax.grid(True)
ax.xaxis.set_tick_params(labelbottom=False)
ax.yaxis.set_tick_params(labelleft=False)
# Plot the curves for each axis
gradients = np.diff(trajectory, axis=1)
vmax = 1.1 * np.max(np.abs(gradients[shot_ids, ...]))
x_axis = np.arange(gradients.shape[1])
colors = displayConfig.get_colorlist()
for j, s_id in enumerate(shot_ids):
for i, ax in enumerate(axes[:Nd]):
ax.set_ylim((-vmax, vmax))
color = (
uni_gradient
if uni_gradient is not None
else colors[j % displayConfig.nb_colors]
)
ax.plot(x_axis, gradients[s_id, ..., i], color=color)
if fill_area:
ax.fill_between(
x_axis,
gradients[s_id, ..., i],
alpha=displayConfig.alpha,
color=color,
)
# Return axes alone
if not show_signal:
return axes
# Show signal as modulated distance to center
distances = np.linalg.norm(trajectory[shot_ids, 1:-1], axis=-1)
distances = np.tile(distances.reshape((len(shot_ids), -1, 1)), (1, 1, 10))
signal = 1 - distances.reshape((len(shot_ids), -1)) / np.max(distances)
signal = (
signal * np.exp(2j * np.pi * figsize / 100 * np.arange(signal.shape[1]))
).real
signal = signal * np.abs(signal) ** 3
colors = displayConfig.get_colorlist()
# Show signal for each requested shot
axes[-1].set_ylim((-1, 1))
axes[-1].set_ylabel("Signal", fontsize=displayConfig.fontsize)
for j in range(len(shot_ids)):
color = (
uni_signal
if (uni_signal is not None)
else colors[j % displayConfig.nb_colors]
)
axes[-1].plot(np.arange(signal.shape[1]), signal[j], color=color)
return axes
[docs]
def display_gradients(
trajectory,
shot_ids=(0,),
figsize=5,
fill_area=True,
show_signal=True,
uni_signal="gray",
uni_gradient=None,
subfigure=None,
show_constraints=False,
gmax=DEFAULT_GMAX,
smax=DEFAULT_SMAX,
constraints_order=None,
raster_time=DEFAULT_RASTER_TIME,
**constraints_kwargs,
):
"""Display gradients based on trajectory of any dimension.
Parameters
----------
trajectory : array_like
Trajectory to display.
shot_ids : list of int
Indices of the shots to display.
The default is `(0,)`.
figsize : float, optional
Size of the figure.
fill_area : bool, optional
Fills the area under the curve for improved visibility and
representation of the integral, aka trajectory.
The default is `True`.
show_signal : bool, optional
Show an additional illustration of the signal as
the modulated distance to the center.
The default is `True`.
uni_signal : str or None, optional
Define whether the signal should be represented by a
unique color given as argument or just by the default
color cycle when `None`.
The default is `"gray"`.
uni_signal : str or None, optional
Define whether the gradients should be represented by a
unique color given as argument or just by the default
color cycle when `None`.
The default is `None`.
subfigure: plt.Figure or plt.SubFigure, optional
The figure where the trajectory should be displayed.
The default is `None`.
show_constraints : bool, optional
Display the points where the gradients and slew rates
are above the `gmax` and `smax` limits, respectively.
The default is `False`.
gmax: float, optional
Maximum constraint on the gradients in T/m.
The default is `DEFAULT_GMAX`.
smax: float, optional
Maximum constraint on the slew rates in T/m/ms.
The default is `DEFAULT_SMAX`.
constraint_order: int, str, optional
Norm order defining how the constraints are checked,
typically 2 or `np.inf`, following the `numpy.linalg.norm`
conventions on parameter `ord`.
The default is None.
raster_time: float, optional
Amount of time between the acquisition of two
consecutive samples in ms.
The default is `DEFAULT_RASTER_TIME`.
**kwargs
Acquisition parameters used to check on hardware constraints,
following the parameter convention from
`mrinufft.trajectories.utils.compute_gradients_and_slew_rates`.
Returns
-------
axes : plt.Axes
Axes of the figure.
"""
# Initialize figure with a simpler version
axes = display_gradients_simply(
trajectory,
shot_ids,
figsize,
fill_area,
show_signal,
uni_signal,
uni_gradient,
subfigure,
)
# Setup figure and labels
Nd = trajectory.shape[-1]
for i, ax in enumerate(axes[:Nd]):
ax.set_ylabel(
"G{} (mT/m)".format(["x", "y", "z"][i]),
fontsize=displayConfig.small_fontsize,
)
axes[-1].set_xlabel("Time (ms)", fontsize=displayConfig.small_fontsize)
if show_signal:
axes[-1].set_ylabel("Signal (a.u.)", fontsize=displayConfig.small_fontsize)
# Update axis ticks with rescaled values
for i, ax in enumerate(axes):
# Update xtick labels with time values
if ax == axes[-1]:
ax.xaxis.set_tick_params(labelbottom=True)
ticks = ax.get_xticks()
scale = (0.1 if (show_signal and ax == axes[-1]) else 1) * raster_time
locator = mticker.FixedLocator(ticks)
formatter = mticker.FixedFormatter(np.around(scale * ticks, 2))
ax.xaxis.set_major_locator(locator)
ax.xaxis.set_major_formatter(formatter)
# Update ytick labels with gradient values
ax.yaxis.set_tick_params(labelleft=True)
ticks = ax.get_yticks()
idx = min(i, Nd - 1)
norms = np.diff(trajectory[:1, :2, idx]).squeeze()
norms = np.where(norms != 0, norms, 1)
scale = (
convert_trajectory_to_gradients(
trajectory[:1, :2], raster_time=raster_time, **constraints_kwargs
)[0][0, 0, idx]
/ norms
)
scale = 1e3 * scale # Convert from T/m to mT/m
locator = mticker.FixedLocator(ticks)
formatter = mticker.FixedFormatter(np.around(scale * ticks, 1))
if not show_signal or ax != axes[-1]:
ax.yaxis.set_major_locator(locator)
ax.yaxis.set_major_formatter(formatter)
# Move on with constraints if requested
if not show_constraints:
return axes
# Compute true gradients and slew rates
gradients, slewrates = compute_gradients_and_slew_rates(
trajectory[shot_ids, :], **constraints_kwargs
)
gradients = np.linalg.norm(gradients, axis=-1, ord=constraints_order)
slewrates = np.linalg.norm(slewrates, axis=-1, ord=constraints_order)
slewrates = np.pad(slewrates, ((0, 0), (0, 1)))
# Point out hardware constraint violations
for ax in axes[:Nd]:
pts = np.where(gradients > gmax)
ax.scatter(
pts,
np.zeros_like(pts),
color=displayConfig.gradient_point_color,
s=displayConfig.pointsize,
)
pts = np.where(slewrates > smax)
ax.scatter(
pts,
np.zeros_like(pts),
color=displayConfig.slewrate_point_color,
s=displayConfig.pointsize,
)
return axes