Temporal RSA

This demo notebook demonstrates how to work with temporal data in the RSA toolbox

So far, it demonstrates how to

  1. import temporal dataset into the rsatoolbox.data.TemporalDataset class and

  2. how to create RDM movies using the rsatoolbox.rdm.calc_rdm_movie function

The notebook will

[1]:
import numpy as np
import matplotlib.pyplot as plt
import rsatoolbox
import pickle

from rsatoolbox.rdm import calc_rdm_movie

Load temporal data

I here used sample data from mne-python

https://mne.tools/dev/overview/datasets_index.html#sample

Data is comprised of the preprocessed MEG data in “sample_audvis_raw.fif”.

Preprocessing includes: - downsampling to 60Hz - band-pass filtering between 1 Hz and 20 Hz - rejecting bad trials using an amplitude threshold - baseline correction (basline -200 to 0 ms)

See demos/TemporalSampleData/preproc_mn_sample_data.py

The preprocessed data is stored in TemporalSampleData/meg_sample_data.pkl

[2]:
dat = pickle.load( open( "TemporalSampleData/meg_sample_data.pkl", "rb" ) )
measurements = dat['data']
cond_names = [x for x in dat['cond_names'].keys()]
cond_idx = dat['cond_idx']
channel_names = dat['channel_names']
times = dat['times']
[3]:
print('there are %d observations (trials), %d channels, and %d time-points\n' %
      (measurements.shape))

print('conditions:')
print(cond_names)
there are 227 observations (trials), 203 channels, and 58 time-points

conditions:
['Auditory/Left', 'Auditory/Right', 'Visual/Left', 'Visual/Right']

Plot condition averages for two channels:

[4]:
fig, ax = plt.subplots(1, 2, figsize=(12,4))
ax = ax.flatten()
for jj,chan in enumerate(channel_names[:2]):
    for ii, cond_ii in enumerate(np.unique(cond_idx)):
        mn = measurements[cond_ii == cond_idx,jj,:].mean(0).squeeze()
        ax[jj].plot(times, mn, label = cond_names[ii])
        ax[jj].set_title(chan)
ax[jj].legend()
plt.show()
_images/demo_temporal_7_0.png

The rsatoolbox.data.TemporalDataset class

measurements is an np.array of shape n_obs x n_channels x n_times

time_descriptor should contain the time-point vector for the measurements of length n_times. it is recommended to call this descriptor ‘time’

[5]:
tim_des = {'time': times}

the other descriptors are identical as in the rsatoolbox.data.Dataset class

[6]:
des = {'session': 0, 'subj': 0}
obs_des = {'conds': cond_idx}
chn_des = {'channels': channel_names}
[7]:
data = rsatoolbox.data.TemporalDataset(measurements,
                              descriptors = des,
                              obs_descriptors = obs_des,
                              channel_descriptors = chn_des,
                              time_descriptors = tim_des)
data.sort_by('conds')

convenience methods

rsatoolbox.data.TemporalDataset comes with the same convenience methods as rsatoolbox.data.Dataset.

In addition, the following functions are provided:

  • rsatoolbox.data.TemporalDataset.split_time(by)

  • rsatoolbox.data.TemporalDataset.subset_time(by, t_from, t_to)

  • rsatoolbox.data.TemporalDataset.bin_time(by, bins)

  • rsatoolbox.data.TemporalDataset.convert_to_dataset(by)

rsatoolbox.data.TemporalDataset.split_time(by)

splits the rsatoolbox.data.TemporalDataset object into a list of n_times rsatoolbox.data.TemporalDatset objects, splitting the measurements along the time_descriptor by

[8]:
print('shape of original measurements')
print(data.measurements.shape)

data_split_time = data.split_time('time')

print('\nafter splitting')
print(len(data_split_time))
print(data_split_time[0].measurements.shape)
shape of original measurements
(227, 203, 58)

after splitting
58
(227, 203, 1)

rsatoolbox.data.TemporalDataset.subset_time(by, t_from, t_to)

returns a new rsatoolbox.data.TemporalDataset with only the data between where time_descriptors[by] is between t_from and t_to

[9]:
print('shape of original measurements')
print(data.measurements.shape)

data_subset_time = data.subset_time('time', t_from = -.1, t_to = .5)

print('\nafter subsetting')
print(data_subset_time.measurements.shape)
print(data_subset_time.time_descriptors['time'][0])
shape of original measurements
(227, 203, 58)

