Source code for rsatoolbox.rdm.calc

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Calculation of RDMs from datasets
@author: heiko, benjamin
"""
from __future__ import annotations
from collections.abc import Iterable
from copy import deepcopy
from typing import TYPE_CHECKING, Optional, Tuple
import numpy as np
from rsatoolbox.rdm.rdms import concat
from rsatoolbox.rdm.calc_unbalanced import calc_rdm_unbalanced
from rsatoolbox.rdm.combine import from_partials
from rsatoolbox.data import average_dataset_by
from rsatoolbox.util.rdm_utils import _extract_triu_
from rsatoolbox.util.build_rdm import _build_rdms

if TYPE_CHECKING:
    from rsatoolbox.data.base import DatasetBase
    from numpy.typing import NDArray


[docs]def calc_rdm( dataset: DatasetBase, method: str = 'euclidean', descriptor: Optional[str] = None, noise: Optional[NDArray] = None, cv_descriptor: Optional[str] = None, prior_lambda: float = 1, prior_weight: float = 0.1, remove_mean: bool = False): """ calculates an RDM from an input dataset This should usually be called with the method and the descriptor argument to specify the dissimilarity measure and which observations in the dataset belong to which condition. Args: dataset (rsatoolbox.data.dataset.DatasetBase): The dataset the RDM is computed from method (String): a description of the dissimilarity measure (e.g. 'Euclidean') descriptor (String): obs_descriptor used to define the rows/columns of the RDM noise (numpy.ndarray): dataset.n_channel x dataset.n_channel precision matrix used to calculate the RDM used only for Mahalanobis and Crossnobis estimators defaults to an identity matrix, i.e. euclidean distance remove_mean (bool): whether the mean of each pattern shall be removed before distance calculation. This has no effect on poisson based and correlation distances. Returns: rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM """ if isinstance(dataset, Iterable): rdms = [] for i_dat, ds_i in enumerate(dataset): if noise is None: rdms.append(calc_rdm( ds_i, method=method, descriptor=descriptor, cv_descriptor=cv_descriptor, prior_lambda=prior_lambda, prior_weight=prior_weight)) elif isinstance(noise, np.ndarray) and noise.ndim == 2: rdms.append(calc_rdm( ds_i, method=method, descriptor=descriptor, noise=noise, cv_descriptor=cv_descriptor, prior_lambda=prior_lambda, prior_weight=prior_weight)) elif isinstance(noise, Iterable): rdms.append(calc_rdm( ds_i, method=method, descriptor=descriptor, noise=noise[i_dat], cv_descriptor=cv_descriptor, prior_lambda=prior_lambda, prior_weight=prior_weight)) if descriptor is None: rdm = concat(rdms) else: rdm = from_partials(rdms, descriptor=descriptor) else: if method == 'euclidean': rdm = calc_rdm_euclidean(dataset, descriptor, remove_mean) elif method == 'correlation': rdm = calc_rdm_correlation(dataset, descriptor) elif method == 'mahalanobis': rdm = calc_rdm_mahalanobis(dataset, descriptor, noise, remove_mean) elif method == 'crossnobis': rdm = calc_rdm_crossnobis(dataset, descriptor, noise, cv_descriptor, remove_mean) elif method == 'poisson': rdm = calc_rdm_poisson(dataset, descriptor, prior_lambda=prior_lambda, prior_weight=prior_weight) elif method == 'poisson_cv': rdm = calc_rdm_poisson_cv(dataset, descriptor, cv_descriptor=cv_descriptor, prior_lambda=prior_lambda, prior_weight=prior_weight) else: raise NotImplementedError if descriptor is not None: rdm.sort_by(**{descriptor: 'alpha'}) return rdm
[docs]def calc_rdm_movie( dataset, method='euclidean', descriptor=None, noise=None, cv_descriptor=None, prior_lambda=1, prior_weight=0.1, time_descriptor='time', bins=None, unbalanced=False): """ calculates an RDM movie from an input TemporalDataset Args: dataset (rsatoolbox.data.dataset.TemporalDataset): The dataset the RDM is computed from method (String): a description of the dissimilarity measure (e.g. 'Euclidean') descriptor (String): obs_descriptor used to define the rows/columns of the RDM noise (numpy.ndarray): dataset.n_channel x dataset.n_channel precision matrix used to calculate the RDM used only for Mahalanobis and Crossnobis estimators defaults to an identity matrix, i.e. euclidean distance time_descriptor (String): descriptor key that points to the time dimension in dataset.time_descriptors. Defaults to 'time'. bins (array-like): list of bins, with bins[i] containing the vector of time-points for the i-th bin. Defaults to no binning. unbalanced (bool): if set to True use calc_rdm_unbalanced, else and by default use calc_rdm Returns: rsatoolbox.rdm.rdms.RDMs: RDMs object with RDM movie """ if isinstance(dataset, Iterable): rdms = [] for i_dat, ds_i in enumerate(dataset): if noise is None: rdms.append(calc_rdm_movie( ds_i, method=method, descriptor=descriptor)) elif isinstance(noise, np.ndarray) and noise.ndim == 2: rdms.append(calc_rdm_movie( ds_i, method=method, descriptor=descriptor, noise=noise)) elif isinstance(noise, Iterable): rdms.append(calc_rdm_movie( ds_i, method=method, descriptor=descriptor, noise=noise[i_dat])) rdm = concat(rdms) else: if bins is not None: binned_data = dataset.bin_time(time_descriptor, bins) splited_data = binned_data.split_time(time_descriptor) time = binned_data.time_descriptors[time_descriptor] else: splited_data = dataset.split_time(time_descriptor) time = dataset.time_descriptors[time_descriptor] rdms = [] for dat in splited_data: dat_single = dat.convert_to_dataset(time_descriptor) if unbalanced: rdms.append(calc_rdm_unbalanced( dat_single, method=method, descriptor=descriptor, noise=noise, cv_descriptor=cv_descriptor, prior_lambda=prior_lambda, prior_weight=prior_weight)) else: rdms.append(calc_rdm( dat_single, method=method, descriptor=descriptor, noise=noise, cv_descriptor=cv_descriptor, prior_lambda=prior_lambda, prior_weight=prior_weight)) rdm = concat(rdms) rdm.rdm_descriptors[time_descriptor] = time rdm.dissimilarity_measure = method return rdm
[docs]def calc_rdm_euclidean( dataset: DatasetBase, descriptor: Optional[str] = None, remove_mean: bool = False): """ Args: dataset (rsatoolbox.data.DatasetBase): The dataset the RDM is computed from descriptor (String): obs_descriptor used to define the rows/columns of the RDM defaults to one row/column per row in the dataset remove_mean (bool): whether the mean of each pattern shall be removed before calculating distances. Returns: rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM """ measurements, desc = _parse_input(dataset, descriptor, remove_mean) sum_sq_measurements = np.sum(measurements**2, axis=1, keepdims=True) rdm = sum_sq_measurements + sum_sq_measurements.T \ - 2 * np.dot(measurements, measurements.T) rdm = _extract_triu_(rdm) / measurements.shape[1] return _build_rdms(rdm, dataset, 'squared euclidean', descriptor, desc)
[docs]def calc_rdm_correlation(dataset, descriptor=None): """ calculates an RDM from an input dataset using correlation distance If multiple instances of the same condition are found in the dataset they are averaged. Args: dataset (rsatoolbox.data.DatasetBase): The dataset the RDM is computed from descriptor (String): obs_descriptor used to define the rows/columns of the RDM defaults to one row/column per row in the dataset Returns: rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM """ ma, desc = _parse_input(dataset, descriptor, remove_mean=True) ma /= np.sqrt(np.einsum('ij,ij->i', ma, ma))[:, None] rdm = 1 - np.einsum('ik,jk', ma, ma) return _build_rdms(rdm, dataset, 'correlation', descriptor, desc)
[docs]def calc_rdm_mahalanobis(dataset, descriptor=None, noise=None, remove_mean: bool = False): """ calculates an RDM from an input dataset using mahalanobis distance If multiple instances of the same condition are found in the dataset they are averaged. Args: dataset (rsatoolbox.data.dataset.DatasetBase): The dataset the RDM is computed from descriptor (String): obs_descriptor used to define the rows/columns of the RDM defaults to one row/column per row in the dataset noise (numpy.ndarray): dataset.n_channel x dataset.n_channel precision matrix used to calculate the RDM default: identity matrix, i.e. euclidean distance remove_mean (bool): whether the mean of each pattern shall be removed before calculating distances. Returns: rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM """ if noise is None: return calc_rdm_euclidean(dataset, descriptor, remove_mean) measurements, desc = _parse_input(dataset, descriptor, remove_mean) noise = _check_noise(noise, dataset.n_channel) kernel = measurements @ noise @ measurements.T rdm = np.expand_dims(np.diag(kernel), 0) + \ np.expand_dims(np.diag(kernel), 1) - 2 * kernel rdm = _extract_triu_(rdm) / measurements.shape[1] return _build_rdms( rdm, dataset, 'squared mahalanobis', descriptor, desc, noise=noise )
[docs]def calc_rdm_crossnobis(dataset, descriptor, noise=None, cv_descriptor=None, remove_mean: bool = False): """ calculates an RDM from an input dataset using Cross-nobis distance This performs leave one out crossvalidation over the cv_descriptor. As the minimum input provide a dataset and a descriptor-name to define the rows & columns of the RDM. You may pass a noise precision. If you don't an identity is assumed. Also a cv_descriptor can be passed to define the crossvalidation folds. It is recommended to do this, to assure correct calculations. If you do not, this function infers a split in order of the dataset, which is guaranteed to fail if there are any unbalances. This function also accepts a list of noise precision matricies. It is then assumed that this is the precision of the mean from the corresponding crossvalidation fold, i.e. if multiple measurements enter a fold, please compute the resulting noise precision in advance! To assert equal ordering in the folds the dataset is initially sorted according to the descriptor used to define the patterns. Args: dataset (rsatoolbox.data.dataset.DatasetBase): The dataset the RDM is computed from descriptor (String): obs_descriptor used to define the rows/columns of the RDM defaults to one row/column per row in the dataset noise (numpy.ndarray): dataset.n_channel x dataset.n_channel precision matrix used to calculate the RDM default: identity matrix, i.e. euclidean distance cv_descriptor (String): obs_descriptor which determines the cross-validation folds remove_mean (bool): whether the mean of each pattern shall be removed before calculating distances. Returns: rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM """ noise = _check_noise(noise, dataset.n_channel) if noise is None: noise = np.eye(dataset.n_channel) if descriptor is None: raise ValueError('descriptor must be a string! Crossvalidation' + 'requires multiple measurements to be grouped') datasetCopy = deepcopy(dataset) if cv_descriptor is None: cv_desc = _gen_default_cv_descriptor(datasetCopy, descriptor) datasetCopy.obs_descriptors['cv_desc'] = cv_desc cv_descriptor = 'cv_desc' datasetCopy.sort_by(descriptor) cv_folds = np.unique(np.array(datasetCopy.obs_descriptors[cv_descriptor])) rdms = [] if (noise is None) or (isinstance(noise, np.ndarray) and noise.ndim == 2): for i_fold, fold in enumerate(cv_folds): data_test = datasetCopy.subset_obs(cv_descriptor, fold) data_train = datasetCopy.subset_obs( cv_descriptor, np.setdiff1d(cv_folds, fold) ) measurements_train, _, _ = \ average_dataset_by(data_train, descriptor) measurements_test, _, _ = \ average_dataset_by(data_test, descriptor) if remove_mean: measurements_train -= measurements_train.mean(axis=1, keepdims=True) measurements_test -= measurements_test.mean(axis=1, keepdims=True) rdm = _calc_rdm_crossnobis_single( measurements_train, measurements_test, noise) rdms.append(rdm) else: # a list of noises was provided measurements = [] variances = [] for i, i_fold in enumerate(cv_folds): data = datasetCopy.subset_obs(cv_descriptor, i_fold) ma = average_dataset_by(data, descriptor)[0] if remove_mean: ma -= ma.mean(axis=1, keepdims=True) measurements.append(ma) variances.append(np.linalg.inv(noise[i])) for i_fold in range(len(cv_folds)): for j_fold in range(i_fold + 1, len(cv_folds)): if i_fold != j_fold: rdm = _calc_rdm_crossnobis_single( measurements[i_fold], measurements[j_fold], np.linalg.inv( (variances[i_fold] + variances[j_fold]) / 2) ) rdms.append(rdm) rdms = np.array(rdms) rdm = np.einsum('ij->j', rdms) / rdms.shape[0] return _build_rdms( rdm, datasetCopy, 'crossnobis', descriptor, noise=noise, cv=cv_descriptor )
[docs]def calc_rdm_poisson(dataset, descriptor=None, prior_lambda=1, prior_weight=0.1): """ calculates an RDM from an input dataset using the symmetrized KL-divergence assuming a poisson distribution. If multiple instances of the same condition are found in the dataset they are averaged. Args: dataset (rsatoolbox.data.DatasetBase): The dataset the RDM is computed from descriptor (String): obs_descriptor used to define the rows/columns of the RDM defaults to one row/column per row in the dataset Returns: rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM """ measurements, desc = _parse_input(dataset, descriptor) measurements = (measurements + prior_lambda * prior_weight) \ / (1 + prior_weight) kernel = measurements @ np.log(measurements).T rdm = np.expand_dims(np.diag(kernel), 0) + \ np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T rdm = _extract_triu_(rdm) / measurements.shape[1] return _build_rdms(rdm, dataset, 'poisson', descriptor, desc)
[docs]def calc_rdm_poisson_cv(dataset, descriptor=None, prior_lambda=1, prior_weight=0.1, cv_descriptor=None): """ calculates an RDM from an input dataset using the crossvalidated symmetrized KL-divergence assuming a poisson distribution To assert equal ordering in the folds the dataset is initially sorted according to the descriptor used to define the patterns. Args: dataset (rsatoolbox.data.DatasetBase): The dataset the RDM is computed from descriptor (String): obs_descriptor used to define the rows/columns of the RDM defaults to one row/column per row in the dataset cv_descriptor (str): The descriptor that indicates the folds to use for crossvalidation Returns: rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM """ if descriptor is None: raise ValueError('descriptor must be a string! Crossvalidation' + 'requires multiple measurements to be grouped') dataset = deepcopy(dataset) if cv_descriptor is None: cv_desc = _gen_default_cv_descriptor(dataset, descriptor) dataset.obs_descriptors['cv_desc'] = cv_desc cv_descriptor = 'cv_desc' dataset.sort_by(descriptor) cv_folds = np.unique(np.array(dataset.obs_descriptors[cv_descriptor])) for i_fold in range(len(cv_folds)): fold = cv_folds[i_fold] data_test = dataset.subset_obs(cv_descriptor, fold) data_train = dataset.subset_obs(cv_descriptor, np.setdiff1d(cv_folds, fold)) measurements_train, _, _ = average_dataset_by(data_train, descriptor) measurements_test, _, _ = average_dataset_by(data_test, descriptor) measurements_train = (measurements_train + prior_lambda * prior_weight) \ / (1 + prior_weight) measurements_test = (measurements_test + prior_lambda * prior_weight) \ / (1 + prior_weight) kernel = measurements_train @ np.log(measurements_test).T rdm = np.expand_dims(np.diag(kernel), 0) + \ np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T rdm = _extract_triu_(rdm) / measurements_train.shape[1] return _build_rdms(rdm, dataset, 'poisson_cv', descriptor)
def _calc_rdm_crossnobis_single(meas1, meas2, noise) -> NDArray: kernel = meas1 @ noise @ meas2.T rdm = np.expand_dims(np.diag(kernel), 0) + \ np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T return _extract_triu_(rdm) / meas1.shape[1] def _gen_default_cv_descriptor(dataset, descriptor) -> np.ndarray: """ generates a default cv_descriptor for crossnobis This assumes that the first occurence each descriptor value forms the first group, the second occurence forms the second group, etc. """ desc = dataset.obs_descriptors[descriptor] values, counts = np.unique(desc, return_counts=True) assert np.all(counts == counts[0]), ( 'cv_descriptor generation failed:\n' + 'different number of observations per pattern') n_repeats = counts[0] cv_descriptor = np.zeros_like(desc) for i_val in values: cv_descriptor[desc == i_val] = np.arange(n_repeats) return cv_descriptor def _parse_input( dataset: DatasetBase, descriptor: Optional[str], remove_mean: bool = False ) -> Tuple[np.ndarray, Optional[np.ndarray]]: if descriptor is None: measurements = dataset.measurements desc = None else: measurements, desc, _ = average_dataset_by(dataset, descriptor) if remove_mean: measurements = measurements - measurements.mean(axis=1, keepdims=True) return measurements, desc def _check_noise(noise, n_channel): """ checks that a noise pattern is a matrix with correct dimension n_channel x n_channel Args: noise: noise input to be checked Returns: noise(np.ndarray): n_channel x n_channel noise precision matrix """ if noise is None: pass elif isinstance(noise, np.ndarray) and noise.ndim == 2: assert np.all(noise.shape == (n_channel, n_channel)) elif isinstance(noise, dict): for key in noise.keys(): noise[key] = _check_noise(noise[key], n_channel) elif isinstance(noise, Iterable): for idx, noise_i in enumerate(noise): noise[idx] = _check_noise(noise_i, n_channel) else: raise ValueError('noise(s) must have shape n_channel x n_channel') return noise