import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from dcmri import rel, sig, pk_inv, pk_aorta
PARAMS = {
'dt': {
'init': 1.0,
'name': 'Time step',
'unit': 'sec',
},
'ca': {
'init': None,
'name': 'Arterial Input concentration',
'unit': 'mL/sec/cm3',
},
'irf': {
'init': None,
'name': 'Impulse response function',
'unit': 'mL/sec/cm3',
},
'r1': {
'init': 5000.0,
'name': 'Contrast agent relaxivity',
'unit': 'Hz/M',
},
'R10a': {
'init': 0.7,
'name': 'Arterial precontrast R1',
'unit': 'Hz',
},
'S0a': {
'init': 1.0,
'name': 'Arterial signal scaling factor',
'unit': 'a.u.',
},
'B1corr_a': {
'init': 1,
'name': 'Arterial B1-correction factor',
'unit': '',
},
'B1corr': {
'name': 'Tissue B1-correction factor',
'unit': '',
},
'FA': {
'init': 15,
'name': 'Flip angle',
'unit': 'deg',
},
'TR': {
'init': 0.005,
'name': 'Repetition time',
'unit': 'sec',
},
'TC': {
'init': 0.2,
'name': 'Time to k-space center',
'unit': 'sec',
},
'TS': {
'init': 0,
'name': 'Sampling time',
'unit': 'sec',
},
'R10': {
'name': 'Tissue precontrast R1',
'unit': 'Hz',
},
'S0': {
'name': 'Signal scaling factor',
'unit': 'a.u.',
},
'H': {
'init': 0.45,
'name': 'Tissue Hematocrit',
'unit': '',
},
'Fb': {
'name': 'Parenchymal blood flow',
'unit': 'mL/sec/cm3',
},
've': {
'name': 'Extracellular volume',
'unit': 'mL/cm3',
},
'Te': {
'name': 'Extracellular mean transit time',
'unit': 'sec',
},
}
[docs]
class TissueLSArray():
"""Array of linear and stationary tissues with a single inlet.
These are generic model-free tissue types. Their response to
an indicator injection is proportional to the dose (linear) and
independent of the time of injection (stationary).
Args:
shape (array-like, required): shape of the tissue array (spatial dimensions only).
Any number of dimensions is allowed.
aif (array-like, required): Signal-time curve in the blood of the
feeding artery.
dt (float, optional): Time interval between values of the arterial
input function. Defaults to 1.0.
sequence (str, optional): imaging sequence. Possible values
are 'SS', 'SR' and 'lin' (linear). Defaults to 'SS'.
params (dict, optional): values for the parameters of the tissue,
specified as keyword parameters. Defaults are used for any that are
not provided.
See Also:
`TissueLS`, `TissueArray`
Example:
Fit a linear and stationary model to the synthetic test images:
.. plot::
:include-source:
:context: close-figs
>>> import numpy as np
>>> import dcmri as dc
Use `fake_brain` to generate synthetic test data:
>>> n=64
>>> time, signal, aif, gt = dc.fake_brain(n)
The correct ground truth for ve in model-free analysis is the
extracellular part of the distribution space:
>>> gt['ve'] = gt['ve'] = np.where(gt['PS'] > 0, gt['vp'] + gt['vi'], gt['vp'])
Build a tissue array and set the constants to match the
experimental conditions of the synthetic test data. We use
the exact T1-map as baseline values:
>>> tissue = dc.TissueLSArray(
... (n,n),
... dt = time[1],
... sequence = 'SS',
... r1 = dc.relaxivity(3, 'blood','gadodiamide'),
... TR = 0.005,
... FA = 15,
... R10a = 1/dc.T1(3.0,'blood'),
... R10 = np.where(gt['T1']==0, 0, 1/gt['T1']),
... )
Train the tissue on the data. Since have noise-free synthetic
data we use a lower tolerance than the default, which is optimized
for noisy data:
>>> tissue.train(signal, aif, n0=10, tol=0.01)
Plot the reconstructed maps, along with the ground truth for reference.
We set fixed scaling for the parameter maps so they are comparable.
>>> vmin = {'Fb':0, 've':0, 'S0':0}
>>> vmax = {'Fb':0.02, 've':0.2, 'S0':np.amax(gt['S0'])}
>>> tissue.plot(vmin=vmin, vmax=vmax, ref=gt)
Notes:
As the example shows, even under noise-free conditions the maps
are not reconstructed exactly. While the assumptions of linearity and
stationarity are valid for the data generated by `fake_brain`, the
price to pay for a fully model-free analysis is some numerical bias.
The advantage is a convenient and fast first line analysis that applies under
practically all conditions and produces robust maps that provide a
quantitative insight into functional differences between tissue types.
"""
def __init__(self, shape, sequence='SS', **kwargs):
# Configuration
if sequence not in ['SS', 'SR', 'lin']:
raise ValueError(
f"Sequence {sequence} is not recognized. "
f"Current options are 'SS', 'SR', 'lin'."
)
self.shape=shape
self.sequence = sequence
self.pars = {}
# Initialize scalar parameters
params = ['dt', 'H', 'R10a', 'S0a', 'r1']
if self.sequence == 'SR':
params += ['TC']
elif self.sequence == 'SS':
params += ['TR', 'B1corr_a', 'FA']
elif self.sequence == 'lin':
params += []
for par in params:
self.pars[par] = PARAMS[par]['init']
# Initialize array parameters
nt = 120
time = self.pars['dt'] * np.arange(nt)
Kb = np.ones(shape + (1,)) / 5.0
self.pars['ca'] = pk_aorta.aif_tristan(time)
self.pars['irf'] = 0.01 * np.exp(-Kb * time)
self.pars['S0'] = np.ones(shape)
self.pars['R10'] = np.ones(shape)
if self.sequence == 'SS':
self.pars['B1corr'] = np.ones(shape)
# Override parameter defaults
for par in kwargs:
self.pars[par] = kwargs[par]
[docs]
def predict_aif(self):
"""Predict the signal at specific time points
Returns:
np.ndarray: Array of predicted signals for each time point.
"""
# Predict arterial signal
R1a = rel.relax(self.pars['ca'], self.pars['R10a'], self.pars['r1'])
if self.sequence == 'SS':
Sa = sig.signal_ss(self.pars['S0a'], R1a, self.pars['TR'], self.pars['B1corr']*self.pars['FA'])
elif self.sequence == 'SR':
Sa = sig.signal_src(self.pars['S0a'], R1a, self.pars['TC'])
elif self.sequence == 'lin':
Sa = sig.signal_lin(self.pars['S0a'], R1a)
return Sa
[docs]
def predict_conc(self):
"""Return the tissue concentration
Returns:
np.ndarray: Concentration in M
"""
ca_mat = self.pars['dt'] * pk_inv.convmat(self.pars['ca'])
irf_mat = self.pars['irf'].reshape(-1, self.pars['irf'].shape[-1])
conc = ca_mat @ irf_mat.T
return conc.T.reshape(self.pars['irf'].shape)
[docs]
def predict(self):
"""Predict the signal at specific time points
Returns:
np.ndarray: Array of predicted signals for each time point.
"""
conc = self.predict_conc()
R1 = rel.relax(conc, self.pars['R10'], self.pars['r1'])
R1 = R1.reshape(-1, conc.shape[-1])
S0 = self.pars['S0'].ravel()
signal = np.zeros(R1.shape)
for x in tqdm(range(R1.shape[0]), 'Predicting signals'):
if self.sequence == 'SS':
signal[x,:] = sig.signal_ss(S0[x], R1[x,:], self.pars['TR'], self.pars['B1corr'].ravel()[x]*self.pars['FA'])
elif self.sequence == 'SR':
signal[x,:] = sig.signal_src(S0[x], R1[x,:], self.pars['TC'])
elif self.sequence == 'lin':
signal[x,:] = sig.signal_lin(S0[x], R1[x,:])
return signal.reshape(conc.shape)
[docs]
def train(self, signal, signal_aif, n0=1, tol=0.1, init_s0=True):
"""Train the free parameters
Args:
signal (array-like): Array with measured signals.
tol: cut-off value for the singular values in the
computation of the matrix pseudo-inverse.
Returns:
self
"""
# Fit baselines if needed
if init_s0:
if self.sequence == 'SR':
scla = sig.signal_src(1, self.pars['R10a'], self.pars['TC'])
scl = sig.signal_src(1, self.pars['R10'], self.pars['TC'])
elif self.sequence == 'SS':
scla = sig.signal_ss(1, self.pars['R10a'], self.pars['TR'], self.pars['B1corr_a'] * self.pars['FA'])
scl = sig.signal_ss(1, self.pars['R10'], self.pars['TR'], self.pars['B1corr'] * self.pars['FA'])
elif self.sequence == 'lin':
scla = sig.signal_lin(1, self.pars['R10a'])
scl = sig.signal_lin(1, self.pars['R10'])
self.pars['S0a'] = np.mean(signal_aif[:n0]) / scla if scla > 0 else 0
self.pars['S0'] = np.where(scl==0, 0, np.mean(signal[...,:n0], axis=-1)/scl)
# Derive concentrations
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
T10 = np.where(self.pars['R10']==0, 0, 1/self.pars['R10'])
if self.sequence == 'SR':
self.pars['ca'] = sig.conc_src(
signal_aif, self.pars['TC'], 1 / self.pars['R10a'],
self.pars['r1'], S0=self.pars['S0a'])
conc = sig.conc_src(
signal, self.pars['TC'], T10,
self.pars['r1'], S0=self.pars['S0'])
elif self.sequence == 'SS':
self.pars['ca'] = sig.conc_ss(
signal_aif, self.pars['TR'], self.pars['B1corr_a'] * self.pars['FA'],
1 / self.pars['R10a'], self.pars['r1'], S0=self.pars['S0a'])
conc = sig.conc_ss(
signal, self.pars['TR'], self.pars['B1corr_a'] * self.pars['FA'],
T10, self.pars['r1'], S0=self.pars['S0'])
elif self.sequence == 'lin':
self.pars['ca'] = sig.conc_lin(
signal_aif, 1 / self.pars['R10a'],
self.pars['r1'], S0=self.pars['S0a'])
conc = sig.conc_lin(
signal, T10, self.pars['r1'], S0=self.pars['S0'])
conc[np.isnan(conc)] = 0
conc_mat = conc.reshape(-1, conc.shape[-1])
irf_mat = pk_inv.deconv(conc_mat.T, self.pars['ca'], self.pars['dt'], tol=tol)
self.pars['irf'] = irf_mat.T.reshape(conc.shape)
return self
[docs]
def params(self, *args):
"""Export the tissue parameters
Args:
args (tuple): parameters to get. If no arguments are
provided, all available parameters are returned.
Returns:
dict: Dictionary with tissue parameters.
"""
amax = np.max(self.pars['irf'], axis=-1)
auc = np.sum(self.pars['irf'], axis=-1) * self.pars['dt']
params = {
'IRF': self.pars['irf'],
'Fb': amax,
'Te': np.divide(auc, amax, out=np.zeros_like(auc), where=amax!=0),
've': auc * (1-self.pars['H']),
'S0': self.pars['S0'],
}
if args == ():
return params
elif len(args) == 1:
return params[args[0]]
else:
return {p: params[p] for p in args}
[docs]
def plot(self, vmin={}, vmax={},
cmap='gray', ref=None, fname=None, show=True):
"""Plot parameter maps (all on one image)
Note: this function is currently only available for 2D data.
Args:
vmin (dict, optional): Minimum values on display for given parameters. Defaults to {}.
vmax (dict, optional): Maximum values on display for given parameters. Defaults to {}.
cmap (str, optional): matplotlib colormap. Defaults to 'gray'.
ref (dict, optional): Reference images - typically used to display ground truth data when available. Keys are 'signal' (array of data in the same shape as signal), and the parameter maps to show. Defaults to None.
fname (str, optional): File path to save image. Defaults to None.
show (bool, optional): Determine whether the image is shown or not. Defaults to True.
Raises:
NotImplementedError: Features that are not currently implemented.
"""
if len(self.pars['S0'].shape) == 1:
raise NotImplementedError('Cannot plot 1D images.')
yfit = self.predict()
params = self.params('Fb', 've', 'S0')
if len(self.pars['S0'].shape) == 2:
ncols = 2 + len(params)
nrows = 1 if ref is None else 2
fig = plt.figure(figsize=(ncols * 2, nrows * 2))
figcols = fig.subfigures(
1, 2, wspace=0.0, hspace=0.0, width_ratios=[2, ncols - 2])
# Left panel: signal
ax = figcols[0].subplots(nrows, 2)
figcols[0].subplots_adjust(hspace=0.0, wspace=0)
for i in range(nrows):
for j in range(2):
ax[i, j].set_yticks([])
ax[i, j].set_xticks([])
# Signal maps
ax[0, 0].set_title('max(signal)')
ax[0, 0].set_ylabel('reconstruction')
ax[0, 0].imshow(np.amax(yfit, axis=-1), vmin=0,
vmax=0.5 * np.amax(ref['signal']), cmap=cmap)
if ref is not None:
ax[1, 0].set_ylabel('ground truth')
ax[1, 0].imshow(np.amax(ref['signal'], axis=-1),
vmin=0, vmax=0.5 * np.amax(ref['signal']), cmap=cmap)
ax[0, 1].set_title('mean(signal)')
ax[0, 1].imshow(np.mean(yfit, axis=-1), vmin=0,
vmax=0.5 * np.amax(ref['signal']), cmap=cmap)
if ref is not None:
ax[1, 1].imshow(np.mean(ref['signal'], axis=-1),
vmin=0, vmax=0.5 * np.amax(ref['signal']), cmap=cmap)
# Right panel: free parameters
ax = figcols[1].subplots(nrows, ncols - 2)
figcols[1].subplots_adjust(hspace=0.0, wspace=0)
for i in range(nrows):
for j in range(ncols - 2):
ax[i, j].set_yticks([])
ax[i, j].set_xticks([])
ax[0, 0].set_ylabel('reconstruction')
if ref is not None:
ax[1, 0].set_ylabel('ground truth')
for i, par in enumerate(params.keys()):
v0 = vmin[par] if par in vmin else np.percentile(params[par], 1)
v1 = vmax[par] if par in vmax else np.percentile(params[par], 99)
ax[0, i].set_title(par)
ax[0, i].imshow(params[par], vmin=v0, vmax=v1, cmap=cmap)
if ref is not None:
ax[1, i].imshow(ref[par], vmin=v0, vmax=v1, cmap=cmap)
if len(self.shape) == 3:
raise NotImplementedError('3D plot not yet implemented')
if fname is not None:
plt.savefig(fname=fname)
if show:
plt.show()
else:
plt.close()
[docs]
def plot_overlay(self, mask=None, vmin=None, vmax=None, aspect_ratio=16/9, cmap='magma', fname=None, show=True):
"""Plot parameter maps (one image per parameter)
Note: this function is currently only available for 3D data.
Args:
signal (numpy.ndarray): dynamic signal
mask (numpy.ndarray): If provided, only pixels inside the mask are shown
vmin (dict, optional): Minimum values on display for the model parameters. Defaults to None.
vmax (dict, optional): Maximum values on display for the model parameters. Defaults to None.
aspect_ratio (float, optional): Aspect ratio of the mosaic. Defaults to 16/9.
cmap (str, optional): matplotlib colormap. Defaults to 'gray'.
fname (str, optional): File path to save image. Defaults to None.
show (bool, optional): Determine whether the image is shown or not. Defaults to True.
Raises:
NotImplementedError: Features that are not currently implemented.
"""
if len(self.pars['S0'].shape) != 3:
raise NotImplementedError('plot_params is only available for 3D data at this stage.')
if self.pars['S0'].ndim==3:
for par in ['Fb', 've']:
img = self.params(par)
alpha = img != 0 if mask is None else mask
width = self.pars['S0'].shape[0]
height = self.pars['S0'].shape[1]
n_mosaics = self.pars['S0'].shape[2]
nrows = int(np.round(np.sqrt((width*n_mosaics)/(aspect_ratio*height))))
ncols = int(np.ceil(n_mosaics/nrows))
# Set up figure
fig, ax = plt.subplots(
nrows=nrows,
ncols=ncols,
gridspec_kw = {'wspace':0, 'hspace':0},
figsize=(ncols*width/max([width,height]), nrows*height/max([width,height])),
dpi=300,
)
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
# Build figure
i = 0
for row in tqdm(ax, desc='Building png'):
for col in row:
col.set_xticklabels([])
col.set_yticklabels([])
col.set_aspect('equal')
col.axis("off")
if i < img.shape[2]:
col.imshow(
self.pars['S0'][:,:,i].T,
cmap='gray',
interpolation='none',
vmin=np.amin(self.pars['S0']) if vmin is None else vmin['background'],
vmax=np.amax(self.pars['S0']) if vmax is None else vmax['background'],
)
im = col.imshow(
img[:,:,i].T,
cmap=cmap,
interpolation='none',
alpha = alpha[:,:,i].T.astype(np.float32),
vmin=np.amin(img[alpha]) if vmin is None else vmin[par],
vmax=np.amax(img[alpha]) if vmax is None else vmax[par],
)
i += 1
# Add colorbar on the right
cbar = fig.colorbar(im, ax=ax.ravel().tolist(), location='right', fraction=0.046, pad=0.04)
cbar.set_label(f"{par} ({PARAMS[par]['unit']})")
cbar.ax.yaxis.set_label_position("left")
cbar.ax.yaxis.set_ticks_position("right")
if fname is not None:
folder, filename = os.path.split(fname)
name, ext = os.path.splitext(filename)
new_filename = f"{par}_{name}{ext}"
new_fullpath = os.path.join(folder, new_filename)
fig.savefig(fname=new_fullpath, bbox_inches='tight', pad_inches=0)
if show:
plt.show()
else:
plt.close()