after subsetting
(227, 203, 37)
-0.09989760657919393

rsatoolbox.data.TemporalDataset.bin_time(by, bins)

returns a new rsatoolbox.data.TemporalDataset object with binned temporal data. data within bins is averaged.

bins is a list or array, where the first dimension contains the bins, and the second dimension the old time-bins that should go into this bin.

[10]:
bins = np.reshape(tim_des['time'], [-1, 2])
print(len(bins))
print(bins[0])
29
[-0.19979521 -0.18314561]
[11]:
print('shape of original measurements')
print(data.measurements.shape)

data_binned = data.bin_time('time', bins=bins)

print('\nafter binning')
print(data_binned.measurements.shape)
print(data_binned.time_descriptors['time'][0])
shape of original measurements
(227, 203, 58)

after binning
(227, 203, 29)
-0.1914704126101217

rsatoolbox.data.TemporalDataset.convert_to_dataset(by)

returns a rsatoolbox.data.Dataset object where the time dimension is absorbed into the observation dimension

[12]:
print('shape of original measurements')
print(data.measurements.shape)

data_dataset = data.convert_to_dataset('time')

print('\nafter binning')
print(data_dataset.measurements.shape)
print(data_dataset.obs_descriptors['time'][0])
shape of original measurements
(227, 203, 58)

after binning
(13166, 203)
-0.19979521315838786

create RDM movie

the function calc_rdm_movie takes rsatoolbox.data.TemporalDataset as an input and outputs an RDMs rsatoolbox.rdm.RDMs object. It works like calc_rdm.

[13]:
rdms_data = calc_rdm_movie(data, method = 'euclidean',
                           descriptor = 'conds')
print(rdms_data)
rsatoolbox.rdm.RDMs
58 RDM(s) over 4 conditions

dissimilarity_measure =
squared euclidean

dissimilarities[0] =
[[0.00000000e+00 4.48201826e-25 6.17620803e-25 4.68172301e-25]
 [4.48201826e-25 0.00000000e+00 7.16414642e-25 4.27861898e-25]
 [6.17620803e-25 7.16414642e-25 0.00000000e+00 7.61554453e-25]
 [4.68172301e-25 4.27861898e-25 7.61554453e-25 0.00000000e+00]]

descriptors:

rdm_descriptors:
session = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
subj = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57]
time = [-0.19979521 -0.18314561 -0.16649601 -0.14984641 -0.13319681 -0.11654721
 -0.09989761 -0.08324801 -0.0665984  -0.0499488  -0.0332992  -0.0166496
  0.          0.0166496   0.0332992   0.0499488   0.0665984   0.08324801
  0.09989761  0.11654721  0.13319681  0.14984641  0.16649601  0.18314561
  0.19979521  0.21644481  0.23309442  0.24974402  0.26639362  0.28304322
  0.29969282  0.31634242  0.33299202  0.34964162  0.36629122  0.38294083
  0.39959043  0.41624003  0.43288963  0.44953923  0.46618883  0.48283843
  0.49948803  0.51613763  0.53278724  0.54943684  0.56608644  0.58273604
  0.59938564  0.61603524  0.63268484  0.64933444  0.66598404  0.68263364
  0.69928325  0.71593285  0.73258245  0.74923205]

pattern_descriptors:
index = [0, 1, 2, 3]
conds = [1.0, 2.0, 3.0, 4.0]


Binning can be applied before computing the RDMs by simpling specifying the bins argument

[14]:
rdms_data_binned = calc_rdm_movie(data, method = 'euclidean',
                           descriptor = 'conds',
                           bins=bins)
print(rdms_data_binned)
rsatoolbox.rdm.RDMs
29 RDM(s) over 4 conditions

dissimilarity_measure =
squared euclidean

dissimilarities[0] =
[[0.00000000e+00 3.46836229e-25 4.81401601e-25 3.27640005e-25]
 [3.46836229e-25 0.00000000e+00 4.35938622e-25 3.20443252e-25]
 [4.81401601e-25 4.35938622e-25 0.00000000e+00 5.22469051e-25]
 [3.27640005e-25 3.20443252e-25 5.22469051e-25 0.00000000e+00]]

descriptors:

