from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import math
import matplotlib.pyplot
import sklearn.manifold
import numpy
from scipy.spatial.distance import squareform
from rsatoolbox.util.weighted_mds import Weighted_MDS
from rsatoolbox.vis.icon import Icon
from rsatoolbox.util.rdm_utils import _get_n_from_reduced_vectors
if TYPE_CHECKING:
from rsatoolbox.rdm import RDMs
from numpy.typing import NDArray
from matplotlib.figure import Figure
seed = numpy.random.RandomState(seed=1)
[docs]
def show_scatter(
rdms: RDMs,
coords: NDArray,
rdm_descriptor: Optional[str]=None,
pattern_descriptor: Optional[str]=None,
icon_size: float=0.1
) -> Figure:
"""Draw a 2-dimensional scatter plot based on the provided coordinates
Args:
rdms (RDMs): The RDMs object to display
coords (NDArray): Array of x and y coordinates for each
pattern (patterns x 2)
rdm_descriptor: (Optional[str]): If provided, this will be used as
title for each individual RDM.
pattern_descriptor (Optional[str]): If provided, the chosen pattern
descriptor will be printed adjacent to each point in the plot
icon_size: relative size of icons if the pattern descriptor chosen
is of type Icon
Returns:
Figure: A matplotlib figure in which the plot is drawn
"""
frac, n = math.modf(math.sqrt(rdms.n_rdm))
nrows, ncols = math.floor(n), math.floor(n)
if frac > 0:
nrows += 1
if frac > 0.5:
ncols += 1
fig, axes = matplotlib.pyplot.subplots(nrows=nrows, ncols=ncols)
axes = numpy.array(axes) ## it's now an array even if there's only one
for r, ax in enumerate(axes.ravel()):
if r > (rdms.n_rdm - 1):
## fewer rdms than rows x cols, hide the remaining axes
ax.axis('off')
break
ax.scatter(coords[r, :, 0], coords[r, :, 1])
ax.set_xlim(coords.min()*0.95, coords.max()*1.05)
ax.set_ylim(coords.min()*0.95, coords.max()*1.05)
## RDM names
if rdm_descriptor is not None:
ax.set_title(rdms.rdm_descriptors[rdm_descriptor][r])
## print labels next to dots
if pattern_descriptor is not None:
for p in range(coords.shape[1]):
pat_desc = rdms.pattern_descriptors[pattern_descriptor][p]
pat_coords = (coords[r, p, 0], coords[r, p, 1])
if isinstance(pat_desc, Icon):
pat_desc.plot(pat_coords[0], pat_coords[1], ax=ax, size=icon_size)
else:
label = ax.annotate(pat_desc, pat_coords)
label.set_alpha(.6)
## turn off all axis ticks and labels
ax.tick_params(axis='both', which='both', bottom=False, top=False,
right=False, left=False, labelbottom=False, labeltop=False,
labelleft=False, labelright=False)
return fig
[docs]
def show_2d(
rdms: RDMs,
method: str,
weights: Optional[NDArray]=None,
rdm_descriptor: Optional[str]=None,
pattern_descriptor: Optional[str]=None,
icon_size: float=0.1
) -> Figure:
"""Draw a scatter plot of the RDMs reduced to two dimensions
Args:
rdms (RDMs): The RDMs object to display
method (str): One of 'MDS', 't-SNE', 'Isomap'.
weights: Optional array of weights (vector per RDM)
rdm_descriptor: (Optional[str]): If provided, this will be used as
title for each individual RDM.
pattern_descriptor (Optional[str]): If provided, the chosen pattern
descriptor will be printed adjacent to each point in the plot
icon_size: relative size of icons if the pattern descriptor chosen
is of type Icon
Returns:
Figure: A matplotlib figure in which the plot is drawn
"""
if method == 'MDS':
MDS = sklearn.manifold.MDS if weights is None else Weighted_MDS
embedding = MDS(
n_components=2,
random_state=seed,
dissimilarity='precomputed',
normalized_stress='auto',
)
elif method == 't-SNE':
embedding = sklearn.manifold.TSNE(n_components=2)
elif method == 'Isomap':
embedding = sklearn.manifold.Isomap(n_components=2)
else:
raise NotImplementedError('Unknown method: ' + str(method))
rdm_mats = rdms.get_matrices()
coords = numpy.full((rdms.n_rdm, rdms.n_cond, 2), numpy.nan)
for r in range(rdms.n_rdm):
fitKwargs = dict()
if weights is not None:
fitKwargs['weight'] = weight_to_matrices(weights)[r, :, :]
coords[r, :, :] = embedding.fit_transform(rdm_mats[r, :, :], **fitKwargs)
return show_scatter(
rdms,
coords,
rdm_descriptor=rdm_descriptor,
pattern_descriptor=pattern_descriptor,
icon_size=icon_size
)
[docs]
def show_MDS(
rdms: RDMs,
weights: Optional[NDArray]=None,
rdm_descriptor: Optional[str]=None,
pattern_descriptor: Optional[str]=None,
icon_size: float=0.1
) -> Figure:
"""Draw a scatter plot based on Multidimensional Scaling dimensionality reduction
Args:
rdms (RDMs): The RDMs object to display
weights: Optional array of weights (vector per RDM)
rdm_descriptor: (Optional[str]): If provided, this will be used as
title for each individual RDM.
pattern_descriptor (Optional[str]): If provided, the chosen pattern
descriptor will be printed adjacent to each point in the plot
icon_size: relative size of icons if the pattern descriptor chosen
is of type Icon
Returns:
Figure: A matplotlib figure in which the plot is drawn
"""
return show_2d(
rdms,
method='MDS',
weights=weights,
rdm_descriptor=rdm_descriptor,
pattern_descriptor=pattern_descriptor,
icon_size=icon_size
)
[docs]
def show_tSNE(
rdms: RDMs,
rdm_descriptor: Optional[str]=None,
pattern_descriptor: Optional[str]=None,
icon_size: float=0.1
) -> Figure:
"""Draw a scatter plot based on t-SNE dimensionality reduction
Args:
rdms (RDMs): The RDMs object to display
rdm_descriptor: (Optional[str]): If provided, this will be used as
title for each individual RDM.
pattern_descriptor (Optional[str]): If provided, the chosen pattern
descriptor will be printed adjacent to each point in the plot
icon_size: relative size of icons if the pattern descriptor chosen
is of type Icon
Returns:
Figure: A matplotlib figure in which the plot is drawn
"""
return show_2d(
rdms,
method='t-SNE',
rdm_descriptor=rdm_descriptor,
pattern_descriptor=pattern_descriptor,
icon_size=icon_size
)
[docs]
def show_iso(
rdms: RDMs,
rdm_descriptor: Optional[str]=None,
pattern_descriptor: Optional[str]=None,
icon_size: float=0.1
) -> Figure:
"""Draw a scatter plot based on Isomap dimensionality reduction
Args:
rdms (RDMs): The RDMs object to display
rdm_descriptor: (Optional[str]): If provided, this will be used as
title for each individual RDM.
pattern_descriptor (Optional[str]): If provided, the chosen pattern
descriptor will be printed adjacent to each point in the plot
icon_size: relative size of icons if the pattern descriptor chosen
is of type Icon
Returns:
Figure: A matplotlib figure in which the plot is drawn
"""
return show_2d(
rdms,
method='Isomap',
rdm_descriptor=rdm_descriptor,
pattern_descriptor=pattern_descriptor,
icon_size=icon_size
)
[docs]
def weight_to_matrices(x: NDArray) -> NDArray:
"""converts a *stack* of weights in vector or matrix form into matrix form
Args:
**x** (np.ndarray): stack of weight matrices or weight vectors
Returns:
tuple: **v** (np.ndarray): 3D, matrix form of the stack of weight matrices
"""
if x.ndim == 2:
v = x
n_rdm = x.shape[0]
n_cond = _get_n_from_reduced_vectors(x)
m = numpy.ndarray((n_rdm, n_cond, n_cond))
for idx in numpy.arange(n_rdm):
m[idx, :, :] = squareform(v[idx, :])
elif x.ndim == 3:
m = x
else:
raise ValueError('X must have 2 or 3 dimensions')
return m