Source code for rsatoolbox.data.base

"""Base class for Dataset
"""
from __future__ import annotations
from rsatoolbox.util.descriptor_utils import check_descriptor_length_error
from rsatoolbox.util.descriptor_utils import format_descriptor
from rsatoolbox.util.descriptor_utils import parse_input_descriptor
from rsatoolbox.io.hdf5 import write_dict_hdf5
from rsatoolbox.io.pkl import write_dict_pkl
from rsatoolbox.util.file_io import remove_file


[docs] class DatasetBase: """ Abstract dataset class. Defines members that every class needs to have, but does not implement any interesting behavior. Inherit from this class to define specific dataset types Args: measurements (numpy.ndarray): n_obs x n_channel 2d-array, descriptors (dict): descriptors (metadata) obs_descriptors (dict): observation descriptors (all are array-like with shape = (n_obs,...)) channel_descriptors (dict): channel descriptors (all are array-like with shape = (n_channel,...)) Returns: dataset object """ def __init__(self, measurements, descriptors=None, obs_descriptors=None, channel_descriptors=None, check_dims=True): if measurements.ndim != 2: raise AttributeError( "measurements must be in dimension n_obs x n_channel") self.measurements = measurements self.n_obs, self.n_channel = self.measurements.shape if check_dims: check_descriptor_length_error(obs_descriptors, "obs_descriptors", self.n_obs ) check_descriptor_length_error(channel_descriptors, "channel_descriptors", self.n_channel ) self.descriptors = parse_input_descriptor(descriptors) self.obs_descriptors = parse_input_descriptor(obs_descriptors) self.channel_descriptors = parse_input_descriptor(channel_descriptors) def __repr__(self): """ defines string which is printed for the object """ return (f'rsatoolbox.data.{self.__class__.__name__}(\n' f'measurements = \n{self.measurements}\n' f'descriptors = \n{self.descriptors}\n' f'obs_descriptors = \n{self.obs_descriptors}\n' f'channel_descriptors = \n{self.channel_descriptors}\n' ) def __str__(self): """ defines the output of print """ string_desc = format_descriptor(self.descriptors) string_obs_desc = format_descriptor(self.obs_descriptors) string_channel_desc = format_descriptor(self.channel_descriptors) if self.measurements.shape[0] > 5: measurements = self.measurements[:5, :] else: measurements = self.measurements return (f'rsatoolbox.data.{self.__class__.__name__}\n' f'measurements = \n{measurements}\n...\n\n' f'descriptors: \n{string_desc}\n\n' f'obs_descriptors: \n{string_obs_desc}\n\n' f'channel_descriptors: \n{string_channel_desc}\n' ) def __eq__(self, other: object) -> bool: """Equality check, to be implemented in the specific Dataset class Args: other (DatasetBase): The object to compare to. Raises: NotImplementedError: This is not valid if not implemented by the specific Dataset class Returns: bool: Never returns """ if isinstance(other, DatasetBase): raise NotImplementedError() else: return False
[docs] def copy(self) -> DatasetBase: """Copy Dataset To be implemented in child class Raises: NotImplementedError: raised if not implemented Returns: DatasetBase: Never returns """ raise NotImplementedError
[docs] def split_obs(self, by): """ Returns a list Datasets split by obs Args: by(String): the descriptor by which the splitting is made Returns: list of Datasets, splitted by the selected obs_descriptor """ raise NotImplementedError( "split_obs function not implemented in used Dataset class!")
[docs] def split_channel(self, by): """ Returns a list Datasets split by channels Args: by(String): the descriptor by which the splitting is made Returns: list of Datasets, splitted by the selected channel_descriptor """ raise NotImplementedError( "split_channel function not implemented in used Dataset class!")
[docs] def subset_obs(self, by, value): """ Returns a subsetted Dataset defined by certain obs value Args: by(String): the descriptor by which the subset selection is made from obs dimension value: the value by which the subset selection is made from obs dimension Returns: Dataset, with subset defined by the selected obs_descriptor """ raise NotImplementedError( "subset_obs function not implemented in used Dataset class!")
[docs] def subset_channel(self, by, value): """ Returns a subsetted Dataset defined by certain channel value Args: by(String): the descriptor by which the subset selection is made from channel dimension value: the value by which the subset selection is made from channel dimension Returns: Dataset, with subset defined by the selected channel_descriptor """ raise NotImplementedError( "subset_channel function not implemented in used Dataset class!")
[docs] def save(self, filename, file_type='hdf5', overwrite=False): """ Saves the dataset object to a file Args: filename(String): path to the file [or opened file] file_type(String): Type of file to create: hdf5: hdf5 file pkl: pickle file overwrite(Boolean): overwrites file if it already exists """ data_dict = self.to_dict() if overwrite: remove_file(filename) if file_type == 'hdf5': write_dict_hdf5(filename, data_dict) elif file_type == 'pkl': write_dict_pkl(filename, data_dict)
[docs] def to_dict(self): """ Generates a dictionary which contains the information to recreate the dataset object. Used for saving to disc Returns: data_dict(dict): dictionary with dataset information """ data_dict = {} data_dict['measurements'] = self.measurements data_dict['descriptors'] = self.descriptors data_dict['obs_descriptors'] = self.obs_descriptors data_dict['channel_descriptors'] = self.channel_descriptors data_dict['type'] = type(self).__name__ return data_dict