Source code for rsatoolbox.util.searchlight

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This code was initially inspired by the following :
https://github.com/machow/pysearchlight

@author: Daniel Lindh
"""
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm
from joblib import Parallel, delayed
from rsatoolbox.data.dataset import Dataset
from rsatoolbox.rdm.calc import calc_rdm
from rsatoolbox.rdm import RDMs


def _get_searchlight_neighbors(mask, center, radius=3):
    """Return indices for searchlight where distance
        between a voxel and their center < radius (in voxels)

    Args:
        center (index):  point around which to make searchlight sphere

    Returns:
        list: the list of volume indices that respect the
                searchlight radius for the input center.
    """
    center = np.array(center)
    mask_shape = mask.shape
    cx, cy, cz = np.array(center)
    x = np.arange(mask_shape[0])
    y = np.arange(mask_shape[1])
    z = np.arange(mask_shape[2])

    # First mask the obvious points
    # - may actually slow down your calculation depending.
    x = x[abs(x - cx) < radius]
    y = y[abs(y - cy) < radius]
    z = z[abs(z - cz) < radius]

    # Generate grid of points
    X, Y, Z = np.meshgrid(x, y, z)
    data = np.vstack((X.ravel(), Y.ravel(), Z.ravel())).T
    distance = cdist(data, center.reshape(1, -1), 'euclidean').ravel()

    return tuple(data[distance < radius].T.tolist())


[docs]def get_volume_searchlight(mask, radius=2, threshold=1.0): """ Searches through the non-zero voxels of the mask, selects centers where proportion of sphere voxels >= self.threshold. Args: mask ([numpy array]): binary brain mask radius (int, optional): the radius of each searchlight, defined in voxels. Defaults to 2. threshold (float, optional): Threshold of the proportion of voxels that need to be inside the brain mask in order for it to be considered a good searchlight center. Values go between 0.0 - 1.0 where 1.0 means that 100% of the voxels need to be inside the brain mask. Defaults to 1.0. Returns: numpy array: array of centers of size n_centers x 3 list: list of lists with neighbors - the length of the list will correspond to: n_centers x 3 x n_neighbors """ mask = np.array(mask) assert mask.ndim == 3, "Mask needs to be a 3-dimensional numpy array" centers = list(zip(*np.nonzero(mask))) good_centers = [] good_neighbors = [] for center in tqdm(centers, desc='Finding searchlights...'): neighbors = _get_searchlight_neighbors(mask, center, radius) if mask[neighbors].mean() >= threshold: good_centers.append(center) good_neighbors.append(neighbors) good_centers = np.array(good_centers) assert good_centers.shape[0] == len(good_neighbors),\ "number of centers and sets of neighbors do not match" print(f'Found {len(good_neighbors)} searchlights') # turn the 3-dim coordinates to array coordinates centers = np.ravel_multi_index(good_centers.T, mask.shape) neighbors = [np.ravel_multi_index(n, mask.shape) for n in good_neighbors] return centers, neighbors
[docs]def get_searchlight_RDMs(data_2d, centers, neighbors, events, method='correlation', verbose=True): """Iterates over all the searchlight centers and calculates the RDM Args: data_2d (2D numpy array): brain data, shape n_observations x n_channels (i.e. voxels/vertices) centers (1D numpy array): center indices for all searchlights as provided by rsatoolbox.util.searchlight.get_volume_searchlight neighbors (list): list of lists with neighbor voxel indices for all searchlights as provided by rsatoolbox.util.searchlight.get_volume_searchlight events (1D numpy array): 1D array of length n_observations method (str, optional): distance metric, see rsatoolbox.rdm.calc for options. Defaults to 'correlation'. verbose (bool, optional): Defaults to True. Returns: RDM [rsatoolbox.rdm.RDMs]: RDMs object with the RDM for each searchlight the RDM.rdm_descriptors['voxel_index'] describes the center voxel index each RDM is associated with """ data_2d, centers = np.array(data_2d), np.array(centers) n_centers = centers.shape[0] # For memory reasons, we chunk the data if we have more than 1000 RDMs if n_centers > 1000: # we can't run all centers at once, that will take too much memory # so lets to some chunking chunked_center = np.split(np.arange(n_centers), np.linspace(0, n_centers, 101, dtype=int)[1:-1]) # loop over chunks n_conds = len(np.unique(events)) RDM = np.zeros((n_centers, n_conds * (n_conds - 1) // 2)) for chunks in tqdm(chunked_center, desc='Calculating RDMs...'): center_data = [] for c in chunks: # grab this center and neighbors center = centers[c] center_neighbors = neighbors[c] # create a database object with this data ds = Dataset(data_2d[:, center_neighbors], descriptors={'center': center}, obs_descriptors={'events': events}, channel_descriptors={'voxels': center_neighbors}) center_data.append(ds) RDM_corr = calc_rdm(center_data, method=method, descriptor='events') RDM[chunks, :] = RDM_corr.dissimilarities else: center_data = [] for c in range(n_centers): # grab this center and neighbors center = centers[c] nb = neighbors[c] # create a database object with this data ds = Dataset(data_2d[:, nb], descriptors={'center': c}, obs_descriptors={'events': events}, channel_descriptors={'voxels': nb}) center_data.append(ds) # calculate RDMs for each database object RDM = calc_rdm(center_data, method=method, descriptor='events').dissimilarities SL_rdms = RDMs(RDM, rdm_descriptors={'voxel_index': centers}, dissimilarity_measure=method) return SL_rdms
[docs]def evaluate_models_searchlight(sl_RDM, models, eval_function, method='corr', theta=None, n_jobs=1): """evaluates each searchlighth with the given model/models Args: sl_RDM ([rsatoolbox.rdm.RDMs]): RDMs object as computed by rsatoolbox.util.searchlight.get_searchlight_RDMs models ([rsatoolbox.model]: models to evaluate - can also be list of models eval_function (rsatoolbox.inference evaluation-function): [description] method (str, optional): see rsatoolbox.rdm.compare for specifics. Defaults to 'corr'. n_jobs (int, optional): how many jobs to run. Defaults to 1. Returns: list: list of with the model evaluation for each searchlight center """ results = Parallel(n_jobs=n_jobs)( delayed(eval_function)( models, x, method=method, theta=theta) for x in tqdm( sl_RDM, desc='Evaluating models for each searchlight')) return results