Source code for mrinufft.operators.interfaces.bart

"""Interface for the BART NUFFT.

BART uses a command line interfaces, and read/writes data to files.

The file format is described here: https://bart-doc.readthedocs.io/en/latest/data.html#non-cartesian-datasets

"""

import os
import subprocess as subp
import tempfile
from pathlib import Path
from mrinufft._utils import proper_trajectory
from mrinufft.operators.base import FourierOperatorCPU

import numpy as np
from mrinufft.io.cfl import traj2cfl, _writecfl, _readcfl

# available if return code is 0
try:
    BART_AVAILABLE = not subp.call(
        ["which", "bart"], stdout=subp.DEVNULL, stderr=subp.DEVNULL
    )
except Exception:
    BART_AVAILABLE = False


[docs] class RawBartNUFFT: """Wrapper around BART NUFFT CLI.""" def __init__(self, samples, shape, extra_op_args=None, extra_adj_op_args=None): self.samples = samples # To normalize and send to file self.shape = shape self.shape_str = ":".join([str(s) for s in shape]) self.shape_str += ":1" if len(shape) == 2 else "" self._op_args = extra_op_args or [] self._adj_op_args = extra_adj_op_args or [] self._temp_dir = tempfile.TemporaryDirectory() # Write trajectory to temp file tmp_path = Path(self._temp_dir.name) self._traj_file = tmp_path / "traj" self._ksp_file = tmp_path / "ksp" self._grid_file = tmp_path / "grid" traj2cfl(self.samples, self.shape, self._traj_file)
[docs] def _tmp_file(self): """Return a temporary file name.""" return os.path.join(self._temp_dir.name, next(tempfile._get_candidate_names()))
def __del__(self): """Delete also the temporary files.""" self._temp_dir.cleanup()
[docs] def op(self, coeffs_data, grid_data): """Forward Operator.""" grid_data_ = grid_data.reshape(self.shape) _writecfl(grid_data_, self._grid_file) cmd = [ "bart", "nufft", "-d", self.shape_str, *self._op_args, str(self._traj_file), str(self._grid_file), str(self._ksp_file), ] try: subp.run(cmd, check=True, capture_output=True) except subp.CalledProcessError as exc: msg = "Failed to run BART NUFFT\n" msg += f"error code: {exc.returncode}\n" msg += "cmd: " + " ".join(cmd) + "\n" msg += f"stdout: {exc.output}\n" msg += f"stderr: {exc.stderr}" raise RuntimeError(msg) from exc ksp_raw = _readcfl(self._ksp_file) np.copyto(coeffs_data, ksp_raw) return coeffs_data
[docs] def adj_op(self, coeffs_data, grid_data): """Adjoint Operator.""" # Format grid data to cfl format, and write to file # Run bart nufft with argument in subprocess coeffs_ = coeffs_data.reshape(len(self.samples)) _writecfl(coeffs_[None, ..., None, None, None], self._ksp_file) cmd = [ "bart", "nufft", "-d", self.shape_str, "-a" if "-i" not in self._adj_op_args else "", *self._adj_op_args, str(self._traj_file), str(self._ksp_file), str(self._grid_file), ] try: subp.run(cmd, check=True, capture_output=True) except subp.CalledProcessError as exc: msg = "Failed to run BART NUFFT\n" msg += f"error code: {exc.returncode}\n" msg += "cmd: " + " ".join(cmd) + "\n" msg += f"stdout: {exc.output}\n" msg += f"stderr: {exc.stderr}" raise RuntimeError(msg) from exc grid_raw = _readcfl(self._grid_file) np.copyto(grid_data, grid_raw) return grid_data
[docs] class MRIBartNUFFT(FourierOperatorCPU): """BART implementation of MRI NUFFT transform.""" # TODO override Data consistency function: use toepliz backend = "bart" available = BART_AVAILABLE def __init__( self, samples, shape, density=False, n_coils=1, n_batchs=1, smaps=None, squeeze_dims=True, **kwargs, ): samples_ = proper_trajectory(samples, normalize="unit") if density is True: density = False if getattr(kwargs, "extra_adj_op_args", None): kwargs["extra_adj_op_args"] += ["-i"] else: kwargs["extra_adj_op_args"] = ["-i"] self.raw_op = RawBartNUFFT(samples_, shape, **kwargs) super().__init__( samples_, shape, density, n_coils=n_coils, n_batchs=n_batchs, n_trans=1, smaps=smaps, raw_op=self.raw_op, squeeze_dims=squeeze_dims, ) @property def norm_factor(self): """Normalization factor of the operator.""" # return 1.0 return np.sqrt(2 ** len(self.shape))