"""Interface for the BART NUFFT.
BART uses a command line interfaces, and read/writes data to files.
The file format is described here: https://bart-doc.readthedocs.io/en/latest/data.html#non-cartesian-datasets
"""
import os
import subprocess as subp
import tempfile
from pathlib import Path
from mrinufft._utils import proper_trajectory
from mrinufft.operators.base import FourierOperatorCPU
import numpy as np
from mrinufft.io.cfl import traj2cfl, _writecfl, _readcfl
# available if return code is 0
try:
BART_AVAILABLE = not subp.call(
["which", "bart"], stdout=subp.DEVNULL, stderr=subp.DEVNULL
)
except Exception:
BART_AVAILABLE = False
[docs]
class RawBartNUFFT:
"""Wrapper around BART NUFFT CLI."""
def __init__(self, samples, shape, extra_op_args=None, extra_adj_op_args=None):
self.samples = samples # To normalize and send to file
self.shape = shape
self.shape_str = ":".join([str(s) for s in shape])
self.shape_str += ":1" if len(shape) == 2 else ""
self._op_args = extra_op_args or []
self._adj_op_args = extra_adj_op_args or []
self._temp_dir = tempfile.TemporaryDirectory()
# Write trajectory to temp file
tmp_path = Path(self._temp_dir.name)
self._traj_file = tmp_path / "traj"
self._ksp_file = tmp_path / "ksp"
self._grid_file = tmp_path / "grid"
traj2cfl(self.samples, self.shape, self._traj_file)
[docs]
def _tmp_file(self):
"""Return a temporary file name."""
return os.path.join(self._temp_dir.name, next(tempfile._get_candidate_names()))
def __del__(self):
"""Delete also the temporary files."""
self._temp_dir.cleanup()
[docs]
def op(self, coeffs_data, grid_data):
"""Forward Operator."""
grid_data_ = grid_data.reshape(self.shape)
_writecfl(grid_data_, self._grid_file)
cmd = [
"bart",
"nufft",
"-d",
self.shape_str,
*self._op_args,
str(self._traj_file),
str(self._grid_file),
str(self._ksp_file),
]
try:
subp.run(cmd, check=True, capture_output=True)
except subp.CalledProcessError as exc:
msg = "Failed to run BART NUFFT\n"
msg += f"error code: {exc.returncode}\n"
msg += "cmd: " + " ".join(cmd) + "\n"
msg += f"stdout: {exc.output}\n"
msg += f"stderr: {exc.stderr}"
raise RuntimeError(msg) from exc
ksp_raw = _readcfl(self._ksp_file)
np.copyto(coeffs_data, ksp_raw)
return coeffs_data
[docs]
def adj_op(self, coeffs_data, grid_data):
"""Adjoint Operator."""
# Format grid data to cfl format, and write to file
# Run bart nufft with argument in subprocess
coeffs_ = coeffs_data.reshape(len(self.samples))
_writecfl(coeffs_[None, ..., None, None, None], self._ksp_file)
cmd = [
"bart",
"nufft",
"-d",
self.shape_str,
"-a" if "-i" not in self._adj_op_args else "",
*self._adj_op_args,
str(self._traj_file),
str(self._ksp_file),
str(self._grid_file),
]
try:
subp.run(cmd, check=True, capture_output=True)
except subp.CalledProcessError as exc:
msg = "Failed to run BART NUFFT\n"
msg += f"error code: {exc.returncode}\n"
msg += "cmd: " + " ".join(cmd) + "\n"
msg += f"stdout: {exc.output}\n"
msg += f"stderr: {exc.stderr}"
raise RuntimeError(msg) from exc
grid_raw = _readcfl(self._grid_file)
np.copyto(grid_data, grid_raw)
return grid_data
[docs]
class MRIBartNUFFT(FourierOperatorCPU):
"""BART implementation of MRI NUFFT transform."""
# TODO override Data consistency function: use toepliz
backend = "bart"
available = BART_AVAILABLE
def __init__(
self,
samples,
shape,
density=False,
n_coils=1,
n_batchs=1,
smaps=None,
squeeze_dims=True,
**kwargs,
):
samples_ = proper_trajectory(samples, normalize="unit")
if density is True:
density = False
if getattr(kwargs, "extra_adj_op_args", None):
kwargs["extra_adj_op_args"] += ["-i"]
else:
kwargs["extra_adj_op_args"] = ["-i"]
self.raw_op = RawBartNUFFT(samples_, shape, **kwargs)
super().__init__(
samples_,
shape,
density,
n_coils=n_coils,
n_batchs=n_batchs,
n_trans=1,
smaps=smaps,
raw_op=self.raw_op,
squeeze_dims=squeeze_dims,
)
@property
def norm_factor(self):
"""Normalization factor of the operator."""
# return 1.0
return np.sqrt(2 ** len(self.shape))