Source code for rsatoolbox.data.ops
"""Operations on multiple Datasets
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Union, List, Set, overload
from copy import deepcopy
from warnings import warn
try:
from typing import Literal # pylint: disable=ungrouped-imports
except ImportError:
from typing_extensions import Literal
from numpy import concatenate, repeat
import rsatoolbox
if TYPE_CHECKING:
DESC_LEVEL = Union[Literal['obs'], Literal['set']]
from rsatoolbox.data.dataset import Dataset, TemporalDataset
@overload
def merge_datasets(sets: List[TemporalDataset]) -> TemporalDataset:
...
@overload
def merge_datasets(sets: List[Dataset]) -> Dataset:
...
[docs]def merge_datasets(sets: Union[List[Dataset], List[TemporalDataset]]
) -> Union[Dataset, TemporalDataset]:
"""Concatenate measurements to create one Dataset of the same type
Only descriptors that exist on all subsets are assigned to the merged
dataset.
Dataset-level `descriptors` that are identical across subsets will be
passed on, those that vary will become `obs_descriptors`.
Channel and Time descriptors must be identical across subsets.
Args:
sets (Union[List[Dataset], List[TemporalDataset]]): List of Dataset
or TemporalDataset objects. Must all be the same type.
Returns:
Union[Dataset, TemporalDataset]: The new dataset combining measurements
and descriptors from the given subset datasets.
"""
if len(sets) == 0:
warn('[merge_datasets] Received empty list, returning empty Dataset')
return rsatoolbox.data.dataset.Dataset(measurements=[])
if len({type(s) for s in sets}) > 1:
raise ValueError('All datasets must be of the same type')
ds0 = sets[0]
# numpy pre-allocates so this seems to be a performant solution:
meas = concatenate([ds.measurements for ds in sets], axis=0)
obs_descs = dict()
# loop over obs descriptors that all subsets have in common:
for k in _shared_descriptors(sets, 'obs'):
obs_descs[k] = concatenate([ds.obs_descriptors[k] for ds in sets])
dat_decs = dict()
for k in _shared_descriptors(sets):
if len({s.descriptors[k] for s in sets}) == 1:
# descriptor always has the same value
dat_decs[k] = ds0.descriptors[k]
else:
# descriptor varies across subsets, so repeat it by observation
obs_descs[k] = repeat(
[ds.descriptors[k] for ds in sets],
[ds.n_obs for ds in sets]
)
# order is important as long as TemporalDataset inherits from Dataset
if isinstance(ds0, rsatoolbox.data.dataset.TemporalDataset):
return rsatoolbox.data.dataset.TemporalDataset(
measurements=meas,
descriptors=dat_decs,
obs_descriptors=obs_descs,
channel_descriptors=deepcopy(ds0.channel_descriptors),
time_descriptors=deepcopy(ds0.time_descriptors),
)
if isinstance(ds0, rsatoolbox.data.dataset.Dataset):
return rsatoolbox.data.dataset.Dataset(
measurements=meas,
descriptors=dat_decs,
obs_descriptors=obs_descs,
channel_descriptors=deepcopy(ds0.channel_descriptors)
)
raise ValueError('Unsupported Dataset type')
def _shared_descriptors(
datasets: Union[List[Dataset], List[TemporalDataset]],
level: DESC_LEVEL = 'set') -> Set[str]:
"""Find descriptors that all datasets have in common
"""
if level == 'set':
each_keys = [set(d.descriptors.keys()) for d in datasets]
else:
each_keys = [set(d.obs_descriptors.keys()) for d in datasets]
return set.intersection(*each_keys)