imf_decoding.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import argparse
  2. import numpy as np
  3. import pandas as pd
  4. from tqdm import tqdm
  5. from sklearn.svm import SVC
  6. from sklearn.model_selection import RepeatedStratifiedKFold, StratifiedKFold, cross_val_score
  7. from util import (load_data, get_trials, filter_units, get_psth, get_responses, circmean,
  8. angle_subtract)
  9. from phase_tuning import HHT
  10. from parameters import DATAPATH, NIMFCYCLES, NSPIKES, MINRATE
  11. if __name__ == "__main__":
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument('e_name')
  14. args = parser.parse_args()
  15. NSPLITS = 5
  16. NPHASEBINS = 2
  17. PHASEBINS = np.linspace(-np.pi, np.pi, NPHASEBINS + 1)
  18. df_pupil = load_data('pupil', [args.e_name])
  19. df_trials = load_data('trials', [args.e_name])
  20. df_trials.rename(columns={'trial_on_time':'trial_on_times', 'trial_off_time':'trial_off_times'}, inplace=True)
  21. df_trials = df_trials.apply(get_trials, stim_id=0, axis='columns')
  22. df = pd.merge(df_pupil, df_trials).set_index(['m', 's', 'e'])
  23. df_spikes = load_data('spikes', [args.e_name]).set_index(['m', 's', 'e'])
  24. #df_spikes = df_spikes[df_spikes.index.isin(df_pupil.index)]
  25. df_spikes = filter_units(df_spikes, MINRATE)
  26. df_tuning = load_data('phasetuning', [args.e_name], tranges='noopto').set_index(['m', 's', 'e', 'u'])
  27. seriess = []
  28. for idx, row in tqdm(df.iterrows(), total=len(df)):
  29. pupil_area = row['pupil_area']
  30. pupil_tpts = row['pupil_tpts']
  31. pupil_fs = 1 / np.diff(pupil_tpts).mean()
  32. # Get IMFs
  33. hht = HHT(pupil_area, pupil_fs)
  34. hht.emd()
  35. hht.hsa()
  36. hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
  37. imf_freqs = hht.characteristic_frequency
  38. imf_phases = hht.phase.T
  39. trial_starts = row['trial_on_times']
  40. #trial_stops = row['trial_off_times']
  41. #trial_duration = (trial_stops - trial_starts).mean()
  42. trial_duration = 5
  43. trial_starts = trial_starts[(trial_starts + trial_duration) < pupil_tpts.max()]
  44. stim_labels = np.repeat(np.arange(NSPLITS), len(trial_starts))
  45. # Get units for this experiment
  46. try:
  47. df_units = df_spikes.loc[idx]
  48. df_units = df_units.reset_index().set_index(['m', 's', 'e', 'u'])
  49. except KeyError:
  50. print("Spikes missing for {}".format(idx))
  51. continue
  52. for u_idx, unit in df_units.iterrows():
  53. print(u_idx)
  54. for imf_i, phase in enumerate(imf_phases):
  55. # Skip if no phase tuning analysis was done for this unit
  56. df_unittuning = df_tuning.loc[u_idx].query('imf == %d' % (imf_i + 1))
  57. if len(df_unittuning) < 1:
  58. continue
  59. print(imf_i)
  60. data = {
  61. 'm': idx[0],
  62. 's': idx[1],
  63. 'e': idx[2],
  64. 'u': u_idx[-1],
  65. 'imf': imf_i + 1,
  66. 'freq': imf_freqs[imf_i]
  67. }
  68. phase_raster, raster_tpts = get_responses(trial_starts, phase, pupil_tpts, post=trial_duration)
  69. split_inds = np.linspace(0, len(raster_tpts), NSPLITS + 1).astype(int)
  70. # Mean phase for each stimulus segment
  71. phase_means = np.concatenate(([circmean(phase_raster[:, i0:i1], axis=1)[1] for i0, i1 in zip(split_inds[:-1], split_inds[1:])]))
  72. # Decode IMF phase using each spike type
  73. for spike_type in ['tonicspk', 'burst']:
  74. print(spike_type)
  75. # get raster for spike type and split into segments
  76. spike_times = unit['%s_times' % spike_type]
  77. if len(spike_times) <= NSPIKES:
  78. continue
  79. spike_raster, spike_tpts = get_psth(trial_starts, spike_times, post=trial_duration)
  80. split_inds = np.linspace(0, len(spike_tpts), NSPLITS + 1).astype(int)
  81. X = np.row_stack([spike_raster[:, i0:i1] for i0, i1 in zip(split_inds[:-1], split_inds[1:])])
  82. # use tuning phase to set phase bins
  83. tuning_phase = df_unittuning['%s_phase' % spike_type][0]
  84. if np.isnan(tuning_phase):
  85. continue
  86. phase_shift = angle_subtract(tuning_phase, -1 * np.pi / 2) - np.pi
  87. phase_means_shifted = angle_subtract(phase_means, phase_shift) - np.pi
  88. phase_labels = np.digitize(phase_means_shifted, bins=PHASEBINS)
  89. phase_labels = phase_labels.clip(1, NPHASEBINS) - 1
  90. if len(np.unique(phase_labels)) < 2:
  91. raise RuntimeError
  92. # predict phase bin
  93. print("decoding phase")
  94. classifier = SVC(kernel='rbf')
  95. crossval = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
  96. scores = cross_val_score(classifier, X, phase_labels, cv=crossval)
  97. data['%s_phase' % spike_type] = scores.mean()
  98. ## Decode stimulus across phase bins using all spike times
  99. spike_times = unit['spk_times']
  100. if len(spike_times) <= NSPIKES:
  101. continue
  102. spike_raster, spike_tpts = get_psth(trial_starts, spike_times, post=trial_duration)
  103. #i0, i1 = spike_tpts.searchsorted([0, trial_duration])
  104. #spike_raster = spike_raster[:, i0:i1]
  105. #spike_tpts = spike_tpts[i0:i1]
  106. split_inds = np.linspace(0, len(spike_tpts), NSPLITS + 1).astype(int)
  107. X = np.row_stack([spike_raster[:, i0:i1] for i0, i1 in zip(split_inds[:-1], split_inds[1:])])
  108. # use tonic phase to set the phase bins
  109. tuning_phase = df_unittuning['tonicspk_phase'][0]
  110. if np.isnan(tuning_phase):
  111. continue
  112. phase_shift = angle_subtract(tuning_phase, -1 * np.pi / 2) - np.pi
  113. phase_means_shifted = angle_subtract(phase_means, phase_shift) - np.pi
  114. phase_labels = np.digitize(phase_means_shifted, bins=PHASEBINS)
  115. phase_labels = phase_labels.clip(1, NPHASEBINS) - 1
  116. if ((phase_labels == 0).mean() < 0.25) or ((phase_labels == 1).mean() < 0.25):
  117. print("Phase split biased")
  118. continue
  119. # split segments based on phase bin
  120. X1 = X[phase_labels.astype(bool)]
  121. y1 = stim_labels[phase_labels.astype(bool)]
  122. if len(np.unique(y1)) < 5:
  123. raise RuntimeError
  124. X2 = X[~phase_labels.astype(bool)]
  125. y2 = stim_labels[~phase_labels.astype(bool)]
  126. if len(np.unique(y2)) < 5:
  127. raise RuntimeError
  128. # train on phase bin 2 & test
  129. print("decoding stimulus, set 2")
  130. classifier = SVC(kernel='linear').fit(X2, y2)
  131. data['stim_train2_test2'] = classifier.score(X2, y2)
  132. data['stim_train2_test1'] = classifier.score(X1, y1)
  133. # train on the second phase bin & test
  134. print("decoding stimulus, set 1")
  135. classifier = SVC(kernel='linear').fit(X1, y1)
  136. data['stim_train1_test1'] = classifier.score(X1, y1)
  137. data['stim_train1_test2'] = classifier.score(X2, y2)
  138. # random
  139. splitter = RepeatedStratifiedKFold(n_splits=2, n_repeats=5, random_state=0)
  140. shf_diffs = np.full(10, np.nan)
  141. y = stim_labels
  142. for shf_i, (train, test) in enumerate(splitter.split(X, stim_labels)):
  143. classifier = SVC(kernel='linear').fit(X[train], y[train])
  144. shf_diffs[shf_i] = classifier.score(X[test], y[test]) - classifier.score(X[train], y[train])
  145. data['stim_testshf'] = shf_diffs
  146. seriess.append(pd.Series(data=data))
  147. df_decoding = pd.DataFrame(seriess)
  148. filename = 'imfdecoding_{}_norm.pkl'.format(args.e_name)
  149. df_decoding.to_pickle(DATAPATH + filename)