Source code for dcmri.ui_tissue_ls_array

import os

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from dcmri import rel, sig, pk_inv, pk_aorta



PARAMS = {
    'dt': {
        'init': 1.0,
        'name': 'Time step',
        'unit': 'sec',
    },
    'ca': {
        'init': None,
        'name': 'Arterial Input concentration',
        'unit': 'mL/sec/cm3',
    },
    'irf': {
        'init': None, 
        'name': 'Impulse response function',
        'unit': 'mL/sec/cm3',
    },
    'r1': {
        'init': 5000.0,
        'name': 'Contrast agent relaxivity',
        'unit': 'Hz/M',
    },
    'R10a': {
        'init': 0.7,
        'name': 'Arterial precontrast R1',
        'unit': 'Hz',
    },
    'S0a': {
        'init': 1.0,
        'name': 'Arterial signal scaling factor',
        'unit': 'a.u.',
    },
    'B1corr_a': {
        'init': 1,
        'name': 'Arterial B1-correction factor',
        'unit': '',
    },
    'B1corr': {
        'name': 'Tissue B1-correction factor',
        'unit': '',
    },
    'FA': {
        'init': 15,
        'name': 'Flip angle',
        'unit': 'deg',
    },
    'TR': {
        'init': 0.005,
        'name': 'Repetition time',
        'unit': 'sec',
    },
    'TC': {
        'init': 0.2,
        'name': 'Time to k-space center',
        'unit': 'sec',
    },
    'TS': {
        'init': 0,
        'name': 'Sampling time',
        'unit': 'sec',
    },
    'R10': {
        'name': 'Tissue precontrast R1',
        'unit': 'Hz',
    },
    'S0': {
        'name': 'Signal scaling factor',
        'unit': 'a.u.',
    },
    'H': {
        'init': 0.45,
        'name': 'Tissue Hematocrit',
        'unit': '',
    },
    'Fb': {
        'name': 'Parenchymal blood flow',
        'unit': 'mL/sec/cm3',
    },
    've': {
        'name': 'Extracellular volume',
        'unit': 'mL/cm3',
    },
    'Te': {
        'name': 'Extracellular mean transit time',
        'unit': 'sec',
    },
}





