Source code for mrinufft._utils

"""General utility functions for MRI-NUFFT."""

import warnings
from inspect import cleandoc
from collections import defaultdict
from collections.abc import Callable
from functools import wraps

import numpy as np
from numpy.typing import DTypeLike, NDArray

from mrinufft._array_compat import get_array_module


[docs] def check_error(ier, message): # noqa: D103 if ier != 0: raise RuntimeError(message)
[docs] def sizeof_fmt(num, suffix="B"): """ Return a number as a XiB format. Parameters ---------- num: int The number to format suffix: str, default "B" The unit suffix References ---------- https://stackoverflow.com/a/1094933 """ for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: if abs(num) < 1024.0: return f"{num:3.1f}{unit}{suffix}" num /= 1024.0 return f"{num:.1f}Yi{suffix}"
[docs] def proper_trajectory(trajectory, normalize="pi"): """Normalize the trajectory to be used by NUFFT operators. Parameters ---------- trajectory: np.ndarray The trajectory to normalize, it might be of shape (Nc, Ns, dim) of (Ns, dim) normalize: str if "pi" trajectory will be rescaled in [-pi, pi], if it was in [-0.5, 0.5] if "unit" trajectory will be rescaled in [-0.5, 0.5] if it was not [-0.5, 0.5] Returns ------- new_traj: np.ndarray The normalized trajectory of shape (Nc * Ns, dim) or (Ns, dim) in -pi, pi """ # flatten to a list of point xp = get_array_module(trajectory) # check if the trajectory is a tensor try: new_traj = ( trajectory.clone() if xp.__name__ == "torch" else np.asarray(trajectory).copy() ) except Exception as e: raise ValueError( "trajectory should be array_like, with the last dimension being coordinates" ) from e new_traj = new_traj.reshape(-1, trajectory.shape[-1]) max_abs_val = xp.max(xp.abs(new_traj)) if normalize == "pi" and max_abs_val - 1e-4 < 0.5: warnings.warn( "Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)" ) new_traj *= 2 * xp.pi elif normalize == "unit" and max_abs_val - 1e-4 > 0.5: warnings.warn( "Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)" ) new_traj *= 1 / (2 * xp.pi) if normalize == "unit" and max_abs_val >= 0.5: new_traj = (new_traj + 0.5) % 1 - 0.5 return new_traj
[docs] class MethodRegister: """ A Decorator to register methods of the same type in dictionnaries. Parameters ---------- name: str The register name docstring_sub: dict[str,str] List of potential subsititutions to apply to the docstring. """ registry = defaultdict(dict) def __init__( self, register_name: str, docstring_subs: dict[str, str] | None = None ): self.register_name = register_name self.docstring_subs = docstring_subs
[docs] def __call__(self, method_name=None): """Register the function in the registry. It also substitute placeholder in docstrings. """ def decorator(func): self.registry[self.register_name][method_name] = func if self.docstring_subs is not None and func.__doc__: docstring = cleandoc(func.__doc__) for key, sub in self.docstring_subs.items(): docstring = docstring.replace(f"${{{key}}}", sub) func.__doc__ = docstring return func if callable(method_name): func = method_name method_name = func.__name__ return decorator(func) else: return decorator
def make_getter(self) -> Callable: def getter(method_name, *args, **kwargs): try: method = self.registry[self.register_name][method_name] except KeyError as e: raise ValueError( f"Unknown {self.register_name} method {method_name}." " Available methods are \n" f"{list(self.registry[self.register_name].keys())}" ) from e if args or kwargs: return method(*args, **kwargs) return method getter.__doc__ = f"""Get the {self.register_name} function from its name.""" getter.__name__ = f"get_{self.register_name}" return getter