util.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891
  1. import numpy as np
  2. import pandas as pd
  3. import quantities as pq
  4. from matplotlib import pyplot as plt
  5. from scipy.interpolate import interp1d
  6. from statsmodels.stats.multitest import fdrcorrection
  7. from scipy.stats import norm, sem
  8. from sklearn.mixture import GaussianMixture
  9. from neo import SpikeTrain
  10. from elephant.statistics import instantaneous_rate
  11. from elephant.kernels import GaussianKernel
  12. from parameters import *
  13. ## Data handling
  14. def df2keys(df):
  15. """Return MSE keys for all entries in a DataFrame."""
  16. df = df.reset_index()
  17. keys = [key for key in df.columns if key in ['m', 's', 'e', 'u']]
  18. return [{key: val for key, val in zip(keys, vals)} for vals in df[keys].values]
  19. def key2idx(key):
  20. """Return DataFrame index tuple for the given key."""
  21. return tuple([val for k, val in key.items()])
  22. def df_metadata(df):
  23. """Print metadata for a DataFrame."""
  24. df = df.reset_index()
  25. print("No. mice: {}".format(len(df.groupby('m'))))
  26. print("No. series: {}".format(len(df.groupby(['m', 's']))))
  27. print("No. experiments: {}".format(len(df.groupby(['m', 's', 'e']))))
  28. if 'u' in df.columns:
  29. print("No. units: {}".format(len(df.groupby(['m', 's', 'e', 'u']))))
  30. for idx, df in df.groupby(['m', 's', 'e']):
  31. print("\n{} s{:02d} e{:02d}".format(idx[0], idx[1], idx[2]))
  32. print(" {} units".format(len(df)))
  33. def load_data(data, conditions, **kwargs):
  34. """
  35. Load data of a specified type and recording region, pooling across all
  36. requested conditions.
  37. """
  38. dfs = []
  39. for condition in conditions:
  40. filename = '{}_{}'.format(data, condition)
  41. for kw, arg in kwargs.items():
  42. filename = filename + '_{}'.format(arg)
  43. filename = filename + '.pkl'
  44. print("Loading: ", filename)
  45. df = pd.read_pickle(DATAPATH + filename)
  46. if 'condition' not in df.columns:
  47. df['condition'] = condition
  48. dfs.append(df)
  49. return pd.concat(dfs, axis='index')
  50. ## Plotting utils
  51. def set_plotsize(w, h=None, ax=None):
  52. """
  53. Set the size of a matplotlib axes object in cm.
  54. Parameters
  55. ----------
  56. w, h : float
  57. Desired width and height of plot, if height is None, the axis will be
  58. square.
  59. ax : matplotlib.axes
  60. Axes to resize, if None the output of plt.gca() will be re-sized.
  61. Notes
  62. -----
  63. - Use after subplots_adjust (if adjustment is needed)
  64. - Matplotlib axis size is determined by the figure size and the subplot
  65. margins (r, l; given as a fraction of the figure size), i.e.
  66. w_ax = w_fig * (r - l)
  67. """
  68. if h is None: # assume square
  69. h = w
  70. w /= 2.54 # convert cm to inches
  71. h /= 2.54
  72. if not ax: # get current axes
  73. ax = plt.gca()
  74. # get margins
  75. l = ax.figure.subplotpars.left
  76. r = ax.figure.subplotpars.right
  77. t = ax.figure.subplotpars.top
  78. b = ax.figure.subplotpars.bottom
  79. # set fig dimensions to produce desired ax dimensions
  80. figw = float(w)/(r-l)
  81. figh = float(h)/(t-b)
  82. ax.figure.set_size_inches(figw, figh)
  83. def clip_axes_to_ticks(ax=None, spines=['left', 'bottom'], ext={}):
  84. """
  85. Clip the axis lines to end at the minimum and maximum tick values.
  86. Parameters
  87. ----------
  88. ax : matplotlib.axes
  89. Axes to resize, if None the output of plt.gca() will be re-sized.
  90. spines : list
  91. Axes to keep and clip, axes not included in this list will be removed.
  92. Valid values include 'left', 'bottom', 'right', 'top'.
  93. ext : dict
  94. For each axis in ext.keys() ('left', 'bottom', 'right', 'top'),
  95. the axis line will be extended beyond the last tick by the value
  96. specified, e.g. {'left':[0.1, 0.2]} will results in an axis line
  97. that extends 0.1 units beyond the bottom tick and 0.2 unit beyond
  98. the top tick.
  99. """
  100. if ax is None:
  101. ax = plt.gca()
  102. spines2ax = {
  103. 'left': ax.yaxis,
  104. 'top': ax.xaxis,
  105. 'right': ax.yaxis,
  106. 'bottom': ax.xaxis
  107. }
  108. all_spines = ['left', 'bottom', 'right', 'top']
  109. for spine in spines:
  110. low = min(spines2ax[spine].get_majorticklocs())
  111. high = max(spines2ax[spine].get_majorticklocs())
  112. if spine in ext.keys():
  113. low += ext[spine][0]
  114. high += ext[spine][1]
  115. ax.spines[spine].set_bounds(low, high)
  116. for spine in [spine for spine in all_spines if spine not in spines]:
  117. ax.spines[spine].set_visible(False)
  118. def p2stars(p):
  119. if p <= 0.0001:
  120. return '***'
  121. elif p <= 0.001:
  122. return '**'
  123. elif p<= 0.05:
  124. return '*'
  125. else:
  126. return ''
  127. def violin_plot(dists, colors, ax=None, logscale=False):
  128. if type(colors) is list:
  129. assert len(colors) == len(dists)
  130. if ax is None:
  131. fig, ax = plt.subplots()
  132. violins = ax.violinplot(dists, showmedians=True, showextrema=False)
  133. for violin, color in zip(violins['bodies'], colors):
  134. violin.set_facecolor('none')
  135. violin.set_edgecolor(color)
  136. violin.set_alpha(1)
  137. violin.set_linewidth(2)
  138. violins['cmedians'].set_color('black')
  139. for pos, dist in enumerate(dists):
  140. median = np.median(dist)
  141. text = f'{median:.2f}'
  142. if logscale:
  143. text = f'{10 ** median:.2f}'
  144. ax.text(pos + 1.4, median, text, va='center', ha='center', rotation=-90, fontsize=LABELFONTSIZE)
  145. ax.set_xticks(np.arange(len(dists)) + 1)
  146. ax.tick_params(bottom=False)
  147. return ax
  148. def pupil_area_rate_heatmap(df, cmap='gray', max_rate='high', example=None):
  149. """
  150. Plot a heatmap of event rates (tonic spikes or bursts) where each row is a neuron and each column is a pupil size bin.
  151. Parameters
  152. ----------
  153. df : pandas.DataFrame
  154. Dataframe with neurons in the rows and mean firing rates for each pupil size bin in a column called 'area_means'.
  155. cmap : str or Matplotlib colormap object
  156. max_rate : str
  157. Is the max event rate expected to at 'high' or 'low' pupil sizes?
  158. example : dict
  159. If not none, MSEU key passed will be used to highlight example neuron.
  160. """
  161. fig = plt.figure()
  162. # Find out which pupil size bin has the highest firing rate
  163. df['tuning_max'] = df['area_means'].apply(np.argmax)
  164. # Min-max normalize firing rates for each neuron
  165. df['tuning_norm'] = df['area_means'].apply(lambda x: (x - x.min()) / (x.max() - x.min()))
  166. # Get heatmap for neurons with significant rate differences across pupil size bins
  167. df_sig = df.query('area_p <= 0.05').sort_values('tuning_max')
  168. heatmap_sig = np.row_stack(df_sig['tuning_norm'])
  169. n_sig = len(df_sig) # number of neurons with significant differences
  170. # Make axis with size proportional to the fraction of significant neurons
  171. n_units = len(df) # total number of neurons
  172. ax1_height = n_sig / n_units
  173. ax1 = fig.add_axes([0.1, 0.9 - (0.76 * ax1_height), 0.8, (0.76 * ax1_height)])
  174. # Plot heatmap for signifcant neurons
  175. mat = ax1.matshow(heatmap_sig, cmap=cmap, aspect='auto')
  176. # Scatter plot marking pupil size bin with maximum
  177. ax1.scatter(df_sig['tuning_max'], np.arange(len(df_sig)) - 0.5, s=0.5, color='black', zorder=3)
  178. # Format axes
  179. ax1.set_xticks([])
  180. yticks = np.insert(np.arange(0, n_sig, 20), -1, n_sig) # ticks every 20 neurons, and final count
  181. ax1.set_yticks(yticks - 0.5)
  182. ax1.set_yticklabels(yticks)
  183. ax1.set_ylabel('Neurons')
  184. # Make a colorbar
  185. cbar = fig.colorbar(mat, ax=ax1, ticks=[0, 1], location='top', shrink=0.75)
  186. cbar.ax.set_xticklabels(['min', 'max'])
  187. cbar.ax.set_xlabel('Spikes', labelpad=-5)
  188. # Print dotted line and label for neurons considered to have 'monotonic' modulation profiles
  189. # (highest event rate at one of the pupil size extremes)
  190. if max_rate == 'high':
  191. n_mon = len(df_sig.query('tuning_max >= 9'))
  192. print("Monotonic increasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
  193. ax1.axvline(8.5, lw=2, ls='--', color='white')
  194. ax1.set_title(r'Monotonic$\rightarrow$', fontsize=LABELFONTSIZE, loc='right', pad=0)
  195. elif max_rate == 'low':
  196. n_mon = len(df_sig.query('tuning_max <= 1'))
  197. print("Monotonic decreasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
  198. ax1.axvline(0.5, lw=2, ls='--', color='white')
  199. ax1.set_title(r'$\leftarrow$Monotonic', fontsize=LABELFONTSIZE, loc='left', pad=0)
  200. # Get heatmap for neurons without significant rate differences acrss pupil sizes
  201. df_ns = df.query('area_p > 0.05').sort_values('tuning_max')
  202. heatmap_ns = np.row_stack(df_ns['tuning_norm'])
  203. # Make axis with size proportional to the fraction of non-significant neurons
  204. n_ns = len(df_ns)
  205. ax2_height = n_ns / n_units
  206. ax2 = fig.add_axes([0.1, 0.1, 0.8, (0.76 * ax2_height)])
  207. # Plot heatmap for neurons without significant rate differences acrss pupil sizes
  208. mat = ax2.matshow(heatmap_ns, cmap='Greys', vmax=2, aspect='auto')
  209. ax2.scatter(df_ns['tuning_max'], np.arange(n_ns) - 0.5, s=0.5, color='black', zorder=3)
  210. # Format axes
  211. ax2.xaxis.set_ticks_position('bottom')
  212. ax2.set_xticks([-0.5, 4.5, 9.5])
  213. ax2.set_xticklabels([0, 0.5, 1]) # x-axis ticks mark percentiles of pupil size range
  214. ax2.set_xlim(right=9.5)
  215. ax2.set_xlabel('Pupil size (norm.)')
  216. yticks = np.insert(np.arange(0, n_ns, 20), -1, n_ns) # ticks every 20 neurons, and final count
  217. ax2.set_yticks(yticks - 0.5)
  218. ax2.set_yticklabels(yticks)
  219. # Highligh example neuron by circling scatter dot
  220. if example is not None:
  221. try: # first check if example neuron is among significant neurons
  222. is_ex = df_sig.index == tuple([v for k, v in example.items()])
  223. assert is_ex.any()
  224. ex_max = df_sig['tuning_max'][is_ex]
  225. ax1.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
  226. except:
  227. is_ex = df_ns.index == tuple([v for k, v in example.items()])
  228. ex_max = df_ns['tuning_max'][is_ex]
  229. ax2.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
  230. return fig
  231. def cumulative_histogram(data, bins, color='C0', ax=None):
  232. """Convenience function, cleaner looking plot than plt.hist(..., cumulative=True)."""
  233. if ax is None:
  234. fig, ax = plt.subplots()
  235. weights = np.ones_like(data) / len(data)
  236. counts, _ = np.histogram(data, bins=bins, weights=weights)
  237. ax.plot(bins[:-1], np.cumsum(counts), color=color)
  238. return ax
  239. def cumulative_hist(x, bins, density=True, ax=None, **kwargs):
  240. weights = np.ones_like(x)
  241. if density:
  242. weights = weights / len(x)
  243. counts, _ = np.histogram(x, bins=bins, weights=weights)
  244. if ax is None:
  245. fig, ax = plt.subplots()
  246. xs = np.insert(bins, np.arange(len(bins) - 1), bins[:-1])
  247. ys = np.insert(np.insert(np.cumsum(counts), np.arange(len(counts)), np.cumsum(counts)), 0, 0)
  248. ax.plot(xs, ys, lw=2, **kwargs)
  249. ax.set_xticks(bins + 0.5)
  250. ax.set_yticks([0, 0.5, 1])
  251. return ax, counts
  252. def phase_coupling_scatter(df, ax=None):
  253. """Phase-frequency scatter plot for phase coupling."""
  254. if ax is None:
  255. fig, ax = plt.subplots()
  256. for event in ['tonicspk', 'burst']:
  257. df_sig = df.query(f'{event}_sig == True')
  258. ax.scatter(np.log10(df_sig['freq']), df_sig[f'{event}_phase'], ec=COLORS[event], fc='none', lw=0.5, s=3)
  259. n_sig = max([len(df.query(f'{event}_sig == True').groupby(['m', 's', 'e', 'u'])) for event in ['tonicspk', 'burst']])
  260. ax.set_xticks(FREQUENCYTICKS)
  261. ax.set_xticklabels(FREQUENCYTICKLABELS)
  262. ax.set_xlim(left=-3.075)
  263. ax.set_xlabel("Inverse timescale (s$^{-1}$)")
  264. ax.set_yticks(PHASETICKS)
  265. ax.set_yticklabels(PHASETICKLABELS)
  266. ax.set_ylim([-np.pi - 0.15, np.pi + 0.15])
  267. ax.set_ylabel("Preferred phase")
  268. return ax
  269. def plot_circhist(angles, ax=None, bins=np.linspace(0, 2 * np.pi, 17), density=True, **kwargs):
  270. """Plot a circular histogram."""
  271. if ax is None:
  272. fig, ax = plt.subplots(subplot_kw={'polar':True})
  273. weights = np.ones_like(angles)
  274. if density:
  275. weights /= len(angles)
  276. counts, bins = np.histogram(angles, bins=bins, weights=weights)
  277. xs = bins + (np.pi / (len(bins) - 1))
  278. ys = np.append(counts, counts[0])
  279. ax.plot(xs, ys, **kwargs)
  280. ax.set_xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2])
  281. ax.set_xticklabels(['0', '\u03C0/2', '\u03C0', '3\u03C0/2'])
  282. ax.tick_params(axis='x', pad=-5)
  283. return ax, counts
  284. def coupling_strength_line_plot(df, agg=np.mean, err=sem, logscale=True, ax=None, **kwargs):
  285. """
  286. Line plot showing average burst and tonic spike coupling strengths and SE across timescale bins.
  287. """
  288. if ax is None:
  289. fig, ax = plt.subplots()
  290. for event in ['burst', 'tonicspk']:
  291. df_sig = df.query(f'({event}_sig == True) & ({event}_strength > 0)').copy()
  292. strengths = sort_data(df_sig[f'{event}_strength'], df_sig['freq'], bins=FREQUENCYBINS)
  293. ys = np.array([agg(s) for s in strengths])
  294. yerr = np.array([err(s) for s in strengths])
  295. if not logscale:
  296. ax.plot(FREQUENCYXPOS, ys, color=COLORS[event], **kwargs)
  297. ax.plot(FREQUENCYXPOS, ys + yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
  298. ax.plot(FREQUENCYXPOS, ys - yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
  299. else:
  300. ax.plot(FREQUENCYXPOS, np.log10(ys), color=COLORS[event], **kwargs)
  301. ax.plot(FREQUENCYXPOS, np.log10(ys + yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
  302. ax.plot(FREQUENCYXPOS, np.log10(ys - yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
  303. ax.set_xticks(FREQUENCYTICKS)
  304. ax.set_xticklabels(FREQUENCYTICKLABELS)
  305. ax.set_xlim(left=-3.1)
  306. ax.set_xlabel('Timescale (Hz)')
  307. ax.set_ylabel('Coupling strength')
  308. return ax
  309. ## Util
  310. def zero_runs(a):
  311. """
  312. Return an array with shape (m, 2), where m is the number of "runs" of zeros
  313. in a. The first column is the index of the first 0 in each run, the second
  314. is the index of the first nonzero element after the run.
  315. """
  316. # Create an array that is 1 where a is 0, and pad each end with an extra 0.
  317. iszero = np.concatenate(([0], np.equal(a, 0).view(np.int8), [0]))
  318. absdiff = np.abs(np.diff(iszero))
  319. # Runs start and end where absdiff is 1.
  320. ranges = np.where(absdiff == 1)[0].reshape(-1, 2)
  321. return ranges
  322. def merge_ranges(ranges, dt=1):
  323. """
  324. Given a set of ranges [start, stop], return new set of ranges where all
  325. overlapping ranges are merged.
  326. """
  327. tpts = np.arange(ranges.min(), ranges.max(), dt) # array of time points
  328. tc = np.ones_like(tpts) # time course of ranges
  329. for t0, t1 in ranges: # for each range
  330. i0, i1 = tpts.searchsorted([t0, t1])
  331. tc[i0:i1] = 0 # set values in range to 0
  332. new_ranges = zero_runs(tc) # get indices of continuous stretches of zero
  333. if new_ranges[-1, -1] == len(tpts): # fix end-point
  334. new_ranges[-1, -1] = len(tpts) - 1
  335. return tpts[new_ranges]
  336. def continuous_runs(data, max0len=1, min1len=1, min1prop=0):
  337. """
  338. Get start and stop indices of stretches of (relatively) continuous data.
  339. Parameters
  340. ----------
  341. data : ndarray
  342. 1D boolean array
  343. max0len : int
  344. maximum length (in data pts) of False stretches to ignore
  345. min1len : int
  346. minimum length (in data pts) of True runs to keep
  347. min1prop : int
  348. minimum proprtion of True data in the run necessary for it
  349. to be considered
  350. Returns
  351. -------
  352. out : ndarray
  353. (m, 2) array of start and stop indices, where m is the number runs of
  354. continuous True values
  355. """
  356. # get ranges of True values
  357. one_ranges = zero_runs(~data)
  358. if len(one_ranges) == 0:
  359. return np.array([[]])
  360. # merge ranges that are separated by < min0len of False
  361. one_ranges[:, 1] += (max0len - 1)
  362. one_ranges = merge_ranges(one_ranges)
  363. # return indices to normal
  364. one_ranges[:, 1] -= (max0len - 1)
  365. one_ranges[-1, -1] += 1
  366. # remove ranges that are too short
  367. lengths = np.diff(one_ranges, axis=1).ravel()
  368. one_ranges = one_ranges[lengths >= min1len]
  369. # remove ranges that don't have sufficient proportion True
  370. prop = np.array([data[i0:i1].sum() / (i1 - i0) for (i0, i1) in one_ranges])
  371. return one_ranges[prop >= min1prop]
  372. def switch_ranges(ranges, dt=1, minval=0, maxval=None):
  373. """
  374. Given a set of (start, stop) pairs, return a new set of pairs for values
  375. outside the given ranges.
  376. Parameters
  377. ----------
  378. ranges : ndarray
  379. N x 2 array containing start and stop values in the first and second
  380. columns respectively
  381. dt : float
  382. minval, maxval : int
  383. the minimum and maximum possible values, if maxval is None it is assumed
  384. that the maximum possible value is the maximum value in the input ranges
  385. Returns
  386. -------
  387. out : ndarray
  388. M x 2 array containing start and stop values of all ranges outside of
  389. the input ranges
  390. """
  391. if ranges.shape[1] == 0:
  392. return np.array([[minval, maxval]])
  393. assert (ranges.ndim == 2) & (ranges.shape[1] == 2), "A two-column array is expected"
  394. maxval = ranges.max() if maxval is None else maxval
  395. # get new ranges
  396. new_ranges = np.zeros_like(ranges)
  397. new_ranges[:,0] = ranges[:,0] - dt # new stop values
  398. new_ranges[:,1] = ranges[:,1] + dt # new start values
  399. # fix boundaries
  400. new_ranges = new_ranges.ravel()
  401. if new_ranges[0] >= (minval + dt): # first new stop within allowed range
  402. new_ranges = np.concatenate((np.array([minval]), new_ranges))
  403. else:
  404. new_ranges = new_ranges[1:]
  405. if new_ranges[-1] <= (maxval - dt): # first new start within allowed range
  406. new_ranges = np.concatenate((new_ranges, np.array([maxval])))
  407. else:
  408. new_ranges = new_ranges[:-1]
  409. return new_ranges.reshape((int(len(new_ranges) / 2), 2))
  410. def shuffle_bins(x, binwidth=1):
  411. """
  412. Randomly shuffle bins of an array.
  413. """
  414. # bin start indices
  415. bins_i0 = np.arange(0, len(x), binwidth)
  416. # shuffled bins
  417. np.random.shuffle(bins_i0)
  418. # concatenate shuffled bins
  419. shf = np.concatenate([x[i0:(i0 + binwidth)] for i0 in bins_i0])
  420. return shf
  421. def take_data_in_bouts(series, data, bouts, trange=None, dt=2, dt0=0, dt1=0, concatenate=True, norm=False):
  422. if series['%s_bouts' % bouts].shape[1] < 1:
  423. return np.array([])
  424. header, _ = data.split('_')
  425. data_in_bouts = []
  426. for t0, t1 in series['%s_bouts' % bouts]:
  427. t0 -= dt0
  428. t1 += dt1
  429. if trange == 'start':
  430. t1 = t0 + dt
  431. elif trange == 'end':
  432. t0 = t1 - dt
  433. elif trange == 'middle':
  434. t0 = t0 + dt
  435. t1 = t1 - dt
  436. if t1 <= t0:
  437. continue
  438. if t0 < series['%s_tpts' % header].min():
  439. continue
  440. if t1 > series['%s_tpts' % header].max():
  441. continue
  442. i0, i1 = series['%s_tpts' % header].searchsorted([t0, t1])
  443. data_in_bout = series[data][i0:i1].copy()
  444. if norm:
  445. data_in_bout = data_in_bout / series[data].max()
  446. data_in_bouts.append(data_in_bout)
  447. if concatenate:
  448. data_in_bouts = np.concatenate(data_in_bouts)
  449. return data_in_bouts
  450. def get_trials(series, stim_id=0, opto=False, multi_stim='warn'):
  451. if opto:
  452. opto = np.isin(series['trial_id'], series['opto_trials'])
  453. elif not opto:
  454. opto = ~np.isin(series['trial_id'], series['opto_trials'])
  455. if stim_id < 0:
  456. stim = np.ones_like(series['stim_id']).astype('bool')
  457. else:
  458. stim = series['stim_id'] == stim_id
  459. series['trial_on_times'] = series['trial_on_times'][opto & stim]
  460. series['trial_off_times'] = series['trial_off_times'][opto & stim]
  461. return series
  462. def sort_data(data, sort_vals, bins=10):
  463. if type(bins) == int:
  464. nbins = bins
  465. bin_edges = np.linspace(sort_vals.min(), sort_vals.max(), nbins + 1)
  466. else:
  467. nbins = len(bins) - 1
  468. bin_edges = bins
  469. digitized_vals = np.digitize(sort_vals, bins=bin_edges).clip(1, nbins)
  470. return [data[digitized_vals == val] for val in np.arange(nbins) + 1]
  471. def apply_sort_data(series, data_col, sort_col, bins=10):
  472. return sort_data(series[data_col], series[sort_col], bins)
  473. ## Statistics
  474. def get_binned_rates(spk_rates, pupil_area, sort=False, nbins=10):
  475. # Get bins base on percentiles to eliminate effect of outliers
  476. min_area, max_area = np.percentile(pupil_area, [2.5, 97.5])
  477. #min_area, max_area = pupil_area.min(), pupil_area.max()
  478. area_bins = np.linspace(min_area, max_area, nbins + 1)
  479. # Bin pupil area
  480. binned_area = np.digitize(pupil_area, bins=area_bins).clip(1, nbins) - 1
  481. # Bin rates according to pupil area
  482. binned_rates = np.array([spk_rates[binned_area == bin_i] for bin_i in np.arange(nbins)], dtype=object)
  483. if sort:
  484. sorted_inds = np.argsort([rates.mean() if len(rates) > 0 else 0 for rates in binned_rates])
  485. binned_rates = binned_rates[sorted_inds]
  486. binned_area = np.squeeze([np.where(sorted_inds == area_bin)[0] for area_bin in binned_area])
  487. return area_bins, binned_area, binned_rates
  488. def rescale(x, method='min_max'):
  489. # Set re-scaling method
  490. if method == 'z_score':
  491. return (x - x.mean()) / x.std()
  492. elif method == 'min_max':
  493. return (x - x.min()) / (x.max() - x.min())
  494. def correlogram(ts1, ts2=None, tau_max=1, dtau=0.01, return_tpts=False):
  495. if ts2 is None:
  496. ts2 = ts1.copy()
  497. auto = True
  498. else:
  499. auto = False
  500. tau_max = (tau_max // dtau) * dtau
  501. tbins = np.arange(-tau_max, tau_max + dtau, dtau)
  502. ccg = np.zeros(len(tbins) - 1)
  503. for t0 in ts1:
  504. dts = ts2 - t0
  505. if auto:
  506. dts = dts[dts != 0]
  507. ccg += np.histogram(dts[np.abs(dts) <= tau_max], bins=tbins)[0]
  508. ccg /= len(ts1)
  509. if not return_tpts:
  510. return ccg
  511. else:
  512. tpts = tbins[:-1] + dtau / 2
  513. return tpts, ccg
  514. def angle_subtract(a1, a2, period=(2 * np.pi)):
  515. return (a1 - a2) % period
  516. def circmean(alpha, w=None, axis=None):
  517. """
  518. Compute mean resultant vector of circular data.
  519. Parameters
  520. ----------
  521. alpha : ndarray
  522. array of angles
  523. w : ndarray
  524. array of weights, must be same shape as alpha
  525. axis : int, None
  526. axis across which to compute mean
  527. Returns
  528. -------
  529. mrl : ndarray
  530. mean resultant vector length
  531. theta : ndarray
  532. mean resultant vector angle
  533. """
  534. # weights default to ones
  535. if w is None:
  536. w = np.ones_like(alpha)
  537. w[np.isnan(alpha)] = 0
  538. # compute weighted mean
  539. mean_vector = np.nansum(w * np.exp(1j * alpha), axis=axis) / w.sum(axis=axis)
  540. mrl = np.abs(mean_vector) # length
  541. theta = np.angle(mean_vector) # angle
  542. return mrl, theta
  543. def circmean_angle(alpha, **kwargs):
  544. return circmean(alpha, **kwargs)[1]
  545. def circhist(angles, n_bins=8, proportion=True, wrap=False):
  546. bins = np.linspace(-np.pi, np.pi, n_bins + 1, endpoint=True)
  547. weights = np.ones(len(angles))
  548. if proportion:
  549. weights /= len(angles)
  550. counts, bins = np.histogram(angles, bins=bins, weights=weights)
  551. if wrap:
  552. counts = np.concatenate([counts, [counts[0]]])
  553. bins = np.concatenate([bins, [bins[0]]])
  554. return counts, bins
  555. def unbiased_variance(data):
  556. if len(data) <= 1:
  557. return np.nan
  558. else:
  559. return np.var(data) * len(data) / (len(data) - 1)
  560. def se_median(sample, n_resamp=1000):
  561. """Standard error of the median."""
  562. medians = np.full(n_resamp, np.nan)
  563. for i in range(n_resamp):
  564. resample = np.random.choice(sample, len(sample), replace=True)
  565. medians[i] = np.median(resample)
  566. return np.std(medians)
  567. def coupling_summary(df):
  568. """Print some basic statistics for phase coupling."""
  569. # Either spike type
  570. units = df.query('(tonicspk_p == tonicspk_p) or (burst_p == burst_p)').groupby(['m', 's', 'e', 'u'])
  571. n_sig = units.apply(lambda x: any(x[f'tonicspk_sig']) or any(x[f'burst_sig'])).sum()
  572. prop_sig = n_sig / len(units)
  573. print(f"Neurons with significant coupling: {prop_sig:.3f} ({n_sig}/{len(units)})")
  574. # For each spike type
  575. for spk_type in ['tonicspk', 'burst']:
  576. units = df.dropna(subset=f'{spk_type}_p').groupby(['m', 's', 'e', 'u'])
  577. n_sig = units.apply(lambda x: any(x[f'{spk_type}_sig'])).sum()
  578. prop_sig = n_sig / len(units)
  579. print(f"{spk_type.capitalize()} prop. significant: {prop_sig:.3f} ({n_sig}/{len(units)})")
  580. n_cpds = units.apply(lambda x: x[f'{spk_type}_sig'].sum())
  581. print(f"{spk_type.capitalize()} num. CPDs per neuron: {n_cpds.mean():.2f}, {n_cpds.std():.2f}")
  582. def kl_divergence(p, q):
  583. return np.sum(np.where(p != 0, p * np.log(p / q), 0))
  584. def match_distributions(x1, x2, x1_bins, x2_bins):
  585. """
  586. For two time series, x2 & x2, return indices of sub-sampled time
  587. periods such that the distribution of x2 is matched across
  588. bins of x1.
  589. """
  590. x1_nbins = len(x1_bins) - 1
  591. x2_nbins = len(x2_bins) - 1
  592. # bin x1
  593. x1_binned = np.digitize(x1, x1_bins).clip(1, x1_nbins) - 1
  594. # get continuous periods where x1 visits each bin
  595. x1_ranges = [zero_runs(~np.equal(x1_binned, x1_bin)) for x1_bin in np.arange(x1_nbins)]
  596. # get mean of x2 for each x1 bin visit
  597. x2_means = [np.array([np.mean(x2[i0:i1]) for i0, i1 in x1_bin]) for x1_bin in x1_ranges]
  598. # find minimum common distribution across x1 bins
  599. x2_counts = np.row_stack([np.histogram(means, bins=x2_bins)[0] for means in x2_means])
  600. x2_mcd = x2_counts.min(axis=0)
  601. # bin x2 means
  602. x2_means_binned = [np.digitize(means, bins=x2_bins).clip(1, x2_nbins) - 1 for means in x2_means]
  603. x2_means_in_bins = [[means[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for means, binned_means in zip(x2_means, x2_means_binned)]
  604. x1_ranges_in_bins = [[ranges[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for ranges, binned_means in zip(x1_ranges, x2_means_binned)]
  605. # loop over x2 bins
  606. matched_ranges = [[], [], [], []]
  607. for x2_bin in np.arange(x2_nbins):
  608. # find the x1 bin matching the MCD
  609. seed_x1_bin = np.where(x2_counts[:, x2_bin] == x2_mcd[x2_bin])[0][0]
  610. assert len(x2_means_in_bins[seed_x1_bin][x2_bin]) == x2_mcd[x2_bin]
  611. # for each bin visit, find the closest matching mean in the other x1 bins
  612. target_means = x2_means_in_bins[seed_x1_bin][x2_bin]
  613. target_ranges = x1_ranges_in_bins[seed_x1_bin][x2_bin]
  614. for target_mean, target_range in zip(target_means, target_ranges):
  615. matched_ranges[seed_x1_bin].append(target_range)
  616. for x1_bin in np.delete(np.arange(x1_nbins), seed_x1_bin):
  617. matching_ind = np.abs(x2_means_in_bins[x1_bin][x2_bin] - target_mean).argmin()
  618. matched_ranges[x1_bin].append(x1_ranges_in_bins[x1_bin][x2_bin][matching_ind])
  619. # delete the matching period
  620. x2_means_in_bins[x1_bin][x2_bin] = np.delete(x2_means_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
  621. x1_ranges_in_bins[x1_bin][x2_bin] = np.delete(x1_ranges_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
  622. return [np.row_stack(ranges) if len(ranges) > 0 else np.array([]) for ranges in matched_ranges]
  623. ## Signal processing
  624. def normalized_xcorr(a, b, dt=None, ts=None):
  625. """
  626. Compute Pearson r between two arrays at various lags
  627. Parameters
  628. ----------
  629. a, b : ndarray
  630. The arrays to correlate.
  631. dt : float
  632. The time step between samples in the arrays.
  633. ts : list
  634. If not None, only the xcorr between the specified lags will be
  635. returned.
  636. Return
  637. ------
  638. xcorr, lags : ndarray
  639. The cross correlation and corresponding lags between a and b.
  640. Positive lags indicate that a is delayed relative to b.
  641. """
  642. assert len(a) == len(b)
  643. n = len(a)
  644. a_norm = (a - a.mean()) / a.std()
  645. b_norm = (b - b.mean()) / b.std()
  646. xcorr = np.correlate(a_norm, b_norm, 'full') / n
  647. lags = np.arange(-n + 1, n)
  648. if dt is not None:
  649. lags = lags * dt
  650. if ts is not None:
  651. assert len(ts) == 2
  652. i0, i1 = lags.searchsorted(ts)
  653. xcorr = xcorr[i0:i1]
  654. lags = lags[i0:i1]
  655. return xcorr, lags
  656. def interpolate(y, x_old, x_new, axis=0, fill_value='extrapolate'):
  657. """
  658. Use linear interpolation to re-sample 1D data.
  659. """
  660. # get interpolation function
  661. func = interp1d(x_old, y, axis=axis, fill_value=fill_value)
  662. # get new y-values
  663. y_interpolated = func(x_new)
  664. return y_interpolated
  665. def interpolate_and_normalize(y, x_old, x_new):
  666. """
  667. Perform linear interpolation and min-max normalization.
  668. """
  669. y_new = interpolate(y, x_old, x_new)
  670. return (y_new - y_new.min()) / (y_new.max() - y_new.min())
  671. def match_signal_length(a, b, a_tpts, b_tpts):
  672. """
  673. Given two signals, truncate to match the length of the shortest.
  674. """
  675. t1 = min(a_tpts.max(), b_tpts.max())
  676. a1 = a[:a_tpts.searchsorted(t1)]
  677. a1_tpts = a_tpts[:a_tpts.searchsorted(t1)]
  678. b1 = b[:b_tpts.searchsorted(t1)]
  679. b1_tpts = b_tpts[:b_tpts.searchsorted(t1)]
  680. return a1, b1, a1_tpts, b1_tpts
  681. def times_to_counts(series, columns, t0=None, t1=None, dt=0.25):
  682. if type(t0) == str:
  683. t0, t1 = series[t0].min(), series[t0].max()
  684. elif t0 is None: # get overlapping time range for all columns
  685. t0 = max([series[f'{col.split("_")[0]}_times'].min() for col in columns])
  686. elif t1 is None:
  687. t1 = min([series[f'{col.split("_")[0]}_times'].max() for col in columns])
  688. # Set time base
  689. tbins = np.arange(t0, t1, dt)
  690. tpts = tbins[:-1] + (dt / 2)
  691. for col in columns:
  692. header = col.split('_')[0]
  693. times = series[f'{header}_times']
  694. counts, _ = np.histogram(times, bins=tbins)
  695. series[f'{header}_counts'] = counts
  696. series[f'{header}_tpts'] = tpts
  697. return series
  698. def resample_timeseries(y, tpts, dt=0.25):
  699. tpts_new = np.arange(tpts.min(), tpts.max(), dt)
  700. return tpts_new, interpolate(y, tpts, tpts_new)
  701. def _resample_data(series, columns, t0=None, t1=None, dt=0.25):
  702. if type(t0) == str:
  703. t0, t1 = series[t0].min(), series[t0].max()
  704. elif t0 is None: # get overlapping time range for all columns
  705. t0 = max([series[f'{col.split("_")[0]}_tpts'].min() for col in columns])
  706. t1 = min([series[f'{col.split("_")[0]}_tpts'].max() for col in columns])
  707. # Set new time base
  708. tbins = np.arange(t0, t1, dt)
  709. tpts_new = tbins[:-1] + (dt / 2)
  710. # Interpolate and re-sample each column
  711. for col in columns:
  712. header = col.split('_')[0]
  713. data = series[col]
  714. tpts = series[f'{header}_tpts']
  715. series[col] = interpolate(data, tpts, tpts_new)
  716. series[f'{header}_tpts'] = tpts_new
  717. return series
  718. ## Neural activity
  719. def get_mean_rates(df):
  720. df['mean_rate'] = df.apply(
  721. lambda x:
  722. len(x['spk_times']) / (x['spk_tinfo'][1] - x['spk_tinfo'][0]),
  723. axis='columns'
  724. )
  725. return df
  726. def get_mean_rate_threshold(df, alpha=0.025):
  727. if 'mean_rate' not in df.columns:
  728. df = get_mean_rates(df)
  729. rates = np.log10(df['mean_rate'])
  730. gmm = GaussianMixture(n_components=2)
  731. gmm.fit(rates[..., np.newaxis])
  732. (mu, var) = (gmm.means_.max(), gmm.covariances_.squeeze()[gmm.means_.argmax()])
  733. threshold = mu + norm.ppf(alpha) * np.sqrt(var)
  734. return threshold
  735. def filter_units(df, threshold):
  736. if 'mean_rate' not in df.columns:
  737. df = get_mean_rates(df)
  738. return df.query(f'mean_rate >= {threshold}')
  739. def get_raster(series, events, spike_type='spk', pre=0, post=1):
  740. events = series['%s_times' % events]
  741. spks = series['%s_times' % spike_type]
  742. raster = np.array([spks[(spks > t0 - pre) & (spks < t0 + post)] - t0 for t0 in events], dtype='object')
  743. return raster
  744. def get_psth(events, spikes, pre=0, post=1, dt=0.001, bw=0.01, baseline=[]):
  745. rate_kernel = GaussianKernel(bw*pq.s)
  746. tpts = np.arange(pre, post, dt)
  747. psth = np.full((len(events), len(tpts)), np.nan)
  748. for i, t0 in enumerate(events):
  749. rel_ts = spikes - t0
  750. rel_ts = rel_ts[(rel_ts >= pre) & (rel_ts <= post)]
  751. try:
  752. rate = instantaneous_rate(
  753. SpikeTrain(rel_ts, t_start=pre, t_stop=post, units='s'),
  754. sampling_period=dt*pq.s,
  755. kernel=rate_kernel
  756. )
  757. except:
  758. continue
  759. psth[i] = rate.squeeze()
  760. if baseline:
  761. b0, b1 = tpts.searchsorted(baseline)
  762. baseline_rate = psth[:, b0:b1].mean(axis=1)
  763. psth = (psth.T - baseline_rate).T
  764. return psth, tpts
  765. def apply_get_psth(series, events, spike_type, **kwargs):
  766. events = series['{}_times'.format(events)]
  767. spikes = series['{}_times'.format(spike_type)]
  768. psth, tpts = get_psth(events, spikes, **kwargs)
  769. return psth
  770. def get_responses(events, data, tpts, pre=0, post=1, baseline=[]):
  771. dt = np.round(np.diff(tpts).mean(), 3) # round to nearest ms
  772. i_pre, i_post = int(pre / dt), int(post / dt)
  773. responses = np.full((len(events), i_pre + i_post), np.nan)
  774. for j, t0 in enumerate(events):
  775. i = tpts.searchsorted(t0)
  776. i0, i1 = i - i_pre, i + i_post
  777. if i0 < 0:
  778. continue
  779. if i1 > len(data):
  780. break
  781. responses[j] = data[i0:i1]
  782. tpts = np.linspace(pre, post, responses.shape[1])
  783. if baseline:
  784. b0, b1 = tpts.searchsorted(baseline)
  785. baseline_resp = responses[:, b0:b1].mean(axis=1)
  786. responses = (responses.T - baseline_resp).T
  787. return responses, tpts
  788. def apply_get_responses(series, events, data, **kwargs):
  789. events = series[f'{events}_times']
  790. tpts = series[f'{data.split("_")[0]}_tpts']
  791. data = series[f'{data}']
  792. responses, tpts = get_responses(events, data, tpts, **kwargs)
  793. return responses