[docs] class TissueLSArray(): """Array of linear and stationary tissues with a single inlet. These are generic model-free tissue types. Their response to an indicator injection is proportional to the dose (linear) and independent of the time of injection (stationary). Args: shape (array-like, required): shape of the tissue array (spatial dimensions only). Any number of dimensions is allowed. aif (array-like, required): Signal-time curve in the blood of the feeding artery. dt (float, optional): Time interval between values of the arterial input function. Defaults to 1.0. sequence (str, optional): imaging sequence. Possible values are 'SS', 'SR' and 'lin' (linear). Defaults to 'SS'. params (dict, optional): values for the parameters of the tissue, specified as keyword parameters. Defaults are used for any that are not provided. See Also: `TissueLS`, `TissueArray` Example: Fit a linear and stationary model to the synthetic test images: .. plot:: :include-source: :context: close-figs >>> import numpy as np >>> import dcmri as dc Use `fake_brain` to generate synthetic test data: >>> n=64 >>> time, signal, aif, gt = dc.fake_brain(n) The correct ground truth for ve in model-free analysis is the extracellular part of the distribution space: >>> gt['ve'] = gt['ve'] = np.where(gt['PS'] > 0, gt['vp'] + gt['vi'], gt['vp']) Build a tissue array and set the constants to match the experimental conditions of the synthetic test data. We use the exact T1-map as baseline values: >>> tissue = dc.TissueLSArray( ... (n,n), ... dt = time[1], ... sequence = 'SS', ... r1 = dc.relaxivity(3, 'blood','gadodiamide'), ... TR = 0.005, ... FA = 15, ... R10a = 1/dc.T1(3.0,'blood'), ... R10 = np.where(gt['T1']==0, 0, 1/gt['T1']), ... ) Train the tissue on the data. Since have noise-free synthetic data we use a lower tolerance than the default, which is optimized for noisy data: >>> tissue.train(signal, aif, n0=10, tol=0.01) Plot the reconstructed maps, along with the ground truth for reference. We set fixed scaling for the parameter maps so they are comparable. >>> vmin = {'Fb':0, 've':0, 'S0':0} >>> vmax = {'Fb':0.02, 've':0.2, 'S0':np.amax(gt['S0'])} >>> tissue.plot(vmin=vmin, vmax=vmax, ref=gt) Notes: As the example shows, even under noise-free conditions the maps are not reconstructed exactly. While the assumptions of linearity and stationarity are valid for the data generated by `fake_brain`, the price to pay for a fully model-free analysis is some numerical bias. The advantage is a convenient and fast first line analysis that applies under practically all conditions and produces robust maps that provide a quantitative insight into functional differences between tissue types. """ def __init__(self, shape, sequence='SS', **kwargs): # Configuration if sequence not in ['SS', 'SR', 'lin']: raise ValueError( f"Sequence {sequence} is not recognized. " f"Current options are 'SS', 'SR', 'lin'." ) self.shape=shape self.sequence = sequence self.pars = {} # Initialize scalar parameters params = ['dt', 'H', 'R10a', 'S0a', 'r1'] if self.sequence == 'SR': params += ['TC'] elif self.sequence == 'SS': params += ['TR', 'B1corr_a', 'FA'] elif self.sequence == 'lin': params += [] for par in params: self.pars[par] = PARAMS[par]['init'] # Initialize array parameters nt = 120 time = self.pars['dt'] * np.arange(nt) Kb = np.ones(shape + (1,)) / 5.0 self.pars['ca'] = pk_aorta.aif_tristan(time) self.pars['irf'] = 0.01 * np.exp(-Kb * time) self.pars['S0'] = np.ones(shape) self.pars['R10'] = np.ones(shape) if self.sequence == 'SS': self.pars['B1corr'] = np.ones(shape) # Override parameter defaults for par in kwargs: self.pars[par] = kwargs[par]
[docs] def predict_aif(self): """Predict the signal at specific time points Returns: np.ndarray: Array of predicted signals for each time point. """ # Predict arterial signal R1a = rel.relax(self.pars['ca'], self.pars['R10a'], self.pars['r1']) if self.sequence == 'SS': Sa = sig.signal_ss(self.pars['S0a'], R1a, self.pars['TR'], self.pars['B1corr']*self.pars['FA']) elif self.sequence == 'SR': Sa = sig.signal_src(self.pars['S0a'], R1a, self.pars['TC']) elif self.sequence == 'lin': Sa = sig.signal_lin(self.pars['S0a'], R1a) return Sa
[docs] def predict_conc(self): """Return the tissue concentration Returns: np.ndarray: Concentration in M """ ca_mat = self.pars['dt'] * pk_inv.convmat(self.pars['ca']) irf_mat = self.pars['irf'].reshape(-1, self.pars['irf'].shape[-1]) conc = ca_mat @ irf_mat.T return conc.T.reshape(self.pars['irf'].shape)
[docs] def predict(self): """Predict the signal at specific time points Returns: np.ndarray: Array of predicted signals for each time point. """ conc = self.predict_conc() R1 = rel.relax(conc, self.pars['R10'], self.pars['r1']) R1 = R1.reshape(-1, conc.shape[-1]) S0 = self.pars['S0'].ravel() signal = np.zeros(R1.shape) for x in tqdm(range(R1.shape[0]), 'Predicting signals'): if self.sequence == 'SS': signal[x,:] = sig.signal_ss(S0[x], R1[x,:], self.pars['TR'], self.pars['B1corr'].ravel()[x]*self.pars['FA']) elif self.sequence == 'SR': signal[x,:] = sig.signal_src(S0[x], R1[x,:], self.pars['TC']) elif self.sequence == 'lin': signal[x,:] = sig.signal_lin(S0[x], R1[x,:]) return signal.reshape(conc.shape)
[docs] def train(self, signal, signal_aif, n0=1, tol=0.1, init_s0=True): """Train the free parameters Args: signal (array-like): Array with measured signals. tol: cut-off value for the singular values in the computation of the matrix pseudo-inverse. Returns: self """ # Fit baselines if needed if init_s0: if self.sequence == 'SR': scla = sig.signal_src(1, self.pars['R10a'], self.pars['TC']) scl = sig.signal_src(1, self.pars['R10'], self.pars['TC']) elif self.sequence == 'SS': scla = sig.signal_ss(1, self.pars['R10a'], self.pars['TR'], self.pars['B1corr_a'] * self.pars['FA']) scl = sig.signal_ss(1, self.pars['R10'], self.pars['TR'], self.pars['B1corr'] * self.pars['FA']) elif self.sequence == 'lin': scla = sig.signal_lin(1, self.pars['R10a']) scl = sig.signal_lin(1, self.pars['R10']) self.pars['S0a'] = np.mean(signal_aif[:n0]) / scla if scla > 0 else 0 self.pars['S0'] = np.where(scl==0, 0, np.mean(signal[...,:n0], axis=-1)/scl) # Derive concentrations with np.errstate(divide="ignore", invalid="ignore", over="ignore"): T10 = np.where(self.pars['R10']==0, 0, 1/self.pars['R10']) if self.sequence == 'SR': self.pars['ca'] = sig.conc_src( signal_aif, self.pars['TC'], 1 / self.pars['R10a'], self.pars['r1'], S0=self.pars['S0a']) conc = sig.conc_src( signal, self.pars['TC'], T10, self.pars['r1'], S0=self.pars['S0']) elif self.sequence == 'SS': self.pars['ca'] = sig.conc_ss( signal_aif, self.pars['TR'], self.pars['B1corr_a'] * self.pars['FA'], 1 / self.pars['R10a'], self.pars['r1'], S0=self.pars['S0a']) conc = sig.conc_ss( signal, self.pars['TR'], self.pars['B1corr_a'] * self.pars['FA'], T10, self.pars['r1'], S0=self.pars['S0']) elif self.sequence == 'lin': self.pars['ca'] = sig.conc_lin( signal_aif, 1 / self.pars['R10a'], self.pars['r1'], S0=self.pars['S0a']) conc = sig.conc_lin( signal, T10, self.pars['r1'], S0=self.pars['S0']) conc[np.isnan(conc)] = 0 conc_mat = conc.reshape(-1, conc.shape[-1]) irf_mat = pk_inv.deconv(conc_mat.T, self.pars['ca'], self.pars['dt'], tol=tol) self.pars['irf'] = irf_mat.T.reshape(conc.shape) return self
[docs] def params(self, *args): """Export the tissue parameters Args: args (tuple): parameters to get. If no arguments are provided, all available parameters are returned. Returns: dict: Dictionary with tissue parameters. """ amax = np.max(self.pars['irf'], axis=-1) auc = np.sum(self.pars['irf'], axis=-1) * self.pars['dt'] params = { 'IRF': self.pars['irf'], 'Fb': amax, 'Te': np.divide(auc, amax, out=np.zeros_like(auc), where=amax!=0), 've': auc * (1-self.pars['H']), 'S0': self.pars['S0'], } if args == (): return params elif len(args) == 1: return params[args[0]] else: return {p: params[p] for p in args}
[docs] def plot(self, vmin={}, vmax={}, cmap='gray', ref=None, fname=None, show=True): """Plot parameter maps (all on one image) Note: this function is currently only available for 2D data. Args: vmin (dict, optional): Minimum values on display for given parameters. Defaults to {}. vmax (dict, optional): Maximum values on display for given parameters. Defaults to {}. cmap (str, optional): matplotlib colormap. Defaults to 'gray'. ref (dict, optional): Reference images - typically used to display ground truth data when available. Keys are 'signal' (array of data in the same shape as signal), and the parameter maps to show. Defaults to None. fname (str, optional): File path to save image. Defaults to None. show (bool, optional): Determine whether the image is shown or not. Defaults to True. Raises: NotImplementedError: Features that are not currently implemented. """ if len(self.pars['S0'].shape) == 1: raise NotImplementedError('Cannot plot 1D images.') yfit = self.predict() params = self.params('Fb', 've', 'S0') if len(self.pars['S0'].shape) == 2: ncols = 2 + len(params) nrows = 1 if ref is None else 2 fig = plt.figure(figsize=(ncols * 2, nrows * 2)) figcols = fig.subfigures( 1, 2, wspace=0.0, hspace=0.0, width_ratios=[2, ncols - 2]) # Left panel: signal ax = figcols[0].subplots(nrows, 2) figcols[0].subplots_adjust(hspace=0.0, wspace=0) for i in range(nrows): for j in range(2): ax[i, j].set_yticks([]) ax[i, j].set_xticks([]) # Signal maps ax[0, 0].set_title('max(signal)') ax[0, 0].set_ylabel('reconstruction') ax[0, 0].imshow(np.amax(yfit, axis=-1), vmin=0, vmax=0.5 * np.amax(ref['signal']), cmap=cmap) if ref is not None: ax[1, 0].set_ylabel('ground truth') ax[1, 0].imshow(np.amax(ref['signal'], axis=-1), vmin=0, vmax=0.5 * np.amax(ref['signal']), cmap=cmap) ax[0, 1].set_title('mean(signal)') ax[0, 1].imshow(np.mean(yfit, axis=-1), vmin=0, vmax=0.5 * np.amax(ref['signal']), cmap=cmap) if ref is not None: ax[1, 1].imshow(np.mean(ref['signal'], axis=-1), vmin=0, vmax=0.5 * np.amax(ref['signal']), cmap=cmap) # Right panel: free parameters ax = figcols[1].subplots(nrows, ncols - 2) figcols[1].subplots_adjust(hspace=0.0, wspace=0) for i in range(nrows): for j in range(ncols - 2): ax[i, j].set_yticks([]) ax[i, j].set_xticks([]) ax[0, 0].set_ylabel('reconstruction') if ref is not None: ax[1, 0].set_ylabel('ground truth') for i, par in enumerate(params.keys()): v0 = vmin[par] if par in vmin else np.percentile(params[par], 1) v1 = vmax[par] if par in vmax else np.percentile(params[par], 99) ax[0, i].set_title(par) ax[0, i].imshow(params[par], vmin=v0, vmax=v1, cmap=cmap) if ref is not None: ax[1, i].imshow(ref[par], vmin=v0, vmax=v1, cmap=cmap) if len(self.shape) == 3: raise NotImplementedError('3D plot not yet implemented') if fname is not None: plt.savefig(fname=fname) if show: plt.show() else: plt.close()
[docs] def plot_overlay(self, mask=None, vmin=None, vmax=None, aspect_ratio=16/9, cmap='magma', fname=None, show=True): """Plot parameter maps (one image per parameter) Note: this function is currently only available for 3D data. Args: signal (numpy.ndarray): dynamic signal mask (numpy.ndarray): If provided, only pixels inside the mask are shown vmin (dict, optional): Minimum values on display for the model parameters. Defaults to None. vmax (dict, optional): Maximum values on display for the model parameters. Defaults to None. aspect_ratio (float, optional): Aspect ratio of the mosaic. Defaults to 16/9. cmap (str, optional): matplotlib colormap. Defaults to 'gray'. fname (str, optional): File path to save image. Defaults to None. show (bool, optional): Determine whether the image is shown or not. Defaults to True. Raises: NotImplementedError: Features that are not currently implemented. """ if len(self.pars['S0'].shape) != 3: raise NotImplementedError('plot_params is only available for 3D data at this stage.') if self.pars['S0'].ndim==3: for par in ['Fb', 've']: img = self.params(par) alpha = img != 0 if mask is None else mask width = self.pars['S0'].shape[0] height = self.pars['S0'].shape[1] n_mosaics = self.pars['S0'].shape[2] nrows = int(np.round(np.sqrt((width*n_mosaics)/(aspect_ratio*height)))) ncols = int(np.ceil(n_mosaics/nrows)) # Set up figure fig, ax = plt.subplots( nrows=nrows, ncols=ncols, gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(ncols*width/max([width,height]), nrows*height/max([width,height])), dpi=300, ) plt.subplots_adjust(left=0, right=1, top=1, bottom=0) # Build figure i = 0 for row in tqdm(ax, desc='Building png'): for col in row: col.set_xticklabels([]) col.set_yticklabels([]) col.set_aspect('equal') col.axis("off") if i < img.shape[2]: col.imshow( self.pars['S0'][:,:,i].T, cmap='gray', interpolation='none', vmin=np.amin(self.pars['S0']) if vmin is None else vmin['background'], vmax=np.amax(self.pars['S0']) if vmax is None else vmax['background'], ) im = col.imshow( img[:,:,i].T, cmap=cmap, interpolation='none', alpha = alpha[:,:,i].T.astype(np.float32), vmin=np.amin(img[alpha]) if vmin is None else vmin[par], vmax=np.amax(img[alpha]) if vmax is None else vmax[par], ) i += 1 # Add colorbar on the right cbar = fig.colorbar(im, ax=ax.ravel().tolist(), location='right', fraction=0.046, pad=0.04) cbar.set_label(f"{par} ({PARAMS[par]['unit']})") cbar.ax.yaxis.set_label_position("left") cbar.ax.yaxis.set_ticks_position("right") if fname is not None: folder, filename = os.path.split(fname) name, ext = os.path.splitext(filename) new_filename = f"{par}_{name}{ext}" new_fullpath = os.path.join(folder, new_filename) fig.savefig(fname=new_fullpath, bbox_inches='tight', pad_inches=0) if show: plt.show() else: plt.close()