phase_tuning.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. from tqdm import tqdm
  5. import argparse
  6. from scipy.interpolate import interp1d
  7. from util import zero_runs
  8. from hht import HHT
  9. from parameters import *
  10. from util import (load_data, filter_units, switch_ranges, merge_ranges, angle_subtract,
  11. circmean, get_trials, shuffle_bins)
  12. def resample(data, old_tpts, new_tpts, axis=0, fill_value='extrapolate'):
  13. """
  14. Use linear interpolation to re-sample data.
  15. """
  16. # interpolate time-course with linear splines
  17. func = interp1d(old_tpts, data, axis=axis, fill_value=fill_value)
  18. # get new time-course
  19. interpolated_data = func(new_tpts)
  20. return interpolated_data
  21. def phase2rank(alpha):
  22. """
  23. Convert angles to circular ranks.
  24. """
  25. n = len(alpha)
  26. ranks = np.full(n, np.nan)
  27. ranks[alpha.argsort()] = np.arange(n)
  28. return 2 * np.pi * ranks / n
  29. def rank2phase(rank, alpha):
  30. """
  31. Convert a circular rank back to phase in the original distribution.
  32. """
  33. n = len(alpha)
  34. # convert circular rank to linear rank
  35. linrank = n * rank / 2 / np.pi
  36. return np.sort(alpha)[np.round(linrank).astype('int')]
  37. def inds2train(inds, length):
  38. """
  39. Convert event indices into binary time-series.
  40. Parameters
  41. ----------
  42. inds : 1D array
  43. event indices
  44. length : int
  45. total length of segment in which events occur
  46. Returns
  47. -------
  48. event_train : ndarray
  49. binary time-series
  50. """
  51. # initialize output array
  52. event_train = np.zeros(length, dtype='uint8')
  53. event_train[inds] = 1
  54. return event_train
  55. def times2train(evts, tpts):
  56. """
  57. Convert event times into binary time-series.
  58. Parameters
  59. ----------
  60. evts : 1D array
  61. event times
  62. tpts : 1D array
  63. time base in which events occur
  64. Returns
  65. -------
  66. event_train : ndarray
  67. binary time-series array
  68. """
  69. # clip events that fall out of time base
  70. evts_in_tpts = evts[(evts > tpts.min()) & (evts < tpts.max())]
  71. # convert times to indices
  72. evis = tpts.searchsorted(evts_in_tpts)
  73. # get the event train
  74. ev_train = inds2train(evis, len(tpts))
  75. return ev_train
  76. def modified_mrl2(alpha, w=None, axis=0):
  77. """
  78. A bias-free measure of the squared mean resultant length [1].
  79. Parameters
  80. ----------
  81. alpha : ndarray
  82. array of angles
  83. w : ndarray
  84. array of weights, must be same shape as alpha
  85. axis : int, None
  86. axis across which to compute mean
  87. Returns
  88. -------
  89. out : ndarray
  90. bias-corrected squared mean resultant length
  91. Notes
  92. -----
  93. - taking the square-root of this measure does *not* provide a bias-free
  94. measure of the mean resultant length, see [1].
  95. References
  96. ----------
  97. [1] Kutil, R. (2012). Biased and unbiased estimation of the circular mean
  98. resultant length and its variance. Statistics, 46(4), 549-561.
  99. """
  100. mrl, _ = circmean(alpha, w=w, axis=axis)
  101. n = alpha.shape[axis]
  102. return (n / (n - 1)) * (mrl ** 2 - (1 / n))
  103. def phase_tuning(phase, spk_train, shuffle_binwidth=1000, n_shuffles=1000):
  104. ranks = phase2rank(phase) - np.pi
  105. assert len(phase) == len(spk_train)
  106. # Get phase ranks where spikes occur
  107. spk_ranks = ranks[spk_train == 1]
  108. # Compute modified mean vector length
  109. r = modified_mrl2(spk_ranks)
  110. # Compute mean rank
  111. _, mean_rank = circmean(spk_ranks)
  112. # Convert rank back to phase
  113. theta = rank2phase(mean_rank + np.pi, phase)
  114. # Compute tuning strength for shuffled spike trains
  115. r_shf = np.full(n_shuffles, np.nan)
  116. for shf_i in range(n_shuffles):
  117. # Shuffle time bins of spike train
  118. spk_train_shf = shuffle_bins(spk_train, shuffle_binwidth)
  119. # Take phase of shuffled train
  120. spk_ranks_shf = ranks[spk_train_shf == 1]
  121. # Compute tuning strength
  122. r_shf[shf_i] = modified_mrl2(spk_ranks_shf)
  123. p = (r_shf > r).sum() / n_shuffles
  124. return r, theta, p
  125. if __name__ == "__main__":
  126. parser = argparse.ArgumentParser()
  127. parser.add_argument('e_name')
  128. parser.add_argument('-s', '--spk_types', nargs='+', default=['tonicspk', 'burst'])
  129. parser.add_argument('-t', '--tranges', default='')
  130. args = parser.parse_args()
  131. if args.tranges:
  132. assert args.tranges in ['run', 'sit', 'desync', 'sizematched', 'nosaccade', 'noopto']
  133. df_pupil = load_data('pupil', [args.e_name])
  134. df_pupil.set_index(['m', 's', 'e'], inplace=True)
  135. df_spikes = load_data('spikes', [args.e_name])
  136. df_spikes.set_index(['m', 's', 'e'], inplace=True)
  137. df_spikes = filter_units(df_spikes, MINRATE)
  138. ## TODO: find a better way to integrate saccades
  139. if 'saccade' in args.spk_types:
  140. df_spikes['saccade_times'] = [df_pupil.loc[idx]['saccade_times'] for idx, unit in df_spikes.iterrows()]
  141. # Load data for requested time ranges
  142. if args.tranges in ['nosaccade']:
  143. df_pupil = load_data('pupil', [args.e_name]).set_index(['m', 's', 'e'])
  144. elif args.tranges in ['desync', 'sizematched']:
  145. df_hht = load_data('hht', [args.e_name]).set_index(['m', 's', 'e'])
  146. elif args.tranges in ['run', 'sit']:
  147. df_run = load_data('ball', [args.e_name]).set_index(['m', 's', 'e'])
  148. elif args.tranges in ['noopto']:
  149. df_trials = load_data('trials', [args.e_name])
  150. df_trials.rename(columns={'trial_on_time':'trial_on_times', 'trial_off_time':'trial_off_times'}, inplace=True)
  151. df_trials = df_trials.apply(get_trials, stim_id=-1, axis='columns').set_index(['m', 's', 'e'])
  152. seriess = []
  153. for idx, row in tqdm(df_pupil.iterrows(), total=len(df_pupil)):
  154. pupil_area = row['pupil_area']
  155. pupil_tpts = row['pupil_tpts']
  156. # Get IMFs
  157. pupil_fs = 1 / np.diff(pupil_tpts).mean()
  158. hht = HHT(pupil_area, pupil_fs)
  159. hht.emd()
  160. # Get phases and frequencies
  161. hht.hsa()
  162. hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
  163. imf_phases = hht.phase.T
  164. imf_freqs = hht.characteristic_frequency
  165. imf_power = hht.power_ratio
  166. # Get time-ranges
  167. if args.tranges in ['run', 'sit']:
  168. try:
  169. tranges = df_run.loc[idx, '%s_bouts' % args.tranges]
  170. except KeyError:
  171. print("No run data found for ", idx)
  172. continue
  173. if args.tranges == 'run':
  174. dt0 = BEHAVEXCLUSIONS['run'][1]
  175. dt1 = BEHAVEXCLUSIONS['sit'][0]
  176. elif args.tranges == 'sit':
  177. dt0 = BEHAVEXCLUSIONS['sit'][1]
  178. dt1 = BEHAVEXCLUSIONS['run'][0]
  179. ext = np.ones_like(tranges) * np.array([dt0, dt1])
  180. tranges = tranges + ext
  181. tranges = np.row_stack([trange for trange in tranges if trange[0] < trange[1]])
  182. tranges = [tranges for imf in range(hht.n_imfs)]
  183. elif args.tranges in ['desync', 'sizematched']:
  184. try:
  185. tranges = df_hht.loc[idx, '%s_bouts' % args.tranges]
  186. except KeyError:
  187. print("No HHT data found for ", idx)
  188. continue
  189. elif args.tranges in ['nosaccade']:
  190. saccade_times = df_pupil.loc[idx, 'saccade_times']
  191. saccade_tranges = np.column_stack([saccade_times, saccade_times])
  192. saccade_tranges += np.array(BEHAVEXCLUSIONS['saccade'])
  193. saccade_tranges = merge_ranges(saccade_tranges, dt=(1 / pupil_fs))
  194. tranges = switch_ranges(
  195. saccade_tranges,
  196. dt=(1 / pupil_fs),
  197. minval=pupil_tpts.min(),
  198. maxval=pupil_tpts.max()
  199. )
  200. tranges = [tranges for imf in range(hht.n_imfs)]
  201. elif args.tranges in ['opto', 'noopto']:
  202. try:
  203. trial_ids = df_trials.loc[idx, 'trial_id']
  204. opto_trials = df_trials.loc[idx, 'opto_trials']
  205. trial_on_time = df_trials.loc[idx, 'trial_on_times']
  206. trial_off_time = df_trials.loc[idx, 'trial_off_times']
  207. except KeyError:
  208. print("No trial data found for ", idx)
  209. continue
  210. t0s = trial_on_time
  211. t1s = trial_off_time
  212. tranges = np.column_stack([t0s, t1s])
  213. tranges = [tranges for imf in range(hht.n_imfs)]
  214. elif args.tranges in ['half1', 'half2']:
  215. t0, t1 = pupil_tpts.min(), pupil_tpts.max()
  216. half_length = (t1 - t0) / 2
  217. if args.tranges == 'half1':
  218. tranges = [np.array([[t0, t0 + half_length]]) for imf in range(hht.n_imfs)]
  219. if args.tranges == 'half2':
  220. tranges = [np.array([[t0 + half_length, t1]]) for imf in range(hht.n_imfs)]
  221. elif args.tranges in ['split1', 'split2']:
  222. imf_cycles = [pupil_tpts[np.where(np.diff(phase) < -np.pi)[0]] for phase in hht.phase.T]
  223. imf_cycles = [np.concatenate([pupil_tpts[:1], cycles, pupil_tpts[-1:]]) for cycles in imf_cycles]
  224. cycle_tranges = [np.column_stack([cycles[:-1], cycles[1:]]) for cycles in imf_cycles]
  225. if args.tranges == 'split1':
  226. tranges = [cycles[0::2] for cycles in cycle_tranges]
  227. else:
  228. tranges = [cycles[1::2] for cycles in cycle_tranges]
  229. # Get units for this experiment
  230. try:
  231. df_units = df_spikes.loc[idx]
  232. except KeyError:
  233. print("Spikes missing for {}".format(idx))
  234. continue
  235. for _, unit in df_units.iterrows():
  236. unit_tpts = np.arange(*unit['spk_tinfo'])
  237. t0, t1 = row['pupil_tpts'].min(), row['pupil_tpts'].max()
  238. i0, i1 = unit_tpts.searchsorted([t0, t1])
  239. unit_tpts = unit_tpts[i0:i1]
  240. imf_phases_resamp = resample(imf_phases, pupil_tpts, unit_tpts, axis=1)
  241. spk_trains = {}
  242. for spk_type in args.spk_types:
  243. spk_trains[spk_type] = times2train(unit['{}_times'.format(spk_type)], unit_tpts)
  244. for imf_i, phase in enumerate(imf_phases_resamp):
  245. data = {
  246. 'm': idx[0],
  247. 's': idx[1],
  248. 'e': idx[2],
  249. 'u': unit['u'],
  250. 'imf': imf_i + 1,
  251. 'freq': imf_freqs[imf_i],
  252. 'power': imf_power[imf_i]
  253. }
  254. # Get time ranges to analyze
  255. if args.tranges:
  256. tranges_imf = tranges[imf_i]
  257. else:
  258. tranges_imf = np.array([[t0, t1]])
  259. iranges_imf = unit_tpts.searchsorted(tranges_imf)
  260. unit_fs = 1 / np.diff(unit_tpts).mean()
  261. binwidth = np.floor(unit_fs * SHUFFLE_BINWIDTH).astype('int')
  262. for spk_type, spk_train in spk_trains.items():
  263. # Take only data in ranges
  264. if len(iranges_imf) > 0:
  265. phase_clipped = np.concatenate([phase[i0:i1] for i0, i1 in iranges_imf])
  266. train_clipped = np.concatenate([spk_train[i0:i1] for i0, i1 in iranges_imf])
  267. else:
  268. train_clipped = np.array([])
  269. # Check that there are enough spikes to do analysis
  270. nspikes = train_clipped.sum()
  271. if nspikes < NSPIKES:
  272. r = theta = p = np.nan
  273. else:
  274. r, theta, p = phase_tuning(
  275. phase_clipped, train_clipped,
  276. shuffle_binwidth=binwidth, n_shuffles=NSHUFFLES
  277. )
  278. data['_'.join([spk_type, 'n'])] = nspikes
  279. data['_'.join([spk_type, 'strength'])] = r
  280. data['_'.join([spk_type, 'phase'])] = theta
  281. data['_'.join([spk_type, 'p'])] = p
  282. seriess.append(pd.Series(data=data))
  283. df_tuning = pd.DataFrame(seriess)
  284. if args.tranges:
  285. filename = 'phasetuning_{}_{}.pkl'.format(args.e_name, args.tranges)
  286. elif args.spk_types == ['saccade']:
  287. filename = 'phasetuning_{}_{}.pkl'.format(args.e_name, 'saccades')
  288. else:
  289. filename = 'phasetuning_{}.pkl'.format(args.e_name)
  290. df_tuning.to_pickle(DATAPATH + filename)