rdm_descriptors:
session = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
subj = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]
time = [-0.19147041 -0.15817121 -0.12487201 -0.09157281 -0.0582736  -0.0249744
  0.0083248   0.041624    0.0749232   0.10822241  0.14152161  0.17482081
  0.20812001  0.24141922  0.27471842  0.30801762  0.34131682  0.37461602
  0.40791523  0.44121443  0.47451363  0.50781283  0.54111204  0.57441124
  0.60771044  0.64100964  0.67430884  0.70760805  0.74090725]

pattern_descriptors:
index = [0, 1, 2, 3]
conds = [1.0, 2.0, 3.0, 4.0]


from here on

The following are examples for data analysis and plotting with temporal data. So far it uses the functions for non-temporal data of the toolbox. This section should be expanded once new temporal RSA functions are added to the toolbox.

I here use plotting from the standard plotting function.

[15]:
plt.figure(figsize=(10,15))

# add formated time as rdm_descriptor
rdms_data_binned.rdm_descriptors['time_formatted'] = ['%0.0f ms' % (np.round(x*1000,2)) for x in rdms_data_binned.rdm_descriptors['time']]

rsatoolbox.vis.show_rdm(rdms_data_binned,
                   pattern_descriptor='conds',
                   rdm_descriptor='time_formatted')
