Source code for rsatoolbox.rdm.transform

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
""" transforms, which can be applied to RDMs
"""
from __future__ import annotations
from copy import deepcopy
import numpy as np
from scipy.stats import rankdata
from .rdms import RDMs


[docs]def rank_transform(rdms: RDMs, method='average'): """ applies a rank_transform and generates a new RDMs object This assigns a rank to each dissimilarity estimate in the RDM, deals with rank ties and saves ranks as new dissimilarity estimates. As an effect, all non-diagonal entries of the RDM will range from 1 to (n_dim²-n_dim)/2, if the RDM has the dimensions n_dim x n_dim. Args: rdms(RDMs): RDMs object method(String): controls how ranks are assigned to equal values options are: ‘average’, ‘min’, ‘max’, ‘dense’, ‘ordinal’ Returns: rdms_new(RDMs): RDMs object with rank transformed dissimilarities """ dissimilarities = rdms.get_vectors() dissimilarities = np.array([rankdata(dissimilarities[i], method=method) for i in range(rdms.n_rdm)]) measure = rdms.dissimilarity_measure or '' if '(ranks)' not in measure: measure = (measure + ' (ranks)').strip() rdms_new = RDMs(dissimilarities, dissimilarity_measure=measure, descriptors=deepcopy(rdms.descriptors), rdm_descriptors=deepcopy(rdms.rdm_descriptors), pattern_descriptors=deepcopy(rdms.pattern_descriptors)) return rdms_new
[docs]def sqrt_transform(rdms): """ applies a square root transform and generates a new RDMs object This sets values blow 0 to 0 and takes a square root of each entry. It also adds a sqrt to the dissimilarity_measure entry. Args: rdms(RDMs): RDMs object Returns: rdms_new(RDMs): RDMs object with sqrt transformed dissimilarities """ dissimilarities = rdms.get_vectors() dissimilarities[dissimilarities < 0] = 0 dissimilarities = np.sqrt(dissimilarities) if rdms.dissimilarity_measure == 'squared euclidean': dissimilarity_measure = 'euclidean' elif rdms.dissimilarity_measure == 'squared mahalanobis': dissimilarity_measure = 'mahalanobis' else: dissimilarity_measure = 'sqrt of' + rdms.dissimilarity_measure rdms_new = RDMs(dissimilarities, dissimilarity_measure=dissimilarity_measure, descriptors=deepcopy(rdms.descriptors), rdm_descriptors=deepcopy(rdms.rdm_descriptors), pattern_descriptors=deepcopy(rdms.pattern_descriptors)) return rdms_new
[docs]def positive_transform(rdms): """ sets all negative entries in an RDM to zero and returns a new RDMs Args: rdms(RDMs): RDMs object Returns: rdms_new(RDMs): RDMs object with sqrt transformed dissimilarities """ dissimilarities = rdms.get_vectors() dissimilarities[dissimilarities < 0] = 0 rdms_new = RDMs(dissimilarities, dissimilarity_measure=rdms.dissimilarity_measure, descriptors=deepcopy(rdms.descriptors), rdm_descriptors=deepcopy(rdms.rdm_descriptors), pattern_descriptors=deepcopy(rdms.pattern_descriptors)) return rdms_new
[docs]def transform(rdms, fun): """ applies an arbitray function ``fun`` to the dissimilarities and returns a new RDMs object. Args: rdms(RDMs): RDMs object Returns: rdms_new(RDMs): RDMs object with sqrt transformed dissimilarities """ dissimilarities = rdms.get_vectors() dissimilarities = fun(dissimilarities) meas = 'transformed ' + rdms.dissimilarity_measure rdms_new = RDMs(dissimilarities, dissimilarity_measure=meas, descriptors=deepcopy(rdms.descriptors), rdm_descriptors=deepcopy(rdms.rdm_descriptors), pattern_descriptors=deepcopy(rdms.pattern_descriptors)) return rdms_new