Source code for rsatoolbox.vis.rdm_plot

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plot showing an RDMs object
"""

from __future__ import annotations
import collections
from typing import TYPE_CHECKING, Union, Tuple, Optional
import pkg_resources
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import rsatoolbox.rdm
from rsatoolbox import vis
from rsatoolbox.vis.colors import rdm_colormap_classic
if TYPE_CHECKING:
    import numpy.typing as npt
    import pathlib
    from matplotlib.axes._axes import Axes
    from matplotlib.colors import Colormap

RDM_STYLE = pkg_resources.resource_filename('rsatoolbox.vis', 'rdm.mplstyle')


[docs]def show_rdm( rdm: rsatoolbox.rdm.RDMs, pattern_descriptor: str = None, cmap: Union[str, Colormap] = 'bone', rdm_descriptor: str = None, n_column: int = None, n_row: int = None, show_colorbar: str = None, gridlines: npt.ArrayLike = None, num_pattern_groups: int = None, figsize: Tuple[float, float] = None, nanmask: npt.ArrayLike = None, style: Union[str, pathlib.Path] = RDM_STYLE, vmin: float = None, vmax: float = None, icon_spacing: float = 1.0, linewidth: float = 0.5, ) -> Tuple[ matplotlib.figure.Figure, npt.ArrayLike, collections.defaultdict ]: """show_rdm. Heatmap figure for RDMs instance, with one panel per RDM. Args: rdm (rsatoolbox.rdm.RDMs): RDMs object to be plotted. pattern_descriptor (str): Key into rdm.pattern_descriptors to use for axis labels. cmap (str or Colormap): Colormap to be used. Either the name of a Matplotlib built-in colormap, a Matplotlib Colormap compatible object, or 'classic' for the matlab toolbox colormap. Defaults to 'bone'. rdm_descriptor (str): Key for rdm_descriptor to use as panel title, or str for direct labeling. n_column (int): Number of columns in subplot arrangement. n_row (int): Number of rows in subplot arrangement. show_colorbar (str): Set to 'panel' or 'figure' to display a colorbar. If 'panel' a colorbar is added next to each RDM. If 'figure' a shared colorbar (and scale) is used across panels. gridlines (npt.ArrayLike): Set to add gridlines at these positions. If num_pattern_groups is defined this is used to infer gridlines. num_pattern_groups (int): Number of rows/columns for any image labels. Also determines gridlines frequency by default (so e.g., num_pattern_groups=3 with results in gridlines every 3 rows/columns). figsize (Tuple[float, float]): mpl.Figure argument. By default we auto-scale to achieve a figure that fits on a standard A4 / US Letter page in portrait orientation. nanmask (npt.ArrayLike): boolean mask defining RDM elements to suppress (by default, the diagonals). style (Union[str, pathlib.Path]): Path to mplstyle file that controls various figure aesthetics (default rsatoolbox/vis/rdm.mplstyle). vmin (float): Minimum intensity for colorbar mapping. matplotlib imshow argument. vmax (float): Maximum intensity for colorbar mapping. matplotlib imshow argument. icon_spacing (float): control spacing of image labels - 1. means no gap (the default), 1.1 means pad 10%, .9 means overlap 10% etc. linewidth (float): Width of connecting lines from icon labels (if used) to axis margin. The default is 0.5 - set to 0. to disable the lines. Returns: Tuple[matplotlib.figure.Figure, npt.ArrayLike, collections.defaultdict]: Tuple of - Handle to created figure. - Subplot axis handles from plt.subplots. - Nested dict containing handles to all other plotted objects (icon labels, colorbars, etc). The keys at the first level are the axis and figure handles. """ if show_colorbar and show_colorbar not in ("panel", "figure"): raise ValueError( f"show_colorbar can be None, panel or figure, got: {show_colorbar}" ) if nanmask is None: nanmask = np.eye(rdm.n_cond, dtype=bool) n_panel = rdm.n_rdm if show_colorbar == "figure": n_panel += 1 # need to keep track of global CB limits if any(var is None for var in [vmin, vmax]): # need to load the RDMs here (expensive) rdmat = rdm.get_matrices() if vmin is None: vmin = rdmat[:, (nanmask == False)].min() if vmax is None: vmax = rdmat[:, (nanmask == False)].max() if n_column is None and n_row is None: n_column = np.ceil(np.sqrt(n_panel)) if n_row is None: n_row = np.ceil(n_panel / n_column) if n_column is None: n_column = np.ceil(n_panel / n_row) if (n_column * n_row) < rdm.n_rdm: raise ValueError( f"invalid n_row*n_column specification for {n_panel} rdms: {n_row}*{n_column}" ) if figsize is None: # scale with number of RDMs, up to a point (the intersection of A4 and us # letter) figsize = (min(2 * n_column, 8.3), min(2 * n_row, 11)) if not np.any(gridlines): # empty list to disable gridlines gridlines = [] if num_pattern_groups: # grid by pattern groups if they exist and explicit grid setting does not gridlines = np.arange( num_pattern_groups - 0.5, rdm.n_cond + 0.5, num_pattern_groups ) if num_pattern_groups is None or num_pattern_groups == 0: num_pattern_groups = 1 # we don't necessarily have the same number of RDMs as panels, so need to stop the # loop when we've plotted all the RDMs rdms_gen = (this_rdm for this_rdm in rdm) # return values are # image, axis, colorbar, x_labels, y_labels # some are global for figure, others local. Perhaps dicts indexed by axis is easiest return_handles = collections.defaultdict(dict) with plt.style.context(style): fig, ax_array = plt.subplots( nrows=int(n_row), ncols=int(n_column), sharex=True, sharey=True, squeeze=False, figsize=figsize, ) # reverse panel order so unfilled rows are at top instead of bottom ax_array = ax_array[::-1] for row_ind, row in enumerate(ax_array): for col_ind, panel in enumerate(row): try: return_handles[panel]["image"] = show_rdm_panel( next(rdms_gen), ax=panel, cmap=cmap, nanmask=nanmask, rdm_descriptor=rdm_descriptor, gridlines=gridlines, vmin=vmin, vmax=vmax, ) except StopIteration: # hide empty panels panel.set_visible(False) continue if show_colorbar == "panel": # needs to happen before labels because it resizes the axis return_handles[panel]["colorbar"] = _rdm_colorbar( mappable=return_handles[panel]["image"], fig=fig, ax=panel, title=rdm.dissimilarity_measure, ) if col_ind == 0 and pattern_descriptor: return_handles[panel]["y_labels"] = add_descriptor_y_labels( rdm, pattern_descriptor, ax=panel, num_pattern_groups=num_pattern_groups, icon_spacing=icon_spacing, linewidth=linewidth, ) if row_ind == 0 and pattern_descriptor: return_handles[panel]["x_labels"] = add_descriptor_x_labels( rdm, pattern_descriptor, ax=panel, num_pattern_groups=num_pattern_groups, icon_spacing=icon_spacing, linewidth=linewidth, ) if show_colorbar == "figure": # key challenge is to obtain a similarly-sized colorbar to the 'panel' case # BUT positioned centered on the reserved subplot axes cbax_parent = ax_array[-1, -1] cbax_parent_orgpos = cbax_parent.get_position(original=True) # use last instance of 'image' (should all be yoked at this point) return_handles[fig]["colorbar"] = _rdm_colorbar( mappable=return_handles[ax_array[0][0]]["image"], fig=fig, ax=cbax_parent, title=rdm.dissimilarity_measure, ) cbax_pos = return_handles[fig]["colorbar"].ax.get_position() # halfway through panel, less the width/height of the colorbar itself x0 = ( cbax_parent_orgpos.x0 + cbax_parent_orgpos.width / 2 - cbax_pos.width / 2 ) y0 = ( cbax_parent_orgpos.y0 + cbax_parent_orgpos.height / 2 - cbax_pos.height / 2 ) return_handles[fig]["colorbar"].ax.set_position( [x0, y0, cbax_pos.width, cbax_pos.height] ) return fig, ax_array, return_handles
def _rdm_colorbar( mappable: matplotlib.cm.ScalarMappable = None, fig: matplotlib.figure.Figure = None, ax: Axes = None, title: str = None, ) -> matplotlib.colorbar.Colorbar: """_rdm_colorbar. Add vertically-oriented, small colorbar to rdm figure. Used internally by show_rdm. Args: mappable (matplotlib.cm.ScalarMappable): Typically plt.imshow instance. fig (matplotlib.figure.Figure): Matplotlib figure handle. ax (matplotlib.axes._axes.Axes): Matplotlib axis handle. plt.gca() by default. title (str): Title string for the colorbar (positioned top, left aligned). Returns: matplotlib.colorbar.Colorbar: Matplotlib handle. """ cb = fig.colorbar( mappable=mappable, ax=ax, shrink=0.25, aspect=5, ticks=matplotlib.ticker.LinearLocator(numticks=3), ) cb.ax.set_title(title, loc="left", fontdict=dict(fontweight="normal")) return cb
[docs]def show_rdm_panel( rdm: rsatoolbox.rdm.RDMs, ax: Optional[Axes] = None, cmap: Union[str, Colormap] = 'bone', nanmask: npt.ArrayLike = None, rdm_descriptor: str = None, gridlines: npt.ArrayLike = None, vmin: float = None, vmax: float = None, ) -> matplotlib.image.AxesImage: """show_rdm_panel. Add RDM heatmap to the axis ax. Args: rdm (rsatoolbox.rdm.RDMs): RDMs object to be plotted (n_rdm must be 1). ax (matplotlib.axes._axes.Axes): Matplotlib axis handle. plt.gca() by default. cmap (str or Colormap): Colormap to be used. Either the name of a Matplotlib built-in colormap, a Matplotlib Colormap compatible object, or 'classic' for the matlab toolbox colormap. Defaults to 'bone'. nanmask (npt.ArrayLike): boolean mask defining RDM elements to suppress (by default, the diagonals). rdm_descriptor (str): Key for rdm_descriptor to use as panel title, or str for direct labeling. gridlines (npt.ArrayLike): Set to add gridlines at these positions. vmin (float): Minimum intensity for colorbar mapping. matplotlib imshow argument. vmax (float): Maximum intensity for colorbar mapping. matplotlib imshow argument. Returns: matplotlib.image.AxesImage: Matplotlib handle. """ if rdm.n_rdm > 1: raise ValueError("expected single rdm - use show_rdm for multi-panel figures") if ax is None: ax = plt.gca() if cmap == 'classic': cmap = rdm_colormap_classic() if nanmask is None: nanmask = np.eye(rdm.n_cond, dtype=bool) if not np.any(gridlines): gridlines = [] rdmat = rdm.get_matrices()[0, :, :] if np.any(nanmask): rdmat[nanmask] = np.nan image = ax.imshow( rdmat, cmap=cmap, vmin=vmin, vmax=vmax, interpolation='none') ax.set_xlim(-0.5, rdm.n_cond - 0.5) ax.set_ylim(rdm.n_cond - 0.5, -0.5) ax.xaxis.set_ticks(gridlines) ax.yaxis.set_ticks(gridlines) ax.xaxis.set_ticklabels([]) ax.yaxis.set_ticklabels([]) ax.xaxis.set_ticks(np.arange(rdm.n_cond), minor=True) ax.yaxis.set_ticks(np.arange(rdm.n_cond), minor=True) # hide minor ticks by default ax.xaxis.set_tick_params(length=0, which="minor") ax.yaxis.set_tick_params(length=0, which="minor") if rdm_descriptor in rdm.rdm_descriptors: ax.set_title(rdm.rdm_descriptors[rdm_descriptor][0]) else: ax.set_title(rdm_descriptor) return image
[docs]def add_descriptor_x_labels( rdm: rsatoolbox.rdm.RDMs, pattern_descriptor: str, ax: Axes = None, num_pattern_groups: int = None, icon_spacing: float = 1.0, linewidth: float = 0.5, ) -> list: """add_descriptor_x_labels. Add labels to the X axis in ax by accessing the rdm.pattern_descriptors dict with the pattern_descriptor key. Args: rdm (rsatoolbox.rdm.RDMs): RDMs instance to annotate. pattern_descriptor (str): dict key for the rdm.pattern_descriptors dict. ax (matplotlib.axes._axes.Axes): Matplotlib axis handle. plt.gca() by default. num_pattern_groups (int): Number of rows/columns for any image labels. icon_spacing (float): control spacing of image labels - 1. means no gap (the default), 1.1 means pad 10%, .9 means overlap 10% etc. linewidth (float): Width of connecting lines from icon labels (if used) to axis margin. The default is 0.5 - set to 0. to disable the lines. Returns: list: Tick label handles. """ if ax is None: ax = plt.gca() return _add_descriptor_labels( rdm, pattern_descriptor, "x_tick_label", ax.xaxis, num_pattern_groups=num_pattern_groups, icon_spacing=icon_spacing, linewidth=linewidth, horizontalalignment="center", )
[docs]def add_descriptor_y_labels( rdm: rsatoolbox.rdm.RDMs, pattern_descriptor: str, ax: Axes = None, num_pattern_groups: int = None, icon_spacing: float = 1.0, linewidth: float = 0.5, ) -> list: """add_descriptor_y_labels. Add labels to the Y axis in ax by accessing the rdm.pattern_descriptors dict with the pattern_descriptor key. Args: rdm (rsatoolbox.rdm.RDMs): RDMs instance to annotate. pattern_descriptor (str): dict key for the rdm.pattern_descriptors dict. ax (matplotlib.axes._axes.Axes): Matplotlib axis handle. plt.gca() by default. num_pattern_groups (int): Number of rows/columns for any image labels. icon_spacing (float): control spacing of image labels - 1. means no gap (the default), 1.1 means pad 10%, .9 means overlap 10% etc. linewidth (float): Width of connecting lines from icon labels (if used) to axis margin. The default is 0.5 - set to 0. to disable the lines. Returns: list: Tick label handles. """ if ax is None: ax = plt.gca() return _add_descriptor_labels( rdm, pattern_descriptor, "y_tick_label", ax.yaxis, num_pattern_groups=num_pattern_groups, icon_spacing=icon_spacing, linewidth=linewidth, horizontalalignment="right", )
def _add_descriptor_labels( rdm: rsatoolbox.rdm.RDMs, pattern_descriptor: str, icon_method: str, axis: Union[matplotlib.axis.XAxis, matplotlib.axis.YAxis], num_pattern_groups: int = None, icon_spacing: float = 1.0, linewidth: float = 0.5, horizontalalignment: str = "center", ) -> list: """_add_descriptor_labels. Used internally by add_descriptor_y_labels and add_descriptor_x_labels. Args: rdm (rsatoolbox.rdm.RDMs): RDMs instance to annotate. pattern_descriptor (str): dict key for the rdm.pattern_descriptors dict. icon_method (str): method to access on Icon instances (typically y_tick_label or x_tick_label). axis (Union[matplotlib.axis.XAxis, matplotlib.axis.YAxis]): Axis to add tick labels to. num_pattern_groups (int): Number of rows/columns for any image labels. icon_spacing (float): control spacing of image labels - 1. means no gap (the default), 1.1 means pad 10%, .9 means overlap 10% etc. linewidth (float): Width of connecting lines from icon labels (if used) to axis margin. The default is 0.5 - set to 0. to disable the lines. horizontalalignment (str): Horizontal alignment of text tick labels. Returns: list: Tick label handles. """ descriptor_arr = np.asarray(rdm.pattern_descriptors[pattern_descriptor]) if isinstance(descriptor_arr[0], vis.Icon): return _add_descriptor_icons( descriptor_arr, icon_method, n_cond=rdm.n_cond, ax=axis.axes, icon_spacing=icon_spacing, num_pattern_groups=num_pattern_groups, linewidth=linewidth, ) is_x_axis = "x" in icon_method return _add_descriptor_text( descriptor_arr, axis=axis, horizontalalignment=horizontalalignment, is_x_axis=is_x_axis, ) def _add_descriptor_text( descriptor_arr: npt.ArrayLike, axis: Union[matplotlib.axis.XAxis, matplotlib.axis.YAxis], horizontalalignment: str = "center", is_x_axis: bool = False, ) -> list: """_add_descriptor_text. Used internally by _add_descriptor_labels to add vanilla Matplotlib-based text labels to the X or Y axis. Args: descriptor_arr (npt.ArrayLike): np.Array-like version of the labels. axis (Union[matplotlib.axis.XAxis, matplotlib.axis.YAxis]): handle for the relevant axis (ax.xaxis or ax.yaxis). horizontalalignment (str): Horizontal alignment of text tick labels. is_x_axis (bool): If set, rotate the text labels 60 degrees to reduce overlap on the X axis. Returns: list: Tick label handles. """ # vanilla matplotlib-based # need to ensure the minor ticks have some length axis.set_tick_params(length=matplotlib.rcParams["xtick.minor.size"], which="minor") label_handles = axis.set_ticklabels( descriptor_arr, verticalalignment="center", horizontalalignment=horizontalalignment, minor=True, ) if is_x_axis: plt.setp( axis.get_ticklabels(minor=True), rotation=60, ha="right", rotation_mode="anchor", ) return label_handles def _add_descriptor_icons( descriptor_arr: npt.ArrayLike, icon_method: str, n_cond: int, ax: Axes = None, num_pattern_groups: int = None, icon_spacing: float = 1.0, linewidth: float = 0.5, ) -> list: """_add_descriptor_icons. Used internally by _add_descriptor_labels to add Icon-based labels to the X or Y axis. Args: descriptor_arr (npt.ArrayLike): np.Array-like version of the labels. icon_method (str): method to access on Icon instances (typically y_tick_label or x_tick_label). n_cond (int): Number of conditions in the RDM (usually from RDMs.n_cond). ax (matplotlib.axes._axes.Axes): Matplotlib axis handle. num_pattern_groups (int): Number of rows/columns for any image labels. icon_spacing (float): control spacing of image labels - 1. means no gap (the default), 1.1 means pad 10%, .9 means overlap 10% etc. linewidth (float): Width of connecting lines from icon labels (if used) to axis margin. The default is 0.5 - set to 0. to disable the lines. Returns: list: Tick label handles. """ # annotated labels with Icon n_to_fit = np.ceil(n_cond / num_pattern_groups) # work out sizing of icons im_max_pix = 20. if descriptor_arr[0].final_image: # size by image im_width_pix = max(this_desc.final_image.width for this_desc in descriptor_arr) im_height_pix = max(this_desc.final_image.height for this_desc in descriptor_arr) im_max_pix = max(im_width_pix, im_height_pix) * icon_spacing ax.figure.canvas.draw() extent = ax.get_window_extent(ax.figure.canvas.get_renderer()) ax_size_pix = max((extent.width, extent.height)) size = (ax_size_pix / n_to_fit) / im_max_pix # from proportion of original size to figure pixels offset = im_max_pix * size label_handles = [] for group_ind in range(num_pattern_groups - 1, -1, -1): position = offset * 0.2 + offset * group_ind ticks = np.arange(group_ind, n_cond, num_pattern_groups) label_handles.append( [ getattr(this_desc, icon_method)( this_x, size, offset=position, linewidth=linewidth, ax=ax, ) for (this_x, this_desc) in zip(ticks, descriptor_arr[ticks]) ] ) return label_handles