[15]:
(<Figure size 597.6x720 with 30 Axes>,
 array([[<AxesSubplot:title={'center':'-191 ms'}>,
         <AxesSubplot:title={'center':'-158 ms'}>,
         <AxesSubplot:title={'center':'-125 ms'}>,
         <AxesSubplot:title={'center':'-92 ms'}>,
         <AxesSubplot:title={'center':'-58 ms'}>,
         <AxesSubplot:title={'center':'-25 ms'}>],
        [<AxesSubplot:title={'center':'8 ms'}>,
         <AxesSubplot:title={'center':'42 ms'}>,
         <AxesSubplot:title={'center':'75 ms'}>,
         <AxesSubplot:title={'center':'108 ms'}>,
         <AxesSubplot:title={'center':'142 ms'}>,
         <AxesSubplot:title={'center':'175 ms'}>],
        [<AxesSubplot:title={'center':'208 ms'}>,
         <AxesSubplot:title={'center':'241 ms'}>,
         <AxesSubplot:title={'center':'275 ms'}>,
         <AxesSubplot:title={'center':'308 ms'}>,
         <AxesSubplot:title={'center':'341 ms'}>,
         <AxesSubplot:title={'center':'375 ms'}>],
        [<AxesSubplot:title={'center':'408 ms'}>,
         <AxesSubplot:title={'center':'441 ms'}>,
         <AxesSubplot:title={'center':'475 ms'}>,
         <AxesSubplot:title={'center':'508 ms'}>,
         <AxesSubplot:title={'center':'541 ms'}>,
         <AxesSubplot:title={'center':'574 ms'}>],
        [<AxesSubplot:title={'center':'608 ms'}>,
         <AxesSubplot:title={'center':'641 ms'}>,
         <AxesSubplot:title={'center':'674 ms'}>,
         <AxesSubplot:title={'center':'708 ms'}>,
         <AxesSubplot:title={'center':'741 ms'}>, <AxesSubplot:>]],
       dtype=object),
 defaultdict(dict,
             {<AxesSubplot:title={'center':'-191 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fb001431c10>,
               'y_labels': [Text(0, 0, '1.0'),
                Text(0, 1, '2.0'),
                Text(0, 2, '3.0'),
                Text(0, 3, '4.0')],
               'x_labels': [Text(0, 0, '1.0'),
                Text(1, 0, '2.0'),
                Text(2, 0, '3.0'),
                Text(3, 0, '4.0')]},
              <AxesSubplot:title={'center':'-158 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff12adc70>,
               'x_labels': [Text(0, 0, '1.0'),
                Text(1, 0, '2.0'),
                Text(2, 0, '3.0'),
                Text(3, 0, '4.0')]},
              <AxesSubplot:title={'center':'-125 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff12c9250>,
               'x_labels': [Text(0, 0, '1.0'),
                Text(1, 0, '2.0'),
                Text(2, 0, '3.0'),
                Text(3, 0, '4.0')]},
              <AxesSubplot:title={'center':'-92 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff12d39a0>,
               'x_labels': [Text(0, 0, '1.0'),
                Text(1, 0, '2.0'),
                Text(2, 0, '3.0'),
                Text(3, 0, '4.0')]},
              <AxesSubplot:title={'center':'-58 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fafd0ad0d30>,
               'x_labels': [Text(0, 0, '1.0'),
                Text(1, 0, '2.0'),
                Text(2, 0, '3.0'),
                Text(3, 0, '4.0')]},
              <AxesSubplot:title={'center':'-25 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff09fb9a0>,
               'x_labels': [Text(0, 0, '1.0'),
                Text(1, 0, '2.0'),
                Text(2, 0, '3.0'),
                Text(3, 0, '4.0')]},
              <AxesSubplot:title={'center':'8 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff11af9d0>,
               'y_labels': [Text(0, 0, '1.0'),
                Text(0, 1, '2.0'),
                Text(0, 2, '3.0'),
                Text(0, 3, '4.0')]},
              <AxesSubplot:title={'center':'42 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fb02092d6a0>},
              <AxesSubplot:title={'center':'75 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff11dd100>},
              <AxesSubplot:title={'center':'108 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff11fd910>},
              <AxesSubplot:title={'center':'142 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fb020a0bd90>},
              <AxesSubplot:title={'center':'175 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fafe0a11280>},
              <AxesSubplot:title={'center':'208 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fb0113439a0>,
               'y_labels': [Text(0, 0, '1.0'),
                Text(0, 1, '2.0'),
                Text(0, 2, '3.0'),
                Text(0, 3, '4.0')]},
              <AxesSubplot:title={'center':'241 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fb02097ff70>},
              <AxesSubplot:title={'center':'275 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fafd0b78ee0>},
              <AxesSubplot:title={'center':'308 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fafe09c85b0>},
              <AxesSubplot:title={'center':'341 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fafe09e4cd0>},
              <AxesSubplot:title={'center':'375 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fafe0a16bb0>},
              <AxesSubplot:title={'center':'408 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fb011311ac0>,
               'y_labels': [Text(0, 0, '1.0'),
                Text(0, 1, '2.0'),
                Text(0, 2, '3.0'),
                Text(0, 3, '4.0')]},
              <AxesSubplot:title={'center':'441 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fafe0a801c0>},
              <AxesSubplot:title={'center':'475 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fafe08c2520>},
              <AxesSubplot:title={'center':'508 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff11e6430>},
              <AxesSubplot:title={'center':'541 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff1139310>},
              <AxesSubplot:title={'center':'574 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff116b220>},
              <AxesSubplot:title={'center':'608 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff123c100>,
               'y_labels': [Text(0, 0, '1.0'),
                Text(0, 1, '2.0'),
                Text(0, 2, '3.0'),
                Text(0, 3, '4.0')]},
              <AxesSubplot:title={'center':'641 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff12fdaf0>},
              <AxesSubplot:title={'center':'674 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff1310d00>},
              <AxesSubplot:title={'center':'708 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7faff11b4af0>},
              <AxesSubplot:title={'center':'741 ms'}>: {'image': <matplotlib.image.AxesImage at 0x7fafe08799d0>}}))
<Figure size 720x1080 with 0 Axes>
_images/demo_temporal_32_2.png

Model rdms

This is a simple example with basic model RDMs

[16]:
from rsatoolbox.rdm import get_categorical_rdm
[17]:
rdms_model_in = get_categorical_rdm(['%d' % x for x in range(4)])
rdms_model_lr = get_categorical_rdm(['l','r','l','r'])
rdms_model_av = get_categorical_rdm(['a','a','v','v'])

model_names = ['independent', 'left/right', 'audio/visual']

# append in one RDMs object

model_rdms = rdms_model_in
model_rdms.append(rdms_model_lr)
model_rdms.append(rdms_model_av)

model_rdms.rdm_descriptors['model_names'] = model_names
model_rdms.pattern_descriptors['cond_names'] = cond_names
[18]:
rsatoolbox.vis.show_rdm(model_rdms, rdm_descriptor='model_names', pattern_descriptor = 'cond_names')
None
_images/demo_temporal_37_0.png

data - model similarity across time

[19]:
from rsatoolbox.rdm import compare
[20]:
r = []
for mod in model_rdms:
    r.append(compare(mod, rdms_data_binned, method='cosine'))

for i, r_ in enumerate(r):
    plt.plot(rdms_data_binned.rdm_descriptors['time'], r_.squeeze(), label=model_names[i])

plt.xlabel('time')
plt.ylabel('model-data cosine similarity')
plt.legend()
[20]:
<matplotlib.legend.Legend at 0x7fb0015669d0>
_images/demo_temporal_40_1.png