#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This code was initially inspired by the following :
https://github.com/machow/pysearchlight
@author: Daniel Lindh
"""
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm
from joblib import Parallel, delayed
from rsatoolbox.data.dataset import Dataset
from rsatoolbox.rdm.calc import calc_rdm
from rsatoolbox.rdm import RDMs
def _get_searchlight_neighbors(mask, center, radius=3, truncate_at_boundary=False):
"""Return indices for searchlight where distance
between a voxel and their center < radius (in voxels)
Args:
center (index): point around which to make searchlight sphere
radius (int): radius of the searchlight sphere in voxels
truncate_at_boundary (bool): if True, only include voxels where mask is True.
if False (default), include all voxels within radius regardless of mask value.
Setting to True addresses issue #466 by preventing searchlights from including
voxels outside the mask boundary.
Returns:
list: the list of volume indices that respect the
searchlight radius for the input center.
"""
center = np.array(center)
mask_shape = mask.shape
cx, cy, cz = np.array(center)
x = np.arange(mask_shape[0])
y = np.arange(mask_shape[1])
z = np.arange(mask_shape[2])
# First mask the obvious points
# - may actually slow down your calculation depending.
x = x[abs(x - cx) < radius]
y = y[abs(y - cy) < radius]
z = z[abs(z - cz) < radius]
# Generate grid of points
X, Y, Z = np.meshgrid(x, y, z)
data = np.vstack((X.ravel(), Y.ravel(), Z.ravel())).T
distance = cdist(data, center.reshape(1, -1), 'euclidean').ravel()
# Get voxels within radius
within_radius = data[distance < radius]
# Optionally filter to only include voxels inside the mask
if truncate_at_boundary:
neighbors_tuple = tuple(within_radius.T.astype(int).tolist())
mask_filter = mask[neighbors_tuple]
within_radius = within_radius[mask_filter > 0]
return tuple(within_radius.T.astype(int).tolist())
[docs]
def get_volume_searchlight(mask, radius=2, threshold=1.0, truncate_at_boundary=False):
"""
Searches through the non-zero voxels of the mask, selects centers where
proportion of sphere voxels >= self.threshold.
Args:
mask ([numpy array]): binary brain mask
radius (int, optional): the radius of each searchlight, defined in voxels.
Defaults to 2.
threshold (float, optional): Threshold of the proportion of voxels that need to
be inside the brain mask in order for it to be
considered a good searchlight center.
Values go between 0.0 - 1.0 where 1.0 means that
100% of the voxels need to be inside
the brain mask.
Defaults to 1.0.
truncate_at_boundary (bool, optional): if True, searchlight spheres will only
include voxels where the mask is True, effectively truncating spheres at
the mask boundary. if False (default), spheres include all voxels within
the radius regardless of mask value (maintains backward compatibility).
When False and threshold < 1.0, this can lead to artifacts as reported in
issue #466. Setting to True fixes this but may require accounting for
different variance characteristics in second-level analysis.
Defaults to False.
Returns:
numpy array: array of centers of size n_centers x 3
list: list of lists with neighbors - the length of the list will correspond to:
n_centers x 3 x n_neighbors
"""
mask = np.array(mask)
assert mask.ndim == 3, "Mask needs to be a 3-dimensional numpy array"
centers = list(zip(*np.nonzero(mask)))
good_centers = []
good_neighbors = []
for center in tqdm(centers, desc='Finding searchlights...'):
neighbors = _get_searchlight_neighbors(mask, center, radius, truncate_at_boundary)
if mask[neighbors].mean() >= threshold:
good_centers.append(center)
good_neighbors.append(neighbors)
good_centers = np.array(good_centers)
assert good_centers.shape[0] == len(good_neighbors),\
"number of centers and sets of neighbors do not match"
print(f'Found {len(good_neighbors)} searchlights')
# turn the 3-dim coordinates to array coordinates
centers = np.ravel_multi_index(good_centers.T, mask.shape)
neighbors = [np.ravel_multi_index(n, mask.shape) for n in good_neighbors]
return centers, neighbors
[docs]
def get_searchlight_RDMs(data_2d, centers, neighbors, events,
method='correlation', verbose=True):
"""Iterates over all the searchlight centers and calculates the RDM
Args:
data_2d (2D numpy array): brain data,
shape n_observations x n_channels (i.e. voxels/vertices)
centers (1D numpy array): center indices for all searchlights as provided
by rsatoolbox.util.searchlight.get_volume_searchlight
neighbors (list): list of lists with neighbor voxel indices for all searchlights
as provided by rsatoolbox.util.searchlight.get_volume_searchlight
events (1D numpy array): 1D array of length n_observations
method (str, optional): distance metric,
see rsatoolbox.rdm.calc for options. Defaults to 'correlation'.
verbose (bool, optional): Defaults to True.
Returns:
RDM [rsatoolbox.rdm.RDMs]: RDMs object with the RDM for each searchlight
the RDM.rdm_descriptors['voxel_index']
describes the center voxel index each RDM is associated with
"""
data_2d, centers = np.array(data_2d), np.array(centers)
n_centers = centers.shape[0]
# For memory reasons, we chunk the data if we have more than 1000 RDMs
if n_centers > 1000:
# we can't run all centers at once, that will take too much memory
# so lets to some chunking
chunked_center = np.split(np.arange(n_centers),
np.linspace(0, n_centers,
101, dtype=int)[1:-1])
# loop over chunks
n_conds = len(np.unique(events))
RDM = np.zeros((n_centers, n_conds * (n_conds - 1) // 2))
for chunks in tqdm(chunked_center, desc='Calculating RDMs...'):
center_data = []
for c in chunks:
# grab this center and neighbors
center = centers[c]
center_neighbors = neighbors[c]
# create a database object with this data
ds = Dataset(data_2d[:, center_neighbors],
descriptors={'center': center},
obs_descriptors={'events': events},
channel_descriptors={'voxels': center_neighbors})
center_data.append(ds)
RDM_corr = calc_rdm(center_data, method=method,
descriptor='events')
RDM[chunks, :] = RDM_corr.dissimilarities
else:
center_data = []
for c in range(n_centers):
# grab this center and neighbors
center = centers[c]
nb = neighbors[c]
# create a database object with this data
ds = Dataset(data_2d[:, nb],
descriptors={'center': c},
obs_descriptors={'events': events},
channel_descriptors={'voxels': nb})
center_data.append(ds)
# calculate RDMs for each database object
RDM = calc_rdm(center_data, method=method,
descriptor='events').dissimilarities
SL_rdms = RDMs(RDM,
rdm_descriptors={'voxel_index': centers},
dissimilarity_measure=method)
return SL_rdms
[docs]
def evaluate_models_searchlight(sl_RDM, models, eval_function, method='corr', theta=None, n_jobs=1):
"""evaluates each searchlighth with the given model/models
Args:
sl_RDM ([rsatoolbox.rdm.RDMs]): RDMs object
as computed by rsatoolbox.util.searchlight.get_searchlight_RDMs
models ([rsatoolbox.model]: models to evaluate - can also be list of models
eval_function (rsatoolbox.inference evaluation-function): [description]
method (str, optional): see rsatoolbox.rdm.compare for specifics. Defaults to 'corr'.
n_jobs (int, optional): how many jobs to run. Defaults to 1.
Returns:
list: list of with the model evaluation for each searchlight center
"""
results = Parallel(n_jobs=n_jobs)(
delayed(eval_function)(
models, x, method=method, theta=theta) for x in tqdm(
sl_RDM, desc='Evaluating models for each searchlight'))
return results