Source code for rsatoolbox.util.searchlight

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This code was initially inspired by the following :
https://github.com/machow/pysearchlight

@author: Daniel Lindh
"""
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm
from joblib import Parallel, delayed
from rsatoolbox.data.dataset import Dataset
from rsatoolbox.rdm.calc import calc_rdm
from rsatoolbox.rdm import RDMs


def _get_searchlight_neighbors(mask, center, radius=3, truncate_at_boundary=False):
    """Return indices for searchlight where distance
        between a voxel and their center < radius (in voxels)

    Args:
        center (index):  point around which to make searchlight sphere
        radius (int): radius of the searchlight sphere in voxels
        truncate_at_boundary (bool): if True, only include voxels where mask is True.
            if False (default), include all voxels within radius regardless of mask value.
            Setting to True addresses issue #466 by preventing searchlights from including
            voxels outside the mask boundary.

    Returns:
        list: the list of volume indices that respect the
                searchlight radius for the input center.
    """
    center = np.array(center)
    mask_shape = mask.shape
    cx, cy, cz = np.array(center)
    x = np.arange(mask_shape[0])
    y = np.arange(mask_shape[1])
    z = np.arange(mask_shape[2])

    # First mask the obvious points
    # - may actually slow down your calculation depending.
    x = x[abs(x - cx) < radius]
    y = y[abs(y - cy) < radius]
    z = z[abs(z - cz) < radius]

    # Generate grid of points
    X, Y, Z = np.meshgrid(x, y, z)
    data = np.vstack((X.ravel(), Y.ravel(), Z.ravel())).T
    distance = cdist(data, center.reshape(1, -1), 'euclidean').ravel()

    # Get voxels within radius
    within_radius = data[distance < radius]

    # Optionally filter to only include voxels inside the mask
    if truncate_at_boundary:
        neighbors_tuple = tuple(within_radius.T.astype(int).tolist())
        mask_filter = mask[neighbors_tuple]
        within_radius = within_radius[mask_filter > 0]

    return tuple(within_radius.T.astype(int).tolist())


[docs] def get_volume_searchlight(mask, radius=2, threshold=1.0, truncate_at_boundary=False): """ Searches through the non-zero voxels of the mask, selects centers where proportion of sphere voxels >= self.threshold. Args: mask ([numpy array]): binary brain mask radius (int, optional): the radius of each searchlight, defined in voxels. Defaults to 2. threshold (float, optional): Threshold of the proportion of voxels that need to be inside the brain mask in order for it to be considered a good searchlight center. Values go between 0.0 - 1.0 where 1.0 means that 100% of the voxels need to be inside the brain mask. Defaults to 1.0. truncate_at_boundary (bool, optional): if True, searchlight spheres will only include voxels where the mask is True, effectively truncating spheres at the mask boundary. if False (default), spheres include all voxels within the radius regardless of mask value (maintains backward compatibility). When False and threshold < 1.0, this can lead to artifacts as reported in issue #466. Setting to True fixes this but may require accounting for different variance characteristics in second-level analysis. Defaults to False. Returns: numpy array: array of centers of size n_centers x 3 list: list of lists with neighbors - the length of the list will correspond to: n_centers x 3 x n_neighbors """ mask = np.array(mask) assert mask.ndim == 3, "Mask needs to be a 3-dimensional numpy array" centers = list(zip(*np.nonzero(mask))) good_centers = [] good_neighbors = [] for center in tqdm(centers, desc='Finding searchlights...'): neighbors = _get_searchlight_neighbors(mask, center, radius, truncate_at_boundary) if mask[neighbors].mean() >= threshold: good_centers.append(center) good_neighbors.append(neighbors) good_centers = np.array(good_centers) assert good_centers.shape[0] == len(good_neighbors),\ "number of centers and sets of neighbors do not match" print(f'Found {len(good_neighbors)} searchlights') # turn the 3-dim coordinates to array coordinates centers = np.ravel_multi_index(good_centers.T, mask.shape) neighbors = [np.ravel_multi_index(n, mask.shape) for n in good_neighbors] return centers, neighbors
[docs] def get_searchlight_RDMs(data_2d, centers, neighbors, events, method='correlation', verbose=True): """Iterates over all the searchlight centers and calculates the RDM Args: data_2d (2D numpy array): brain data, shape n_observations x n_channels (i.e. voxels/vertices) centers (1D numpy array): center indices for all searchlights as provided by rsatoolbox.util.searchlight.get_volume_searchlight neighbors (list): list of lists with neighbor voxel indices for all searchlights as provided by rsatoolbox.util.searchlight.get_volume_searchlight events (1D numpy array): 1D array of length n_observations method (str, optional): distance metric, see rsatoolbox.rdm.calc for options. Defaults to 'correlation'. verbose (bool, optional): Defaults to True. Returns: RDM [rsatoolbox.rdm.RDMs]: RDMs object with the RDM for each searchlight the RDM.rdm_descriptors['voxel_index'] describes the center voxel index each RDM is associated with """ data_2d, centers = np.array(data_2d), np.array(centers) n_centers = centers.shape[0] # For memory reasons, we chunk the data if we have more than 1000 RDMs if n_centers > 1000: # we can't run all centers at once, that will take too much memory # so lets to some chunking chunked_center = np.split(np.arange(n_centers), np.linspace(0, n_centers, 101, dtype=int)[1:-1]) # loop over chunks n_conds = len(np.unique(events)) RDM = np.zeros((n_centers, n_conds * (n_conds - 1) // 2)) for chunks in tqdm(chunked_center, desc='Calculating RDMs...'): center_data = [] for c in chunks: # grab this center and neighbors center = centers[c] center_neighbors = neighbors[c] # create a database object with this data ds = Dataset(data_2d[:, center_neighbors], descriptors={'center': center}, obs_descriptors={'events': events}, channel_descriptors={'voxels': center_neighbors}) center_data.append(ds) RDM_corr = calc_rdm(center_data, method=method, descriptor='events') RDM[chunks, :] = RDM_corr.dissimilarities else: center_data = [] for c in range(n_centers): # grab this center and neighbors center = centers[c] nb = neighbors[c] # create a database object with this data ds = Dataset(data_2d[:, nb], descriptors={'center': c}, obs_descriptors={'events': events}, channel_descriptors={'voxels': nb}) center_data.append(ds) # calculate RDMs for each database object RDM = calc_rdm(center_data, method=method, descriptor='events').dissimilarities SL_rdms = RDMs(RDM, rdm_descriptors={'voxel_index': centers}, dissimilarity_measure=method) return SL_rdms
[docs] def evaluate_models_searchlight(sl_RDM, models, eval_function, method='corr', theta=None, n_jobs=1): """evaluates each searchlighth with the given model/models Args: sl_RDM ([rsatoolbox.rdm.RDMs]): RDMs object as computed by rsatoolbox.util.searchlight.get_searchlight_RDMs models ([rsatoolbox.model]: models to evaluate - can also be list of models eval_function (rsatoolbox.inference evaluation-function): [description] method (str, optional): see rsatoolbox.rdm.compare for specifics. Defaults to 'corr'. n_jobs (int, optional): how many jobs to run. Defaults to 1. Returns: list: list of with the model evaluation for each searchlight center """ results = Parallel(n_jobs=n_jobs)( delayed(eval_function)( models, x, method=method, theta=theta) for x in tqdm( sl_RDM, desc='Evaluating models for each searchlight')) return results