Source code for mrinufft.trajectories.sampling

"""Sampling densities and methods."""

import numpy as np
import numpy.fft as nf
import numpy.linalg as nl
import numpy.random as nr
from tqdm.auto import tqdm

from .utils import KMAX


[docs] def sample_from_density( nb_samples, density, method="random", *, dim_compensation="auto" ): """ Sample points based on a given density distribution. Parameters ---------- nb_samples : int The number of samples to draw. density : np.ndarray An array representing the density distribution from which samples are drawn, normalized automatically by its sum during the call for convenience. method : str, optional The sampling method to use, either 'random' for random sampling over the discrete grid defined by the density or 'lloyd' for Lloyd's algorithm over a continuous space, by default "random". dim_compensation : str, bool, optional Whether to apply a specific dimensionality compensation introduced in [Cha+14]_. An exponent ``N/(N-1)`` with ``N`` the number of dimensions in ``density`` is applied to fix the observed density expectation when set to ``"auto"`` and ``method="lloyd"``. It is also relevant to set it to ``True`` when ``method="random"`` and one wants to create binary masks with continuous paths between drawn samples. Returns ------- np.ndarray An array of range-normalized sampled locations. Raises ------ ValueError If ``nb_samples`` exceeds the total size of the density array or if the specified ``method`` is unknown. References ---------- .. [Cha+14] Chauffert, Nicolas, Philippe Ciuciu, Jonas Kahn, and Pierre Weiss. "Variable density sampling with continuous trajectories." SIAM Journal on Imaging Sciences 7, no. 4 (2014): 1962-1992. """ try: from sklearn.cluster import BisectingKMeans, KMeans except ImportError as err: raise ImportError( "The scikit-learn module is not available. Please install " "it along with the [extra] dependencies " "or using `pip install scikit-learn`." ) from err # Define dimension variables shape = np.array(density.shape) nb_dims = len(shape) max_nb_samples = np.prod(shape) density = density / np.sum(density) if nb_samples > max_nb_samples: raise ValueError("`nb_samples` must be lower than the size of `density`.") # Check for dimensionality compensation if isinstance(dim_compensation, str) and dim_compensation != "auto": raise ValueError(f"Unknown string {dim_compensation} for `dim_compensation`.") if (dim_compensation == "auto" and method == "lloyd") or ( isinstance(dim_compensation, bool) and dim_compensation ): density = density ** (nb_dims / (nb_dims - 1)) density = density / np.sum(density) # Sample using specified method rng = nr.default_rng() if method == "random": choices = rng.choice( np.arange(max_nb_samples), size=nb_samples, p=density.flatten(), replace=False, ) locations = np.indices(shape).reshape((nb_dims, -1))[:, choices] locations = locations.T + 0.5 locations = locations / shape[None, :] locations = 2 * KMAX * locations - KMAX elif method == "lloyd": kmeans = ( KMeans(n_clusters=nb_samples) if nb_dims <= 2 else BisectingKMeans(n_clusters=nb_samples) ) kmeans.fit( np.indices(density.shape).reshape((nb_dims, -1)).T, sample_weight=density.flatten(), ) locations = kmeans.cluster_centers_ - np.array(density.shape) / 2 locations = KMAX * locations / np.max(np.abs(locations)) else: raise ValueError(f"Unknown sampling method {method}.") return locations
[docs] def create_cutoff_decay_density(shape, cutoff, decay, resolution=None): """ Create a density with central plateau and polynomial decay. Create a density composed of a central constant-value ellipsis defined by a cutoff ratio, followed by a polynomial decay over outer regions as defined in [Cha+22]_. Parameters ---------- shape : tuple of int The shape of the density grid, analog to the field-of-view as opposed to ``resolution`` below. cutoff : float The ratio of the largest k-space dimension between 0 and 1 within which density remains uniform and beyond which it decays. decay : float The polynomial decay in density beyond the cutoff ratio. resolution : np.ndarray, optional Resolution scaling factors for each dimension of the density grid, by default ``None``. Returns ------- np.ndarray A density array with values decaying based on the specified cutoff ratio and decay rate. References ---------- .. [Cha+22] Chaithya, G. R., Pierre Weiss, Guillaume Daval-Frérot, Aurélien Massire, Alexandre Vignaud, and Philippe Ciuciu. "Optimizing full 3D SPARKLING trajectories for high-resolution magnetic resonance imaging." IEEE Transactions on Medical Imaging 41, no. 8 (2022): 2105-2117. """ shape = np.array(shape) nb_dims = len(shape) if not resolution: resolution = np.ones(nb_dims) differences = np.indices(shape).astype(float) for i in range(nb_dims): differences[i] = differences[i] + 0.5 - shape[i] / 2 differences[i] = differences[i] / shape[i] / resolution[i] distances = nl.norm(differences, axis=0) cutoff = cutoff * np.max(differences) if cutoff else np.min(differences) density = np.ones(shape) decay_mask = np.where(distances > cutoff, True, False) density[decay_mask] = (cutoff / distances[decay_mask]) ** decay density = density / np.sum(density) return density
[docs] def create_polynomial_density(shape, decay, resolution=None): """ Create a density with polynomial decay from the center. Parameters ---------- shape : tuple of int The shape of the density grid. decay : float The exponent that controls the rate of decay for density. resolution : np.ndarray, optional Resolution scaling factors for each dimension of the density grid, by default None. Returns ------- np.ndarray A density array with polynomial decay. """ return create_cutoff_decay_density( shape, cutoff=0, decay=decay, resolution=resolution )
[docs] def create_energy_density(dataset): """ Create a density based on energy in the Fourier spectrum. A density is made based on the average energy in the Fourier domain of volumes from a target image dataset. Parameters ---------- dataset : np.ndarray The dataset from which to calculate the density based on its Fourier transform, with an expected shape (nb_volumes, dim_1, ..., dim_N). An N-dimensional Fourier transform is performed. Returns ------- np.ndarray A density array derived from the mean energy in the Fourier domain of the input dataset. """ nb_dims = len(dataset.shape) - 1 axes = range(1, nb_dims + 1) kspace = nf.fftshift(nf.fftn(nf.fftshift(dataset, axes=axes), axes=axes), axes=axes) energy = np.mean(np.abs(kspace), axis=0) density = energy / np.sum(energy) return density
[docs] def create_chauffert_density(shape, wavelet_basis, nb_wavelet_scales, verbose=False): """Create a density based on Chauffert's method. This is a reproduction of the proposition from [CCW13]_. A sampling density is derived from compressed sensing equations to maximize guarantees of exact image recovery for a specified sparse wavelet domain decomposition. Parameters ---------- shape : tuple of int The shape of the density grid. wavelet_basis : str, pywt.Wavelet The wavelet basis to use for wavelet decomposition, either as a built-in wavelet name from the PyWavelets package or as a custom ``pywt.Wavelet`` object. nb_wavelet_scales : int The number of wavelet scales to use in decomposition. verbose : bool, optional If ``True``, displays a progress bar. Default to ``False``. Returns ------- np.ndarray A density array created based on wavelet transform coefficients. See Also -------- pywt.wavelist : A list of wavelet decompositions available in the PyWavelets package used inside the function. pywt.Wavelet : A wavelet object accepted to generate Chauffert densities. References ---------- .. [CCW13] Chauffert, Nicolas, Philippe Ciuciu, and Pierre Weiss. "Variable density compressed sensing in MRI. Theoretical vs heuristic sampling strategies." In 2013 IEEE 10th International Symposium on Biomedical Imaging, pp. 298-301. IEEE, 2013. """ try: import pywt as pw except ImportError as err: raise ImportError( "The PyWavelets module is not available. Please install " "it along with the [extra] dependencies " "or using `pip install pywavelets`." ) from err nb_dims = len(shape) indices = np.indices(shape).reshape((nb_dims, -1)).T density = np.zeros(shape) unit_vector = np.zeros(shape) for ids in tqdm(indices, disable=not verbose): ids = tuple(ids) unit_vector[ids] = 1 fourier_vector = nf.ifftn(unit_vector) coeffs = pw.wavedecn( fourier_vector, wavelet=wavelet_basis, level=nb_wavelet_scales ) coeffs, _ = pw.coeffs_to_array(coeffs) density[ids] = np.max(np.abs(coeffs)) ** 2 unit_vector[ids] = 0 density = density / np.sum(density) return nf.ifftshift(density)
[docs] def create_fast_chauffert_density(shape, wavelet_basis, nb_wavelet_scales): """Create a density based on an approximated Chauffert method. This implementation is based on this tutorial: https://github.com/philouc/mri_acq_recon_tutorial. It is a fast approximation of the proposition from [CCW13]_, where a sampling density is derived from compressed sensing equations to maximize guarantees of exact image recovery for a specified sparse wavelet domain decomposition. In this approximation, the decomposition dimensions are considered independent and computed separately to accelerate the density generation. Parameters ---------- shape : tuple of int The shape of the density grid. wavelet_basis : str, pywt.Wavelet The wavelet basis to use for wavelet decomposition, either as a built-in wavelet name from the PyWavelets package or as a custom ``pywt.Wavelet`` object. nb_wavelet_scales : int The number of wavelet scales to use in decomposition. Returns ------- np.ndarray A density array created using a faster approximation based on 1D projections of the wavelet transform. See Also -------- pywt.wavelist : A list of wavelet decompositions available in the PyWavelets package used inside the function. pywt.Wavelet : A wavelet object accepted to generate Chauffert densities. References ---------- .. [CCW13] Chauffert, Nicolas, Philippe Ciuciu, and Pierre Weiss. "Variable density compressed sensing in MRI. Theoretical vs heuristic sampling strategies." In 2013 IEEE 10th International Symposium on Biomedical Imaging, pp. 298-301. IEEE, 2013. """ try: import pywt as pw except ImportError as err: raise ImportError( "The PyWavelets module is not available. Please install " "it along with the [extra] dependencies " "or using `pip install pywavelets`." ) from err nb_dims = len(shape) density = np.ones(shape) for k, s in enumerate(shape): unit_vector = np.zeros(s) density_1d = np.zeros(s) for i in range(s): unit_vector[i] = 1 fourier_vector = nf.ifft(unit_vector) coeffs = pw.wavedec( fourier_vector, wavelet=wavelet_basis, level=nb_wavelet_scales ) coeffs, _ = pw.coeffs_to_array(coeffs) density_1d[i] = np.max(np.abs(coeffs)) ** 2 unit_vector[i] = 0 reshape = np.ones(nb_dims).astype(int) reshape[k] = s density_1d = density_1d.reshape(reshape) density = density * density_1d density = density / np.sum(density) return nf.ifftshift(density)