Source code for rsatoolbox.rdm.rdms

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Definition of RSA RDMs class and subclasses

@author: baihan
"""
from __future__ import annotations
from typing import Dict, Optional, Union, List, overload, IO
import warnings
from copy import deepcopy
from collections.abc import Iterable
import numpy as np
from rsatoolbox.io.pandas import rdms_to_df
from rsatoolbox.rdm.combine import _mean
from rsatoolbox.util.rdm_utils import batch_to_vectors
from rsatoolbox.util.rdm_utils import batch_to_matrices
from rsatoolbox.util.descriptor_utils import format_descriptor
from rsatoolbox.util.descriptor_utils import num_index
from rsatoolbox.util.descriptor_utils import subset_descriptor
from rsatoolbox.util.descriptor_utils import check_descriptor_length_error
from rsatoolbox.util.descriptor_utils import append_descriptor
from rsatoolbox.util.descriptor_utils import dict_to_list
from rsatoolbox.util.descriptor_utils import desc_eq
from rsatoolbox.util.data_utils import extract_dict
from rsatoolbox.rdm.combine import _merged_rdm_descriptors
from rsatoolbox.io.hdf5 import read_dict_hdf5, write_dict_hdf5
from rsatoolbox.io.pkl import read_dict_pkl, write_dict_pkl
from rsatoolbox.util.file_io import remove_file


[docs] class RDMs: """RDMs class Args: dissimilarities (numpy.ndarray): either a 2d np-array (n_rdm x vectorform of dissimilarities) or a 3d np-array (n_rdm x n_cond x n_cond) dissimilarity_measure (String): a description of the dissimilarity measure (e.g. 'Euclidean') descriptors (dict): descriptors with 1 value per RDMs object rdm_descriptors (dict): descriptors with 1 value per RDM pattern_descriptors (dict): descriptors with 1 value per RDM column Attributes: n_rdm(int): number of rdms n_cond(int): number of patterns """ dissimilarities: np.ndarray dissimilarity_measure: Optional[str] descriptors: Dict rdm_descriptors: Dict pattern_descriptors: Dict def __init__( self, dissimilarities, dissimilarity_measure=None, descriptors=None, rdm_descriptors=None, pattern_descriptors=None, ): self.dissimilarities, self.n_rdm, self.n_cond = batch_to_vectors(dissimilarities) if descriptors is None: self.descriptors = {} else: self.descriptors = descriptors if rdm_descriptors is None: self.rdm_descriptors = {} else: for k, v in rdm_descriptors.items(): if not isinstance(v, Iterable) or isinstance(v, str): rdm_descriptors[k] = [v] check_descriptor_length_error(rdm_descriptors, "rdm_descriptors", self.n_rdm) self.rdm_descriptors = rdm_descriptors if pattern_descriptors is None: self.pattern_descriptors = {} else: for k, v in pattern_descriptors.items(): if not isinstance(v, Iterable) or isinstance(v, str): pattern_descriptors[k] = [v] check_descriptor_length_error(pattern_descriptors, "pattern_descriptors", self.n_cond) self.pattern_descriptors = pattern_descriptors if "index" not in self.pattern_descriptors.keys(): self.pattern_descriptors["index"] = list(range(self.n_cond)) if "index" not in self.rdm_descriptors.keys(): self.rdm_descriptors["index"] = list(range(self.n_rdm)) self.dissimilarity_measure = dissimilarity_measure def __repr__(self): """ defines string which is printed for the object """ return ( f"rsatoolbox.rdm.{self.__class__.__name__}(\n" f"dissimilarity_measure = \n{self.dissimilarity_measure}\n" f"dissimilarities = \n{self.dissimilarities}\n" f"descriptors = \n{self.descriptors}\n" f"rdm_descriptors = \n{self.rdm_descriptors}\n" f"pattern_descriptors = \n{self.pattern_descriptors}\n" ) def __eq__(self, other: object) -> bool: """Test for equality This magic method gets called when you compare two RDMs objects: `rdms1 == rdms2`. True if the objects are of the same type, and dissimilarities and descriptors are equal. Args: other (RDMs): The second RDMs object to compare this one with Returns: bool: True if equal """ if isinstance(other, RDMs): return all( [ np.all(self.dissimilarities == other.dissimilarities), self.descriptors == other.descriptors, desc_eq(self.rdm_descriptors, other.rdm_descriptors), desc_eq(self.pattern_descriptors, other.pattern_descriptors), ] ) return False def __str__(self): """ defines the output of print """ string_desc = format_descriptor(self.descriptors) rdm_desc = format_descriptor(self.rdm_descriptors) pattern_desc = format_descriptor(self.pattern_descriptors) diss = self.get_matrices()[0] return ( f"rsatoolbox.rdm.{self.__class__.__name__}\n" f"{self.n_rdm} RDM(s) over {self.n_cond} conditions\n\n" f"dissimilarity_measure = \n{self.dissimilarity_measure}\n\n" f"dissimilarities[0] = \n{diss}\n\n" f"descriptors: \n{string_desc}\n" f"rdm_descriptors: \n{rdm_desc}\n" f"pattern_descriptors: \n{pattern_desc}\n" ) def __getitem__(self, idx): """ allows indexing with [] and iterating over RDMs with `for rdm in rdms:` """ dissimilarities = self.dissimilarities[np.array(idx)].reshape( -1, self.dissimilarities.shape[1] ) rdm_descriptors = subset_descriptor(self.rdm_descriptors, idx) rdms = RDMs( dissimilarities, dissimilarity_measure=self.dissimilarity_measure, descriptors=self.descriptors, rdm_descriptors=rdm_descriptors, pattern_descriptors=self.pattern_descriptors, ) return rdms def __len__(self) -> int: """ The number of RDMs in this stack. Together with __getitem__, allows `reversed(rdms)`. """ return self.n_rdm
[docs] def get_vectors(self): """Returns RDMs as np.ndarray with each RDM as a vector Returns: numpy.ndarray: RDMs as a matrix with one row per RDM """ return self.dissimilarities
[docs] def get_matrices(self): """Returns RDMs as np.ndarray with each RDM as a matrix Returns: numpy.ndarray: RDMs as a 3-Tensor with one matrix per RDM """ matrices, _, _ = batch_to_matrices(self.dissimilarities) return matrices
[docs] def copy(self) -> RDMs: """Return a copy of this object, with all properties equal to the original's Returns: RDMs: Value copy """ return RDMs( dissimilarities=self.dissimilarities.copy(), dissimilarity_measure=self.dissimilarity_measure, descriptors=deepcopy(self.descriptors), rdm_descriptors=deepcopy(self.rdm_descriptors), pattern_descriptors=deepcopy(self.pattern_descriptors), )
[docs] def subset_pattern(self, by, value): """Returns a smaller RDMs with patterns with certain descriptor values Args: by(String): the descriptor by which the subset selection is made from pattern_descriptors value: the value by which the subset selection is made from pattern_descriptors Returns: RDMs object, with fewer patterns """ if by is None: by = "index" if not isinstance(value, Iterable): value = [value] selection = num_index(self.pattern_descriptors[by], value) ix, iy = np.triu_indices(self.n_cond, 1) pattern_in_value = np.array([p in value for p in self.pattern_descriptors[by]]) selection_xy = pattern_in_value[ix] & pattern_in_value[iy] dissimilarities = self.dissimilarities[:, selection_xy] descriptors = self.descriptors pattern_descriptors = extract_dict(self.pattern_descriptors, selection) rdm_descriptors = self.rdm_descriptors dissimilarity_measure = self.dissimilarity_measure rdms = RDMs( dissimilarities=dissimilarities, descriptors=descriptors, rdm_descriptors=rdm_descriptors, pattern_descriptors=pattern_descriptors, dissimilarity_measure=dissimilarity_measure, ) return rdms
[docs] def subsample_pattern(self, by, value): """Returns a subsampled RDMs with repetitions if values are repeated This function now generates Nans where the off-diagonal 0s would appear. These values are trivial to predict for models and thus need to be marked and excluded from the evaluation. Args: by(String): the descriptor by which the subset selection is made from descriptors value: the value(s) by which the subset selection is made from descriptors Returns: RDMs object, with subsampled patterns """ if by is None: by = "index" desc = np.array(self.pattern_descriptors[by]) # desc is list-like if isinstance(value, (list, tuple, np.ndarray)): selection = [np.asarray(desc == i).nonzero()[0] for i in value] selection = np.concatenate(selection) else: selection = np.where(desc == value)[0] selection = np.sort(selection) dissimilarities = self.get_matrices() for i_rdm in range(self.n_rdm): np.fill_diagonal(dissimilarities[i_rdm], np.nan) selection = np.sort(selection) dissimilarities = dissimilarities[:, selection][:, :, selection] descriptors = self.descriptors pattern_descriptors = extract_dict(self.pattern_descriptors, selection) rdm_descriptors = self.rdm_descriptors dissimilarity_measure = self.dissimilarity_measure rdms = RDMs( dissimilarities=dissimilarities, descriptors=descriptors, rdm_descriptors=rdm_descriptors, pattern_descriptors=pattern_descriptors, dissimilarity_measure=dissimilarity_measure, ) return rdms
[docs] def subset(self, by, value): """Returns a set of fewer RDMs matching descriptor values Args: by(String): the descriptor by which the subset selection is made from descriptors value: the value by which the subset selection is made from descriptors Returns: RDMs object, with fewer RDMs """ if by is None: by = "index" selection = num_index(self.rdm_descriptors[by], value) dissimilarities = self.dissimilarities[selection, :] descriptors = self.descriptors pattern_descriptors = self.pattern_descriptors rdm_descriptors = extract_dict(self.rdm_descriptors, selection) dissimilarity_measure = self.dissimilarity_measure rdms = RDMs( dissimilarities=dissimilarities, descriptors=descriptors, rdm_descriptors=rdm_descriptors, pattern_descriptors=pattern_descriptors, dissimilarity_measure=dissimilarity_measure, ) return rdms
[docs] def subsample(self, by, value): """Returns a subsampled RDMs with repetitions if values are repeated Args: by(String): the descriptor by which the subset selection is made from descriptors value: the value by which the subset selection is made from descriptors Returns: RDMs object, with subsampled RDMs """ if by is None: by = "index" desc = self.rdm_descriptors[by] selection = [] if isinstance(value, (list, tuple, np.ndarray)): for i in value: for j, d in enumerate(desc): if d == i: selection.append(j) else: for j, d in enumerate(desc): if d == value: selection.append(j) dissimilarities = self.dissimilarities[selection, :] descriptors = self.descriptors pattern_descriptors = self.pattern_descriptors rdm_descriptors = extract_dict(self.rdm_descriptors, selection) dissimilarity_measure = self.dissimilarity_measure rdms = RDMs( dissimilarities=dissimilarities, descriptors=descriptors, rdm_descriptors=rdm_descriptors, pattern_descriptors=pattern_descriptors, dissimilarity_measure=dissimilarity_measure, ) return rdms
[docs] def append(self, rdm): """appends an rdm to the object The rdm should have the same shape and type as this object. Its pattern_descriptor and descriptor are ignored Args: rdm(rsatoolbox.rdm.RDMs): the rdm to append Returns: """ assert isinstance(rdm, RDMs), "appended rdm should be an RDMs" assert rdm.n_cond == self.n_cond, "appended rdm had wrong shape" assert ( rdm.dissimilarity_measure == self.dissimilarity_measure ), "appended rdm had wrong dissimilarity measure" self.dissimilarities = np.concatenate((self.dissimilarities, rdm.dissimilarities), axis=0) self.rdm_descriptors = append_descriptor(self.rdm_descriptors, rdm.rdm_descriptors) self.n_rdm = self.n_rdm + rdm.n_rdm
[docs] def save(self, filename, file_type="hdf5", overwrite=False): """saves the RDMs object into a file Args: filename(String): path to file to save to [or opened file] file_type(String): Type of file to create: hdf5: hdf5 file pkl: pickle file overwrite(Boolean): overwrites file if it already exists """ rdm_dict = self.to_dict() if overwrite: remove_file(filename) if file_type == "hdf5": write_dict_hdf5(filename, rdm_dict) elif file_type == "pkl": write_dict_pkl(filename, rdm_dict)
[docs] def to_dict(self): """converts the object into a dictionary, which can be saved to disk Returns: rdm_dict(dict): dictionary containing all information required to recreate the RDMs object """ rdm_dict = {} rdm_dict["dissimilarities"] = self.dissimilarities rdm_dict["descriptors"] = self.descriptors rdm_dict["rdm_descriptors"] = self.rdm_descriptors rdm_dict["pattern_descriptors"] = self.pattern_descriptors rdm_dict["dissimilarity_measure"] = self.dissimilarity_measure return rdm_dict
[docs] def to_df(self): """Return a new long-form pandas DataFrame representing this RDM See `rsatoolbox.io.pandas.rdms_to_df` for details Returns: pandas.DataFrame: The DataFrame for this RDMs object """ return rdms_to_df(self)
[docs] def reorder(self, new_order): """Reorder the patterns according to the index in new_order Args: new_order (numpy.ndarray): new order of patterns, vector of length equal to the number of patterns """ matrices = self.get_matrices() matrices = matrices[(slice(None),) + np.ix_(new_order, new_order)] self.dissimilarities = batch_to_vectors(matrices)[0] for dname, descriptors in self.pattern_descriptors.items(): self.pattern_descriptors[dname] = [descriptors[idx] for idx in new_order]
[docs] def sort_by(self, reindex: bool = True, **kwargs): """Reorder the patterns by sorting a descriptor Args: reindex (bool): whether to reset the 'index' descriptor following sorting Pass keyword arguments that correspond to descriptors, with value indicating the sort type. Supported methods: 'alpha': sort alphabetically (using np.sort) list/np.array: specify the new order explicitly. Values should correspond to the descriptor values Examples: The following code sorts the 'condition' descriptor alphabetically: :: rdms.sort_by(condition='alpha') The following code sort the 'condition' descriptor in the order 1, 3, 2, 4, 5: :: rdms.sort_by(condition=[1, 3, 2, 4, 5]) Raises: ValueError: Raised if the method chosen is not implemented """ for dname, method in kwargs.items(): if method == "alpha": descriptor = self.pattern_descriptors[dname] self.reorder(np.argsort(descriptor, kind="stable")) elif isinstance(method, (list, np.ndarray)): # in this case, `method` is the desired descriptor order new_order = method descriptor = self.pattern_descriptors[dname] if not set(descriptor).issubset(new_order): raise ValueError( f"Expected {method} to be a permutation \ or subset of {descriptor}" ) # convert to indices to use `reorder` method self.reorder([list(descriptor).index(x) for x in new_order]) else: raise ValueError(f"Unknown sorting method: {method}") if reindex: self.pattern_descriptors["index"] = list(range(self.n_cond))
[docs] def mean(self, weights=None): """Average rdm of all rdms contained Args: weights (str or ndarray, optional): One of: None: No weighting applied str: Use the weights contained in the `rdm_descriptor` with this name ndarray: Weights array of the shape of RDMs.dissimilarities Returns: `rsatoolbox.rdm.rdms.RDMs`: New RDMs object with one vector """ if str(weights) in self.rdm_descriptors: new_descriptors = {(k, v) for (k, v) in self.descriptors.items() if k != weights} weights = self.rdm_descriptors[weights] else: new_descriptors = deepcopy(self.descriptors) return RDMs( dissimilarities=np.array([_mean(self.dissimilarities, weights)]), dissimilarity_measure=self.dissimilarity_measure, descriptors=new_descriptors, pattern_descriptors=deepcopy(self.pattern_descriptors), )
[docs] def rdms_from_dict(rdm_dict): """creates a RDMs object from a dictionary Args: rdm_dict(dict): dictionary with information Returns: rdms(RDMs): the regenerated RDMs object """ rdms = RDMs( dissimilarities=rdm_dict["dissimilarities"], descriptors=rdm_dict["descriptors"], rdm_descriptors=dict_to_list(rdm_dict["rdm_descriptors"]), pattern_descriptors=dict_to_list(rdm_dict["pattern_descriptors"]), dissimilarity_measure=rdm_dict["dissimilarity_measure"], ) return rdms
[docs] def load_rdm(filename: Union[str, IO], file_type: Optional[str] = None) -> RDMs: """loads a RDMs object from disk Args: filename(String): path to file to load or open file file_type(String): Type of file to load: "hdf5": hdf5 file "pkl": pickle file the file type is optional for filenames, for which the type is inferred from the file extension """ if file_type is None: if isinstance(filename, str): if filename[-4:] == ".pkl": file_type = "pkl" elif filename[-3:] == ".h5" or filename[-4:] == "hdf5": file_type = "hdf5" if file_type == "hdf5": rdm_dict = read_dict_hdf5(filename) elif file_type == "pkl": rdm_dict = read_dict_pkl(filename) else: raise ValueError("filetype not understood") return rdms_from_dict(rdm_dict)
@overload def concat(*rdms: List[RDMs], target_pdesc: Optional[str] = None) -> RDMs: ... @overload def concat(*rdms: RDMs, target_pdesc: Optional[str] = None) -> RDMs: ...
[docs] def concat(*rdms, target_pdesc: Optional[str] = None) -> RDMs: """Merge into single RDMs object requires that the rdms have the same shape descriptor and pattern descriptors are taken from the first rdms object for rdm_descriptors concatenation is tried the rdm index is reinitialized Args: rdms(iterable of rsatoolbox.rdm.RDMs): RDMs objects to be concatenated or multiple RDMs as separate arguments target_pdesc(optional, str): a pattern descriptor to use for sorting Returns: rsatoolbox.rdm.RDMs: concatenated rdms object """ if len(rdms) == 1: # single argument if isinstance(rdms[0], RDMs): rdms_list = [rdms[0]] else: rdms_list = list(rdms[0]) else: # multiple arguments rdms_list = list(rdms) assert isinstance( rdms_list[0], RDMs ), "Supply list of RDMs objects, or RDMs objects as separate arguments" descriptors, rdm_descriptors = _merged_rdm_descriptors(rdms_list) if target_pdesc is None: # see if we can find an authoritative descriptor for pattern order pdescs = rdms_list[0].pattern_descriptors.keys() pdesc_candidates = list( filter( lambda n: n != "index" and ( len(rdms_list[0].pattern_descriptors[n]) == len(set(rdms_list[0].pattern_descriptors[n])) ), pdescs, ) ) target_pdesc = None if len(pdesc_candidates) > 0: target_pdesc = pdesc_candidates[0] if len(pdesc_candidates) > 1: warnings.warn(f'[concat] Multiple pattern descriptors found, using "{target_pdesc}"') else: assert ( target_pdesc in rdms_list[0].pattern_descriptors.keys() ), "The provided descriptor is not a pattern descriptor" assert ( len(rdms_list[0].pattern_descriptors[target_pdesc]) == rdms_list[0].n_cond ), "The provided descriptor is not unique" for rdm_new in rdms_list[1:]: assert isinstance(rdm_new, RDMs), "rdm for concat should be an RDMs" assert rdm_new.n_cond == rdms_list[0].n_cond, "rdm for concat had wrong shape" assert ( rdm_new.dissimilarity_measure == rdms_list[0].dissimilarity_measure ), "appended rdm had wrong dissimilarity measure" if target_pdesc: # if we have a target descriptor, check if the order is the same auth_order = rdms_list[0].pattern_descriptors[target_pdesc] other_order = rdm_new.pattern_descriptors[target_pdesc] if not np.all(other_order == auth_order): # order varies; reorder this rdms object _, new_order = np.where(auth_order[:, None] == other_order) rdm_new.reorder(new_order) dissimilarities = np.concatenate([rdm.dissimilarities for rdm in rdms_list], axis=0) # Set dissimilarity measure if it's the same for all rdms in list if len(set(r.dissimilarity_measure for r in rdms_list)) == 1: dissimilarity_measure = rdms_list[0].dissimilarity_measure else: dissimilarity_measure = None rdm = RDMs( dissimilarities=dissimilarities, dissimilarity_measure=dissimilarity_measure, rdm_descriptors=rdm_descriptors, descriptors=descriptors, pattern_descriptors=rdms_list[0].pattern_descriptors, ) return rdm
[docs] def permute_rdms(rdms, p=None): """Permute rows, columns and corresponding pattern descriptors of RDM matrices according to a permutation vector Args: p (numpy.ndarray): permutation vector (values must be unique integers from 0 to n_cond of RDM matrix). If p = None, a random permutation vector is created. Returns: rdm_p(rsatoolbox.rdm.RDMs): the rdm object with a permuted matrix and pattern descriptors """ if p is None: p = np.random.permutation(rdms.n_cond) print("No permutation vector specified," + " performing random permutation.") assert p.dtype == "int", "permutation vector must have integer entries." assert ( min(p) == 0 and max(p) == rdms.n_cond - 1 ), "permutation vector must have entries ranging from 0 to n_cond" assert ( len(np.unique(p)) == rdms.n_cond ), "permutation vector must only have unique integer entries" rdm_mats = rdms.get_matrices() descriptors = rdms.descriptors.copy() rdm_descriptors = rdms.rdm_descriptors.copy() pattern_descriptors = rdms.pattern_descriptors.copy() # To easily reverse permutation later p_inv = np.arange(len(p))[np.argsort(p)] descriptors.update({"p_inv": p_inv}) rdm_mats = rdm_mats[:, p, :] rdm_mats = rdm_mats[:, :, p] stims = np.array(pattern_descriptors["index"]) pattern_descriptors.update({"index": list(stims[p].astype(np.str_))}) rdms_p = RDMs( dissimilarities=rdm_mats, descriptors=descriptors, rdm_descriptors=rdm_descriptors, pattern_descriptors=pattern_descriptors, ) return rdms_p
[docs] def inverse_permute_rdms(rdms): """Gimmick function to reverse the effect of permute_rdms()""" p_inv = rdms.descriptors["p_inv"] rdms_p = permute_rdms(rdms, p=p_inv) return rdms_p
[docs] def get_categorical_rdm(category_vector, category_name="category"): """generates an RDM object containing a categorical RDM, i.e. RDM = 0 if the category is the same and 1 if they are different Args: category_vector(iterable): a category index per condition category_name(String): name for the descriptor in the object, defaults to 'category' Returns: rsatoolbox.rdm.RDMs: constructed RDM """ n = len(category_vector) rdm_list = [] for i_cat in range(n): for j_cat in range(i_cat + 1, n): if isinstance(category_vector[i_cat], Iterable): comparisons = [ np.array(category_vector[i_cat][idx]) != np.array(category_vector[j_cat][idx]) for idx in range(len(category_vector[i_cat])) ] rdm_list.append(np.any(comparisons)) else: rdm_list.append(category_vector[i_cat] != category_vector[j_cat]) rdm = RDMs( np.array(rdm_list, dtype=float), pattern_descriptors={category_name: np.array(category_vector)}, ) return rdm