"""
Plot showing an RDMs object
public API:
- show_rdm()
- show_rdm_panel()
"""
from __future__ import annotations
import itertools
from pathlib import Path
from typing import TYPE_CHECKING, Union, Tuple, Optional, Literal, Dict, Any, List, Iterator
from enum import Enum
from math import ceil
import numpy as np
from scipy.spatial.distance import squareform
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import LinearLocator
from matplotlib.colors import ListedColormap
from matplotlib.patches import Polygon
import rsatoolbox.rdm
from rsatoolbox.rdm.rdms import RDMs
from rsatoolbox import vis
from rsatoolbox.vis.colors import rdm_colormap_classic
from rsatoolbox.resources import get_style
if TYPE_CHECKING:
from matplotlib.axes._axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Colormap
from matplotlib.colorbar import Colorbar
from matplotlib.figure import Figure
from matplotlib.text import Text
from matplotlib.image import AxesImage
from matplotlib.axis import XAxis, YAxis
from numpy.typing import NDArray, ArrayLike
ArrayOrRdmDescriptor = NDArray | Tuple[str, str]
[docs]class Axis(Enum):
"""X or Y axis Enum
"""
X = 'x'
Y = 'y'
[docs]class Symmetry(Enum):
"""RDM Triangle Enum: both, upper or lower
"""
BOTH = 'both'
UPPER = 'upper'
LOWER = 'lower'
[docs]def show_rdm(
rdms: rsatoolbox.rdm.RDMs,
pattern_descriptor: Optional[str] = None,
cmap: Union[str, Colormap] = 'bone_r',
rdm_descriptor: Optional[str] = None,
n_column: Optional[int] = None,
n_row: Optional[int] = None,
show_colorbar: Optional[str] = None,
gridlines: Optional[ArrayLike] = None,
num_pattern_groups: Optional[int] = None,
figsize: Optional[Tuple[float, float]] = None,
nanmask: NDArray | str | None = "diagonal",
style: Optional[Union[str, Path]] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
icon_spacing: float = 1.0,
linewidth: float = 0.5,
overlay: Optional[ArrayOrRdmDescriptor] = None,
overlay_color: str = '#00ff0050',
overlay_symmetry: Symmetry = Symmetry.BOTH,
contour: Optional[ArrayOrRdmDescriptor] = None,
contour_color: str = 'red',
contour_symmetry: Symmetry = Symmetry.BOTH,
) -> Tuple[Figure, NDArray, Dict[int, Dict[str, Any]]]:
"""show_rdm. Heatmap figure for RDMs instance, with one panel per RDM.
Args:
rdm (rsatoolbox.rdm.RDMs): RDMs object to be plotted.
pattern_descriptor (str): Key into rdm.pattern_descriptors to use for axis
labels.
cmap (str or Colormap): Colormap to be used.
Either the name of a Matplotlib built-in colormap, a Matplotlib
Colormap compatible object, or 'classic' for the matlab toolbox
colormap. Defaults to 'bone_r'.
rdm_descriptor (str): Key for rdm_descriptor to use as panel title, or
str for direct labeling.
n_column (int): Number of columns in subplot arrangement.
n_row (int): Number of rows in subplot arrangement.
show_colorbar (str): Set to 'panel' or 'figure' to display a colorbar. If
'panel' a colorbar is added next to each RDM. If 'figure' a shared colorbar
(and scale) is used across panels.
gridlines (ArrayLike): Set to add gridlines at these positions. If
num_pattern_groups is defined this is used to infer gridlines.
num_pattern_groups (int): Number of rows/columns for any image labels. Also
determines gridlines frequency by default (so e.g., num_pattern_groups=3
with results in gridlines every 3 rows/columns).
figsize (Tuple[float, float]): mpl.Figure argument. By default we
auto-scale to achieve a figure that fits on a standard A4 / US Letter page
in portrait orientation.
nanmask (Union[ArrayLike, str, None]): boolean mask defining RDM elements to suppress
(by default, the diagonal).
Use the string "diagonal" to suppress the diagonal.
style (Union[str, Path]): Path to mplstyle file that controls
various figure aesthetics (default rsatoolbox/vis/rdm.mplstyle).
vmin (float): Minimum intensity for colorbar mapping. matplotlib imshow
argument.
vmax (float): Maximum intensity for colorbar mapping. matplotlib imshow
argument.
icon_spacing (float): control spacing of image labels - 1. means no gap (the
default), 1.1 means pad 10%, .9 means overlap 10% etc.
linewidth (float): Width of connecting lines from icon labels (if used) to axis
margin. The default is 0.5 - set to 0. to disable the lines.
overlay ((str, str) or NDArray): RDM descriptor name-value tuple,
or vector (one value per pair) which indicates whether to highlight the given cells
overlay_color (str): Color to use to highlight the pairs in the overlay argument.
Use RGBA to specify transparency. Default is 50% opaque green.
contour ((str, str) or NDArray): RDM descriptor name-value tuple,
or vector (one value per pair) which indicates whether to add a border
to the given cells
contour_color (str): Color to use for a border around pairs in the contour argument.
Use RGBA to specify transparency. Default is red.
Returns:
Tuple[Figure, ArrayLike, Dict]:
Tuple of
- Handle to created figure.
- Subplot axis handles from plt.subplots.
- Nested dict containing handles to all other plotted
objects (icon labels, colorbars, etc). The key at the first level is the axis index.
"""
# create a plot "configuration" object which resolves all parameters
conf = MultiRdmPlot.from_show_rdm_args(
rdms, pattern_descriptor, cmap, rdm_descriptor, n_column, n_row,
show_colorbar, gridlines, num_pattern_groups, figsize, nanmask,
style, vmin, vmax, icon_spacing, linewidth, overlay, overlay_color,
overlay_symmetry, contour, contour_color, contour_symmetry
)
return _plot_multi_rdm(conf)
def _plot_multi_rdm(conf: MultiRdmPlot) -> Tuple[Figure, NDArray, Dict[int, Dict[str, Any]]]:
# A dictionary of figure element handles
handles = dict()
handles[-1] = dict() # fig level handles
# create a list of (row index, column index) tuples
rc_tuples = list(itertools.product(range(conf.n_row), range(conf.n_column)))
# number of empty panels at the top
n_empty = (conf.n_row * conf.n_column) - conf.rdms.n_rdm
with plt.style.context(conf.style):
fig, ax_array = plt.subplots(
nrows=conf.n_row,
ncols=conf.n_column,
sharex=True,
sharey=True,
squeeze=False,
figsize=conf.figsize,
)
p = 0
for p, (r, c) in enumerate(rc_tuples):
handles[p] = dict()
rdm_index = p - n_empty # rdm index
if rdm_index < 0:
ax_array[r, c].set_visible(False)
continue
handles[p]["image"] = _show_rdm_panel(conf.for_single(rdm_index), ax_array[r, c])
if conf.show_colorbar == "panel":
# needs to happen before labels because it resizes the axis
handles[p]["colorbar"] = _rdm_colorbar(
mappable=handles[p]["image"],
fig=fig,
ax=ax_array[r, c],
title=conf.dissimilarity_measure
)
if c == 0 and conf.pattern_descriptor:
handles[p]["y_labels"] = _add_descriptor_labels(Axis.Y, ax_array[r, c], conf)
if r == 0 and conf.pattern_descriptor:
handles[p]["x_labels"] = _add_descriptor_labels(Axis.X, ax_array[r, c], conf)
if conf.show_colorbar == "figure":
handles[-1]["colorbar"] = _rdm_colorbar(
mappable=handles[p]["image"],
fig=fig,
ax=ax_array[0, 0],
title=conf.dissimilarity_measure,
)
_adjust_colorbar_pos(handles[-1]["colorbar"], ax_array[0, 0])
return fig, ax_array, handles
def _adjust_colorbar_pos(cb: Colorbar, parent: Axes) -> None:
"""Moves figure-level colorbar to the right position
Args:
cb (Colorbar): The matplotlib colorbar object
parent (Axes): Parent object axes
"""
# key challenge is to obtain a similarly-sized colorbar to the 'panel' case
# BUT positioned centered on the reserved subplot axes
# parent = ax_array[-1, -1]
cbax_parent_orgpos = parent.get_position(original=True)
# use last instance of 'image' (should all be yoked at this point)
cbax_pos = cb.ax.get_position()
# halfway through panel, less the width/height of the colorbar itself
x0 = (
cbax_parent_orgpos.x0
+ cbax_parent_orgpos.width / 2
- cbax_pos.width / 2
)
y0 = (
cbax_parent_orgpos.y0
+ cbax_parent_orgpos.height / 2
- cbax_pos.height / 2
)
cb.ax.set_position((x0, y0, cbax_pos.width, cbax_pos.height))
def _rdm_colorbar(mappable: ScalarMappable, fig: Figure, ax: Axes, title: str) -> Colorbar:
"""_rdm_colorbar. Add vertically-oriented, small colorbar to rdm figure. Used
internally by show_rdm.
Args:
mappable (matplotlib.cm.ScalarMappable): Typically plt.imshow instance.
fig (matplotlib.figure.Figure): Matplotlib figure handle.
ax (matplotlib.axes._axes.Axes): Matplotlib axis handle. plt.gca() by default.
title (str): Title string for the colorbar (positioned top, left aligned).
Returns:
matplotlib.colorbar.Colorbar: Matplotlib handle.
"""
cb = fig.colorbar(
mappable=mappable,
ax=ax,
shrink=0.25,
aspect=5,
ticks=LinearLocator(numticks=3),
)
cb.ax.set_title(title, loc="left", fontdict=dict(fontweight="normal"))
return cb
[docs]def show_rdm_panel(
rdms: rsatoolbox.rdm.RDMs,
ax: Optional[Axes] = None,
cmap: Union[str, Colormap] = 'bone_r',
nanmask: Optional[NDArray] = None,
rdm_descriptor: Optional[str] = None,
gridlines: Optional[ArrayLike] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
overlay: Optional[NDArray] = None,
overlay_color: str = '#00ff0050',
overlay_symmetry: Symmetry = Symmetry.BOTH,
contour: Optional[NDArray] = None,
contour_color: str = 'red',
contour_symmetry: Symmetry = Symmetry.BOTH
) -> AxesImage:
"""show_rdm_panel. Add RDM heatmap to the axis ax.
Args:
rdm (rsatoolbox.rdm.RDMs): RDMs object to be plotted (n_rdm must be 1).
ax (matplotlib.axes._axes.Axes): Matplotlib axis handle. plt.gca() by default.
cmap (str or Colormap): Colormap to be used.
Either the name of a Matplotlib built-in colormap, a Matplotlib
Colormap compatible object, or 'classic' for the matlab toolbox
colormap. Defaults to 'bone_r'.
nanmask (ArrayLike): boolean mask defining RDM elements to suppress
(by default, the diagonals).
rdm_descriptor (str): Key for rdm_descriptor to use as panel title, or
str for direct labeling.
gridlines (ArrayLike): Set to add gridlines at these positions.
vmin (float): Minimum intensity for colorbar mapping. matplotlib imshow
argument.
vmax (float): Maximum intensity for colorbar mapping. matplotlib imshow
argument.
overlay ((str, str) or NDArray): RDM descriptor name-value tuple, or vector
(one value per pair) which indicates whether to highlight the given cells
overlay_color (str): Color to use to highlight the pairs in the overlay argument.
Use RGBA to specify transparency. Default is 50% opaque green.
contour ((str, str) or NDArray): RDM descriptor name-value tuple, or vector
(one value per pair) which indicates whether to add a border to the given cells
contour_color (str): Color to use for a border around pairs in the contour argument.
Use RGBA to specify transparency. Default is red.
Returns:
matplotlib.image.AxesImage: Matplotlib handle.
"""
conf = SingleRdmPlot.from_show_rdm_panel_args(
rdms, cmap, nanmask,
rdm_descriptor, gridlines, vmin, vmax, overlay, overlay_color,
overlay_symmetry, contour, contour_color, contour_symmetry
)
return _show_rdm_panel(conf, ax or plt.gca())
def _show_rdm_panel(conf: SingleRdmPlot, ax: Axes) -> AxesImage:
"""Plot a single RDM based on a plot configuration object
Args:
conf (SingleRdmPlot): _description_
ax (Axes): _description_
Returns:
AxesImage: _description_
"""
rdmat = conf.rdms.get_matrices()[0, :, :]
if np.any(conf.nanmask):
rdmat[conf.nanmask] = np.nan
image = ax.imshow(
rdmat, cmap=conf.cmap, vmin=conf.vmin, vmax=conf.vmax,
interpolation='none')
_overlay(conf, ax)
_contour(conf, ax)
ax.set_xlim(-0.5, conf.rdms.n_cond - 0.5)
ax.set_ylim(conf.rdms.n_cond - 0.5, -0.5)
ax.xaxis.set_ticks(conf.gridlines)
ax.yaxis.set_ticks(conf.gridlines)
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
ax.xaxis.set_ticks(np.arange(conf.rdms.n_cond), minor=True)
ax.yaxis.set_ticks(np.arange(conf.rdms.n_cond), minor=True)
# hide minor ticks by default
ax.xaxis.set_tick_params(length=0, which="minor")
ax.yaxis.set_tick_params(length=0, which="minor")
ax.set_title(conf.title)
return image
def _overlay(conf: SingleRdmPlot, ax: Axes) -> None:
"""Add overlay image to axes
Args:
conf (SingleRdmPlot): Plot configuration
ax (Axes): axes object for this plot
"""
if not np.any(conf.overlay_mask):
return
cmap = ListedColormap(['none', conf.overlay_color])
ax.imshow(conf.overlay_mask, cmap=cmap, interpolation='none')
def _contour(conf: SingleRdmPlot, ax: Axes) -> None:
"""Add contour outline to axes
Args:
conf (SingleRdmPlot): Plot configuration
ax (Axes): axes object for this plot
"""
if not np.any(conf.contour_mask):
return
for (x1, y1, x2, y2) in _contour_coords(conf.contour_mask, -0.5):
ax.add_patch(
Polygon(
[(x1, y1), (x2, y2)],
facecolor='none',
edgecolor=conf.contour_color,
linewidth=3,
closed=True,
joinstyle='round'
)
)
def _mask_from_vector(vector: NDArray, triangles: Symmetry) -> NDArray:
"""Turn a triangular vector into a matrix mask, with given symmetry
Returns:
NDArray: 2-D boolean matrix
"""
mask = squareform(vector)
if triangles == Symmetry.BOTH:
return mask
elif triangles == Symmetry.LOWER:
return np.tril(mask)
elif triangles == Symmetry.UPPER:
return np.triu(mask)
def _contour_coords(mask: NDArray, offset: float) -> Iterator[Tuple[float, float, float, float]]:
"""Determine filled edges for the given mask
Returns a tuple of x1, y1, x2, y2 coordinates for each line.
Args:
mask (NDArray): nconds x nconds mask
offset (float): value to add for matplotlib indexing
Yields:
Iterator[Tuple[float, float, float, float]]: coordinates
"""
mask_t = mask.T
mask_idx = np.where(mask_t)
sides = [
((0, -1), (0, 0, 1, 0)), # top
((1, 0), (1, 0, 1, 1)), # right
((0, 1), (1, 1, 0, 1)), # bottom
((-1, 0), (0, 1, 0, 0)), # left
]
for x, y in np.vstack(mask_idx).T:
for neighbor, edge in sides:
if not mask_t[(x+neighbor[0], y+neighbor[1])]:
x1, y1, x2, y2 = edge
yield (x+x1+offset, y+y1+offset, x+x2+offset, y+y2+offset)
def _add_descriptor_labels(which_axis: Axis, ax: Axes, conf: MultiRdmPlot) -> List:
"""_add_descriptor_labels.
Args:
rdm (rsatoolbox.rdm.RDMs): RDMs instance to annotate.
pattern_descriptor (str): dict key for the rdm.pattern_descriptors dict.
icon_method (str): method to access on Icon instances (typically y_tick_label or
x_tick_label).
axis (Union[matplotlib.axis.XAxis, matplotlib.axis.YAxis]): Axis to add
tick labels to.
num_pattern_groups (int): Number of rows/columns for any image labels.
icon_spacing (float): control spacing of image labels - 1. means no gap (the
default), 1.1 means pad 10%, .9 means overlap 10% etc.
linewidth (float): Width of connecting lines from icon labels (if used) to axis
margin. The default is 0.5 - set to 0. to disable the lines.
horizontalalignment (str): Horizontal alignment of text tick labels.
Returns:
list: Tick label handles.
"""
if which_axis == Axis.X:
icon_method = "x_tick_label"
axis = ax.xaxis
horizontalalignment = "center"
else:
icon_method = "y_tick_label"
axis = ax.yaxis
horizontalalignment = "right"
descriptor_arr = np.asarray(conf.rdms.pattern_descriptors[conf.pattern_descriptor])
if isinstance(descriptor_arr[0], vis.Icon):
return _add_descriptor_icons(
descriptor_arr,
icon_method,
n_cond=conf.rdms.n_cond,
ax=axis.axes,
icon_spacing=conf.icon_spacing,
num_pattern_groups=conf.num_pattern_groups,
linewidth=conf.linewidth,
)
is_x_axis = "x" in icon_method
return _add_descriptor_text(
descriptor_arr,
axis=axis,
horizontalalignment=horizontalalignment,
is_x_axis=is_x_axis,
)
def _add_descriptor_text(
descriptor_arr: ArrayLike,
axis: Union[XAxis, YAxis],
horizontalalignment: str = "center",
is_x_axis: bool = False,
) -> List[Text]:
"""_add_descriptor_text. Used internally by _add_descriptor_labels to add vanilla
Matplotlib-based text labels to the X or Y axis.
Args:
descriptor_arr (ArrayLike): np.Array-like version of the labels.
axis (Union[matplotlib.axis.XAxis, matplotlib.axis.YAxis]): handle for
the relevant axis (ax.xaxis or ax.yaxis).
horizontalalignment (str): Horizontal alignment of text tick labels.
is_x_axis (bool): If set, rotate the text labels 60 degrees to reduce overlap on
the X axis.
Returns:
list: Tick label handles.
"""
# vanilla matplotlib-based
# need to ensure the minor ticks have some length
axis.set_tick_params(length=matplotlib.rcParams["xtick.minor.size"], which="minor")
label_handles = axis.set_ticklabels(
descriptor_arr,
verticalalignment="center",
horizontalalignment=horizontalalignment,
minor=True,
)
if is_x_axis:
plt.setp(
axis.get_ticklabels(minor=True),
rotation=60,
ha="right",
rotation_mode="anchor",
)
return label_handles
def _add_descriptor_icons(
descriptor_arr: ArrayLike,
icon_method: str,
n_cond: int,
ax: Optional[Axes] = None,
num_pattern_groups: Optional[int] = None,
icon_spacing: float = 1.0,
linewidth: float = 0.5,
) -> list:
"""_add_descriptor_icons. Used internally by _add_descriptor_labels to add
Icon-based labels to the X or Y axis.
Args:
descriptor_arr (ArrayLike): np.Array-like version of the labels.
icon_method (str): method to access on Icon instances (typically y_tick_label or
x_tick_label).
n_cond (int): Number of conditions in the RDM (usually from RDMs.n_cond).
ax (matplotlib.axes._axes.Axes): Matplotlib axis handle.
num_pattern_groups (int): Number of rows/columns for any image labels.
icon_spacing (float): control spacing of image labels - 1. means no gap (the
default), 1.1 means pad 10%, .9 means overlap 10% etc.
linewidth (float): Width of connecting lines from icon labels (if used) to axis
margin. The default is 0.5 - set to 0. to disable the lines.
Returns:
list: Tick label handles.
"""
# annotated labels with Icon
n_to_fit = np.ceil(n_cond / num_pattern_groups)
# work out sizing of icons
im_max_pix = 20.
if descriptor_arr[0].final_image:
# size by image
im_width_pix = max(this_desc.final_image.width for this_desc in descriptor_arr)
im_height_pix = max(this_desc.final_image.height for this_desc in descriptor_arr)
im_max_pix = max(im_width_pix, im_height_pix) * icon_spacing
ax.figure.canvas.draw()
extent = ax.get_window_extent(ax.figure.canvas.get_renderer())
ax_size_pix = max((extent.width, extent.height))
size = (ax_size_pix / n_to_fit) / im_max_pix
# from proportion of original size to figure pixels
offset = (im_max_pix / icon_spacing) * size
label_handles = []
for group_ind in range(num_pattern_groups - 1, -1, -1): # e.g. 2->1->0 for npg = 3
position = offset * 0.2 + offset * group_ind
ticks = np.arange(group_ind, n_cond, num_pattern_groups)
label_handles.append(
[
getattr(this_desc, icon_method)(
this_x, size, offset=position, linewidth=linewidth, ax=ax,
)
for (this_x, this_desc) in zip(ticks, descriptor_arr[ticks])
]
)
return label_handles
[docs]class MultiRdmPlot:
"""Configuration for the multi-rdm plot
"""
rdms: RDMs
pattern_descriptor: Optional[str]
cmap: Union[str, Colormap]
rdm_descriptor: str
n_column: int
n_row: int
show_colorbar: Optional[Literal["panel"] | Literal["figure"]]
gridlines: NDArray
num_pattern_groups: int
figsize: Tuple[float, float]
nanmask: NDArray
style: Path
vmin: Optional[float]
vmax: Optional[float]
icon_spacing: float
linewidth: float
n_panel: int
dissimilarity_measure: str
overlay: NDArray
overlay_color: str
overlay_symmetry: Symmetry
contour: NDArray
contour_color: str
contour_symmetry: Symmetry
overlay_mask: NDArray
contour_mask: NDArray
fig: Optional[Figure]
ax: Optional[NDArray]
handles: Optional[Dict[int, Dict[str, Any]]]
[docs] @classmethod
def from_show_rdm_args(
cls,
rdm: RDMs,
pattern_descriptor: Optional[str],
cmap: Union[str, Colormap],
rdm_descriptor: Optional[str],
n_column: Optional[int],
n_row: Optional[int],
show_colorbar: Optional[str],
gridlines: Optional[ArrayLike],
num_pattern_groups: Optional[int],
figsize: Optional[Tuple[float, float]],
nanmask: NDArray | str | None,
style: Optional[Union[str, Path]],
vmin: Optional[float],
vmax: Optional[float],
icon_spacing: float,
linewidth: float,
overlay: Optional[Tuple[str, str] | NDArray],
overlay_color: str,
overlay_symmetry: Symmetry,
contour: Optional[Tuple[str, str] | NDArray],
contour_color: str,
contour_symmetry: Symmetry
) -> MultiRdmPlot:
"""Create an object from the original arguments to show_rdm()
"""
conf = __class__(rdm)
if show_colorbar not in (None, "panel", "figure"):
raise ValueError(
f"show_colorbar can be None, panel or figure, got: {show_colorbar}"
)
conf.show_colorbar = show_colorbar
conf.nanmask = cls.init_nan_mask(nanmask, rdm)
conf.n_panel = rdm.n_rdm + int(show_colorbar == "figure")
if show_colorbar == "figure":
rdmat = rdm.get_matrices()
vmin = vmin or rdmat[:, ~conf.nanmask].min()
vmax = vmax or rdmat[:, ~conf.nanmask].max()
conf.vmin = vmin
conf.vmax = vmax
conf.n_row, conf.n_column = cls.determine_rows_cols_panels(
n_row, n_column, conf.n_panel)
conf.figsize = figsize or cls.calc_figsize(conf.n_column, conf.n_row)
gridlines = np.asarray(gridlines or list())
if num_pattern_groups and (not np.any(gridlines)):
# grid by pattern groups if they exist and explicit grid setting does not
gridlines = np.arange(
num_pattern_groups - 0.5, rdm.n_cond + 0.5, num_pattern_groups
)
conf.gridlines = np.asarray(gridlines)
if num_pattern_groups is None or num_pattern_groups == 0:
num_pattern_groups = 1
conf.num_pattern_groups = num_pattern_groups
conf.style = Path(str(style)) if style is not None else get_style()
conf.icon_spacing = icon_spacing
conf.linewidth = linewidth
if cmap == 'classic':
cmap = rdm_colormap_classic()
conf.cmap = cmap
conf.rdms = rdm
conf.pattern_descriptor = pattern_descriptor
conf.rdm_descriptor = rdm_descriptor or ''
conf.dissimilarity_measure = rdm.dissimilarity_measure or ''
conf.overlay = conf.interpret_rdm_arg(overlay, rdm)
conf.overlay_color = overlay_color
conf.overlay_symmetry = overlay_symmetry
conf.overlay_mask = _mask_from_vector(conf.overlay, conf.overlay_symmetry)
conf.contour = conf.interpret_rdm_arg(contour, rdm)
conf.contour_color = contour_color
conf.contour_symmetry = contour_symmetry
conf.contour_mask = _mask_from_vector(conf.contour, conf.contour_symmetry)
return conf
[docs] def interpret_rdm_arg(self, val: Optional[ArrayOrRdmDescriptor], rdms: RDMs) -> NDArray:
"""Resolve argument that can be an rdm descriptor key/value pair or a utv
"""
if val is None:
n_pairs = rdms.dissimilarities.shape[1]
return np.zeros(n_pairs)
if isinstance(val, np.ndarray):
return val
else:
return rdms.subset(*val).dissimilarities[0, :]
[docs] @classmethod
def determine_rows_cols_panels(
cls,
n_row: Optional[int],
n_column: Optional[int],
n_panel: int
) -> Tuple[int, int]:
"""Choose the number of rows and columns of panels
"""
if (n_column is None) and (n_row is None):
n_column = ceil(np.sqrt(n_panel))
if n_row is None:
n_row = ceil(n_panel / n_column)
if n_column is None:
n_column = ceil(n_panel / n_row)
return n_row, n_column
[docs] @classmethod
def init_nan_mask(
cls,
nanmask: NDArray | str | None,
rdms: RDMs,
) -> NDArray:
"""Interpret user's choice of nanmask
"""
if nanmask is None:
nanmask = np.zeros((rdms.n_cond, rdms.n_cond), dtype=bool)
elif isinstance(nanmask, str):
if nanmask == "diagonal":
nanmask = np.eye(rdms.n_cond, dtype=bool)
else:
raise ValueError("Invalid nanmask value")
return nanmask
[docs] @classmethod
def calc_figsize(cls, n_column: int, n_row: int) -> Tuple[float, float]:
""""
scale with number of RDMs, up to (intersection of A4 and us letter)
"""
return (
min(2 * n_column, 8.3), min(2 * n_row, 11)
)
[docs] def for_single(self, index: int) -> SingleRdmPlot:
"""Create a SingleRdmPlot object for the given rdm index
Args:
index (int): Index for the rdms
Returns:
SingleRdmPlot: _description_
"""
conf = SingleRdmPlot()
conf.rdms = self.rdms[index]
conf.cmap = self.cmap
conf.rdm_descriptor = self.rdm_descriptor
conf.gridlines = self.gridlines
conf.nanmask = self.nanmask
conf.vmin = self.vmin
conf.vmax = self.vmax
conf.overlay = self.overlay
conf.overlay_mask = self.overlay_mask
conf.overlay_color = self.overlay_color
conf.overlay_symmetry = self.overlay_symmetry
conf.contour = self.contour
conf.contour_mask = self.contour_mask
conf.contour_color = self.contour_color
conf.contour_symmetry = self.contour_symmetry
if self.rdm_descriptor in conf.rdms.rdm_descriptors:
conf.title = conf.rdms.rdm_descriptors[self.rdm_descriptor][0]
else:
conf.title = self.rdm_descriptor
return conf
def __init__(self, rdms: RDMs):
self.rdms = rdms
self.pattern_descriptor = None
self.cmap = 'bone_r'
self.rdm_descriptor = ''
self.gridlines = np.array([])
self.num_pattern_groups = 1
self.show_colorbar = None
self.n_row, self.n_column = self.determine_rows_cols_panels(
None, None, self.rdms.n_rdm)
self.figsize = self.calc_figsize(self.n_column, self.n_row)
self.nanmask = self.init_nan_mask('diagonal', self.rdms)
self.style = get_style()
self.vmin = None
self.vmax = None
self.icon_spacing = 1.0
self.linewidth = 0.5
n_pairs = rdms.dissimilarities.shape[1]
self.overlay = np.zeros(n_pairs)
self.overlay_color = '#00ff0050'
self.overlay_symmetry = Symmetry.BOTH
self.contour = np.zeros(n_pairs)
self.contour_color = 'red'
self.contour_symmetry = Symmetry.BOTH
[docs] def addOverlay(self, mask: ArrayOrRdmDescriptor, color: str, triangles: Symmetry):
self.overlay = self.interpret_rdm_arg(mask, self.rdms)
self.overlay_color = color
self.overlay_symmetry = triangles
self.overlay_mask = _mask_from_vector(self.overlay, triangles)
[docs] def addContour(self, mask: ArrayOrRdmDescriptor, color: str, triangles: Symmetry):
self.contour = self.interpret_rdm_arg(mask, self.rdms)
self.contour_color = color
self.contour_symmetry = triangles
self.contour_mask = _mask_from_vector(self.contour, triangles)
[docs] def plot(self):
self.fig, self.ax, self.handles = _plot_multi_rdm(self)
return self.fig
[docs]class SingleRdmPlot:
"""Configuration for the single-rdm plot
"""
rdms: RDMs
cmap: Union[str, Colormap]
rdm_descriptor: str
gridlines: ArrayLike
nanmask: NDArray
vmin: Optional[float]
vmax: Optional[float]
title: str
overlay: NDArray
overlay_color: str
overlay_symmetry: Symmetry
contour: NDArray
contour_color: str
contour_symmetry: Symmetry
overlay_mask: NDArray
contour_mask: NDArray
fig: Optional[Figure]
ax: Optional[NDArray]
handles: Optional[Dict[int, Dict[str, Any]]]
[docs] @classmethod
def from_show_rdm_panel_args(
cls,
rdms: RDMs,
cmap: Union[str, Colormap],
nanmask: Optional[NDArray],
rdm_descriptor: Optional[str],
gridlines: Optional[ArrayLike],
vmin: Optional[float],
vmax: Optional[float],
overlay: Optional[NDArray],
overlay_color: str,
overlay_symmetry: Symmetry,
contour: Optional[NDArray],
contour_color: str,
contour_symmetry: Symmetry
) -> SingleRdmPlot:
"""Create an object from the original arguments to show_rdm_panel()
"""
conf = __class__()
if rdms.n_rdm > 1:
raise ValueError("expected single rdm - use show_rdm for multi-panel figures")
if cmap == 'classic':
cmap = rdm_colormap_classic()
conf.cmap = cmap
if nanmask is None:
nanmask = np.eye(rdms.n_cond, dtype=bool)
conf.nanmask = nanmask
gridlines = gridlines or list()
if not np.any(gridlines):
gridlines = []
conf.gridlines = gridlines
if rdm_descriptor in rdms.rdm_descriptors:
conf.title = rdms.rdm_descriptors[rdm_descriptor][0]
else:
conf.title = rdm_descriptor or ''
conf.vmin = vmin
conf.vmax = vmax
conf.overlay = conf.interpret_rdm_arg(overlay, rdms)
conf.overlay_color = overlay_color
conf.overlay_symmetry = overlay_symmetry
conf.overlay_mask = _mask_from_vector(conf.overlay, conf.overlay_symmetry)
conf.contour = conf.interpret_rdm_arg(contour, rdms)
conf.contour_color = contour_color
conf.contour_symmetry = contour_symmetry
conf.contour_mask = _mask_from_vector(conf.contour, conf.contour_symmetry)
return conf
[docs] def interpret_rdm_arg(self, val: Optional[ArrayOrRdmDescriptor], rdms: RDMs) -> NDArray:
"""Resolve argument that can be an rdm descriptor key/value pair or a utv
"""
if val is None:
n_pairs = rdms.dissimilarities.shape[1]
return np.zeros(n_pairs)
if isinstance(val, np.ndarray):
return val
else:
return rdms.subset(*val).dissimilarities[0, :]