triggered_spiking.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. from tqdm import tqdm
  5. import argparse
  6. import quantities as pq
  7. from neo import SpikeTrain
  8. from elephant.statistics import instantaneous_rate
  9. from elephant.kernels import GaussianKernel
  10. from parameters import DATAPATH, MINRATE, NSPIKES, TRIGGEREDAVERAGES, NSHUFFLES, SHUFFLE_BINWIDTH
  11. from util import load_data, filter_units, shuffle_bins
  12. if __name__ == "__main__":
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument('e_name')
  15. parser.add_argument('region')
  16. parser.add_argument('-t', '--triggers', nargs='+', default=['saccade', 'run', 'sit'])
  17. parser.add_argument('-s', '--spk_types', nargs='+', default=['tonicspk', 'burst'])
  18. args = parser.parse_args()
  19. df_run = load_data('ball', [args.e_name], region=args.region)
  20. df_run['run_times'] = df_run['run_bouts'].apply(lambda x: x[:, 0])
  21. df_run['sit_times'] = df_run['run_bouts'].apply(lambda x: x[:, 1])
  22. df_pupil = load_data('pupil', [args.e_name], region=args.region)
  23. df_triggers = pd.merge(df_run, df_pupil)
  24. if 'trial_on' in args.triggers:
  25. df_trials = load_data('trials', [args.e_name], region=args.region)
  26. df_trials.rename(columns={'trial_on_time':'trial_on_times'}, inplace=True)
  27. df_triggers = pd.merge(df_triggers, df_trials)
  28. if 'burst' in args.triggers:
  29. import sys; sys.exit()
  30. df_triggers.set_index(['m', 's', 'e'], inplace=True)
  31. df_spikes = load_data('spikes', [args.e_name], region=args.region).set_index(['m', 's', 'e'])
  32. df_spikes = filter_units(df_spikes, MINRATE)
  33. seriess = []
  34. for idx, unit in tqdm(df_spikes.iterrows(), total=len(df_spikes)):
  35. data = {
  36. 'm':idx[0],
  37. 's':idx[1],
  38. 'e':idx[2],
  39. 'u':unit['u']
  40. }
  41. for trigger in args.triggers:
  42. #try:
  43. trigger_times = df_triggers.loc[idx]['%s_times' % trigger]
  44. #except KeyError:
  45. #continue
  46. pars = TRIGGEREDAVERAGES[trigger]
  47. fs = pars['dt'] * pq.s
  48. kernel = GaussianKernel(pars['bw'] * pq.s)
  49. for spk_type in args.spk_types:
  50. # Get inst. rate for whole experiment
  51. spk_times = unit['%s_times' % spk_type]
  52. if len(spk_times) < NSPIKES:
  53. continue
  54. t0 = min([unit['spk_tinfo'][0], spk_times.min()])
  55. t1 = max([unit['spk_tinfo'][1], spk_times.max()])
  56. spk_train = SpikeTrain(spk_times, t_start=t0 * pq.s, t_stop=t1 * pq.s, units='s')
  57. inst_rate = instantaneous_rate(spk_train, sampling_period=fs, kernel=kernel)
  58. inst_rate = inst_rate.squeeze().magnitude
  59. # Get responses to each trigger
  60. spk_tpts = np.linspace(t0, t1, inst_rate.shape[0])
  61. trigger_times = trigger_times[trigger_times < (spk_tpts.max() - pars['post'])]
  62. trigger_times = trigger_times[trigger_times > (spk_tpts.min() - pars['pre'])]
  63. i0s = spk_tpts.searchsorted(trigger_times) + int(pars['pre'] / pars['dt'])
  64. i1s = spk_tpts.searchsorted(trigger_times) + int(pars['post'] / pars['dt'])
  65. responses = np.row_stack([inst_rate[i0:i1] for i0, i1 in zip(i0s, i1s)])
  66. # Baseline normalize responses
  67. response_tpts = np.linspace(pars['pre'], pars['post'], responses.shape[1])
  68. b0, b1 = response_tpts.searchsorted(pars['baseline'])
  69. responses = (responses.T - responses[:, b0:b1].mean(axis=1)).T
  70. # Take mean
  71. triggered_average = responses.mean(axis=0)
  72. # Get triggereg averages from shuffled rates
  73. triggered_average_shf = np.full((NSHUFFLES, triggered_average.shape[0]), np.nan)
  74. for shf_i in range(NSHUFFLES):
  75. shuffle_binwidth = int(SHUFFLE_BINWIDTH / pars['dt'])
  76. inst_rate_shf = shuffle_bins(inst_rate, shuffle_binwidth)
  77. responses_shf = np.row_stack([inst_rate_shf[i0:i1] for i0, i1 in zip(i0s, i1s)])
  78. responses_shf = (responses_shf.T - responses_shf[:, b0:b1].mean(axis=1)).T
  79. triggered_average_shf[shf_i] = responses_shf.mean(axis=0)
  80. ci_low, ci_high = np.percentile(triggered_average_shf, [2.5, 97.5], axis=0)
  81. sig = (triggered_average < ci_low).any() | (triggered_average > ci_high).any()
  82. data[f'{trigger}_{spk_type}_response'] = triggered_average
  83. data[f'{trigger}_{spk_type}_tpts'] = response_tpts
  84. data[f'{trigger}_{spk_type}_sig'] = sig
  85. seriess.append(pd.Series(data=data))
  86. df_resp = pd.DataFrame(seriess)
  87. filename = f'responses_{args.e_name}_{args.region}.pkl'
  88. df_resp.to_pickle(DATAPATH + filename)