|
@@ -0,0 +1,891 @@
|
|
|
+import numpy as np
|
|
|
+import pandas as pd
|
|
|
+import quantities as pq
|
|
|
+from matplotlib import pyplot as plt
|
|
|
+from scipy.interpolate import interp1d
|
|
|
+from statsmodels.stats.multitest import fdrcorrection
|
|
|
+from scipy.stats import norm, sem
|
|
|
+from sklearn.mixture import GaussianMixture
|
|
|
+from neo import SpikeTrain
|
|
|
+from elephant.statistics import instantaneous_rate
|
|
|
+from elephant.kernels import GaussianKernel
|
|
|
+
|
|
|
+from parameters import *
|
|
|
+
|
|
|
+
|
|
|
+## Data handling
|
|
|
+def df2keys(df):
|
|
|
+ """Return MSE keys for all entries in a DataFrame."""
|
|
|
+
|
|
|
+ df = df.reset_index()
|
|
|
+ keys = [key for key in df.columns if key in ['m', 's', 'e', 'u']]
|
|
|
+
|
|
|
+ return [{key: val for key, val in zip(keys, vals)} for vals in df[keys].values]
|
|
|
+
|
|
|
+def key2idx(key):
|
|
|
+ """Return DataFrame index tuple for the given key."""
|
|
|
+ return tuple([val for k, val in key.items()])
|
|
|
+
|
|
|
+def df_metadata(df):
|
|
|
+ """Print metadata for a DataFrame."""
|
|
|
+
|
|
|
+ df = df.reset_index()
|
|
|
+ print("No. mice: {}".format(len(df.groupby('m'))))
|
|
|
+ print("No. series: {}".format(len(df.groupby(['m', 's']))))
|
|
|
+ print("No. experiments: {}".format(len(df.groupby(['m', 's', 'e']))))
|
|
|
+ if 'u' in df.columns:
|
|
|
+ print("No. units: {}".format(len(df.groupby(['m', 's', 'e', 'u']))))
|
|
|
+ for idx, df in df.groupby(['m', 's', 'e']):
|
|
|
+ print("\n{} s{:02d} e{:02d}".format(idx[0], idx[1], idx[2]))
|
|
|
+ print(" {} units".format(len(df)))
|
|
|
+
|
|
|
+def load_data(data, conditions, **kwargs):
|
|
|
+ """
|
|
|
+ Load data of a specified type and recording region, pooling across all
|
|
|
+ requested conditions.
|
|
|
+ """
|
|
|
+ dfs = []
|
|
|
+ for condition in conditions:
|
|
|
+ filename = '{}_{}'.format(data, condition)
|
|
|
+ for kw, arg in kwargs.items():
|
|
|
+ filename = filename + '_{}'.format(arg)
|
|
|
+ filename = filename + '.pkl'
|
|
|
+ print("Loading: ", filename)
|
|
|
+ df = pd.read_pickle(DATAPATH + filename)
|
|
|
+ if 'condition' not in df.columns:
|
|
|
+ df['condition'] = condition
|
|
|
+ dfs.append(df)
|
|
|
+ return pd.concat(dfs, axis='index')
|
|
|
+
|
|
|
+## Plotting utils
|
|
|
+
|
|
|
+def set_plotsize(w, h=None, ax=None):
|
|
|
+ """
|
|
|
+ Set the size of a matplotlib axes object in cm.
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ w, h : float
|
|
|
+ Desired width and height of plot, if height is None, the axis will be
|
|
|
+ square.
|
|
|
+
|
|
|
+ ax : matplotlib.axes
|
|
|
+ Axes to resize, if None the output of plt.gca() will be re-sized.
|
|
|
+
|
|
|
+ Notes
|
|
|
+ -----
|
|
|
+ - Use after subplots_adjust (if adjustment is needed)
|
|
|
+ - Matplotlib axis size is determined by the figure size and the subplot
|
|
|
+ margins (r, l; given as a fraction of the figure size), i.e.
|
|
|
+ w_ax = w_fig * (r - l)
|
|
|
+ """
|
|
|
+ if h is None: # assume square
|
|
|
+ h = w
|
|
|
+ w /= 2.54 # convert cm to inches
|
|
|
+ h /= 2.54
|
|
|
+ if not ax: # get current axes
|
|
|
+ ax = plt.gca()
|
|
|
+ # get margins
|
|
|
+ l = ax.figure.subplotpars.left
|
|
|
+ r = ax.figure.subplotpars.right
|
|
|
+ t = ax.figure.subplotpars.top
|
|
|
+ b = ax.figure.subplotpars.bottom
|
|
|
+ # set fig dimensions to produce desired ax dimensions
|
|
|
+ figw = float(w)/(r-l)
|
|
|
+ figh = float(h)/(t-b)
|
|
|
+ ax.figure.set_size_inches(figw, figh)
|
|
|
+
|
|
|
+def clip_axes_to_ticks(ax=None, spines=['left', 'bottom'], ext={}):
|
|
|
+ """
|
|
|
+ Clip the axis lines to end at the minimum and maximum tick values.
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ ax : matplotlib.axes
|
|
|
+ Axes to resize, if None the output of plt.gca() will be re-sized.
|
|
|
+
|
|
|
+ spines : list
|
|
|
+ Axes to keep and clip, axes not included in this list will be removed.
|
|
|
+ Valid values include 'left', 'bottom', 'right', 'top'.
|
|
|
+
|
|
|
+ ext : dict
|
|
|
+ For each axis in ext.keys() ('left', 'bottom', 'right', 'top'),
|
|
|
+ the axis line will be extended beyond the last tick by the value
|
|
|
+ specified, e.g. {'left':[0.1, 0.2]} will results in an axis line
|
|
|
+ that extends 0.1 units beyond the bottom tick and 0.2 unit beyond
|
|
|
+ the top tick.
|
|
|
+ """
|
|
|
+ if ax is None:
|
|
|
+ ax = plt.gca()
|
|
|
+ spines2ax = {
|
|
|
+ 'left': ax.yaxis,
|
|
|
+ 'top': ax.xaxis,
|
|
|
+ 'right': ax.yaxis,
|
|
|
+ 'bottom': ax.xaxis
|
|
|
+ }
|
|
|
+ all_spines = ['left', 'bottom', 'right', 'top']
|
|
|
+ for spine in spines:
|
|
|
+ low = min(spines2ax[spine].get_majorticklocs())
|
|
|
+ high = max(spines2ax[spine].get_majorticklocs())
|
|
|
+ if spine in ext.keys():
|
|
|
+ low += ext[spine][0]
|
|
|
+ high += ext[spine][1]
|
|
|
+ ax.spines[spine].set_bounds(low, high)
|
|
|
+ for spine in [spine for spine in all_spines if spine not in spines]:
|
|
|
+ ax.spines[spine].set_visible(False)
|
|
|
+
|
|
|
+def p2stars(p):
|
|
|
+ if p <= 0.0001:
|
|
|
+ return '***'
|
|
|
+ elif p <= 0.001:
|
|
|
+ return '**'
|
|
|
+ elif p<= 0.05:
|
|
|
+ return '*'
|
|
|
+ else:
|
|
|
+ return ''
|
|
|
+
|
|
|
+def violin_plot(dists, colors, ax=None, logscale=False):
|
|
|
+ if type(colors) is list:
|
|
|
+ assert len(colors) == len(dists)
|
|
|
+ if ax is None:
|
|
|
+ fig, ax = plt.subplots()
|
|
|
+ violins = ax.violinplot(dists, showmedians=True, showextrema=False)
|
|
|
+ for violin, color in zip(violins['bodies'], colors):
|
|
|
+ violin.set_facecolor('none')
|
|
|
+ violin.set_edgecolor(color)
|
|
|
+ violin.set_alpha(1)
|
|
|
+ violin.set_linewidth(2)
|
|
|
+ violins['cmedians'].set_color('black')
|
|
|
+ for pos, dist in enumerate(dists):
|
|
|
+ median = np.median(dist)
|
|
|
+ text = f'{median:.2f}'
|
|
|
+ if logscale:
|
|
|
+ text = f'{10 ** median:.2f}'
|
|
|
+ ax.text(pos + 1.4, median, text, va='center', ha='center', rotation=-90, fontsize=LABELFONTSIZE)
|
|
|
+ ax.set_xticks(np.arange(len(dists)) + 1)
|
|
|
+ ax.tick_params(bottom=False)
|
|
|
+ return ax
|
|
|
+
|
|
|
+def pupil_area_rate_heatmap(df, cmap='gray', max_rate='high', example=None):
|
|
|
+ """
|
|
|
+ Plot a heatmap of event rates (tonic spikes or bursts) where each row is a neuron and each column is a pupil size bin.
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ df : pandas.DataFrame
|
|
|
+ Dataframe with neurons in the rows and mean firing rates for each pupil size bin in a column called 'area_means'.
|
|
|
+
|
|
|
+ cmap : str or Matplotlib colormap object
|
|
|
+
|
|
|
+ max_rate : str
|
|
|
+ Is the max event rate expected to at 'high' or 'low' pupil sizes?
|
|
|
+
|
|
|
+ example : dict
|
|
|
+ If not none, MSEU key passed will be used to highlight example neuron.
|
|
|
+ """
|
|
|
+ fig = plt.figure()
|
|
|
+
|
|
|
+ # Find out which pupil size bin has the highest firing rate
|
|
|
+ df['tuning_max'] = df['area_means'].apply(np.argmax)
|
|
|
+ # Min-max normalize firing rates for each neuron
|
|
|
+ df['tuning_norm'] = df['area_means'].apply(lambda x: (x - x.min()) / (x.max() - x.min()))
|
|
|
+
|
|
|
+ # Get heatmap for neurons with significant rate differences across pupil size bins
|
|
|
+ df_sig = df.query('area_p <= 0.05').sort_values('tuning_max')
|
|
|
+ heatmap_sig = np.row_stack(df_sig['tuning_norm'])
|
|
|
+ n_sig = len(df_sig) # number of neurons with significant differences
|
|
|
+
|
|
|
+ # Make axis with size proportional to the fraction of significant neurons
|
|
|
+ n_units = len(df) # total number of neurons
|
|
|
+ ax1_height = n_sig / n_units
|
|
|
+ ax1 = fig.add_axes([0.1, 0.9 - (0.76 * ax1_height), 0.8, (0.76 * ax1_height)])
|
|
|
+ # Plot heatmap for signifcant neurons
|
|
|
+ mat = ax1.matshow(heatmap_sig, cmap=cmap, aspect='auto')
|
|
|
+ # Scatter plot marking pupil size bin with maximum
|
|
|
+ ax1.scatter(df_sig['tuning_max'], np.arange(len(df_sig)) - 0.5, s=0.5, color='black', zorder=3)
|
|
|
+
|
|
|
+ # Format axes
|
|
|
+ ax1.set_xticks([])
|
|
|
+ yticks = np.insert(np.arange(0, n_sig, 20), -1, n_sig) # ticks every 20 neurons, and final count
|
|
|
+ ax1.set_yticks(yticks - 0.5)
|
|
|
+ ax1.set_yticklabels(yticks)
|
|
|
+ ax1.set_ylabel('Neurons')
|
|
|
+
|
|
|
+ # Make a colorbar
|
|
|
+ cbar = fig.colorbar(mat, ax=ax1, ticks=[0, 1], location='top', shrink=0.75)
|
|
|
+ cbar.ax.set_xticklabels(['min', 'max'])
|
|
|
+ cbar.ax.set_xlabel('Spikes', labelpad=-5)
|
|
|
+
|
|
|
+ # Print dotted line and label for neurons considered to have 'monotonic' modulation profiles
|
|
|
+ # (highest event rate at one of the pupil size extremes)
|
|
|
+ if max_rate == 'high':
|
|
|
+ n_mon = len(df_sig.query('tuning_max >= 9'))
|
|
|
+ print("Monotonic increasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
|
|
|
+ ax1.axvline(8.5, lw=2, ls='--', color='white')
|
|
|
+ ax1.set_title(r'Monotonic$\rightarrow$', fontsize=LABELFONTSIZE, loc='right', pad=0)
|
|
|
+ elif max_rate == 'low':
|
|
|
+ n_mon = len(df_sig.query('tuning_max <= 1'))
|
|
|
+ print("Monotonic decreasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
|
|
|
+ ax1.axvline(0.5, lw=2, ls='--', color='white')
|
|
|
+ ax1.set_title(r'$\leftarrow$Monotonic', fontsize=LABELFONTSIZE, loc='left', pad=0)
|
|
|
+
|
|
|
+ # Get heatmap for neurons without significant rate differences acrss pupil sizes
|
|
|
+ df_ns = df.query('area_p > 0.05').sort_values('tuning_max')
|
|
|
+ heatmap_ns = np.row_stack(df_ns['tuning_norm'])
|
|
|
+
|
|
|
+ # Make axis with size proportional to the fraction of non-significant neurons
|
|
|
+ n_ns = len(df_ns)
|
|
|
+ ax2_height = n_ns / n_units
|
|
|
+ ax2 = fig.add_axes([0.1, 0.1, 0.8, (0.76 * ax2_height)])
|
|
|
+
|
|
|
+ # Plot heatmap for neurons without significant rate differences acrss pupil sizes
|
|
|
+ mat = ax2.matshow(heatmap_ns, cmap='Greys', vmax=2, aspect='auto')
|
|
|
+ ax2.scatter(df_ns['tuning_max'], np.arange(n_ns) - 0.5, s=0.5, color='black', zorder=3)
|
|
|
+
|
|
|
+ # Format axes
|
|
|
+ ax2.xaxis.set_ticks_position('bottom')
|
|
|
+ ax2.set_xticks([-0.5, 4.5, 9.5])
|
|
|
+ ax2.set_xticklabels([0, 0.5, 1]) # x-axis ticks mark percentiles of pupil size range
|
|
|
+ ax2.set_xlim(right=9.5)
|
|
|
+ ax2.set_xlabel('Pupil size (norm.)')
|
|
|
+ yticks = np.insert(np.arange(0, n_ns, 20), -1, n_ns) # ticks every 20 neurons, and final count
|
|
|
+ ax2.set_yticks(yticks - 0.5)
|
|
|
+ ax2.set_yticklabels(yticks)
|
|
|
+
|
|
|
+ # Highligh example neuron by circling scatter dot
|
|
|
+ if example is not None:
|
|
|
+ try: # first check if example neuron is among significant neurons
|
|
|
+ is_ex = df_sig.index == tuple([v for k, v in example.items()])
|
|
|
+ assert is_ex.any()
|
|
|
+ ex_max = df_sig['tuning_max'][is_ex]
|
|
|
+ ax1.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
|
|
|
+ except:
|
|
|
+ is_ex = df_ns.index == tuple([v for k, v in example.items()])
|
|
|
+ ex_max = df_ns['tuning_max'][is_ex]
|
|
|
+ ax2.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
|
|
|
+
|
|
|
+ return fig
|
|
|
+
|
|
|
+def cumulative_histogram(data, bins, color='C0', ax=None):
|
|
|
+ """Convenience function, cleaner looking plot than plt.hist(..., cumulative=True)."""
|
|
|
+ if ax is None:
|
|
|
+ fig, ax = plt.subplots()
|
|
|
+ weights = np.ones_like(data) / len(data)
|
|
|
+ counts, _ = np.histogram(data, bins=bins, weights=weights)
|
|
|
+ ax.plot(bins[:-1], np.cumsum(counts), color=color)
|
|
|
+ return ax
|
|
|
+
|
|
|
+def cumulative_hist(x, bins, density=True, ax=None, **kwargs):
|
|
|
+ weights = np.ones_like(x)
|
|
|
+ if density:
|
|
|
+ weights = weights / len(x)
|
|
|
+ counts, _ = np.histogram(x, bins=bins, weights=weights)
|
|
|
+ if ax is None:
|
|
|
+ fig, ax = plt.subplots()
|
|
|
+ xs = np.insert(bins, np.arange(len(bins) - 1), bins[:-1])
|
|
|
+ ys = np.insert(np.insert(np.cumsum(counts), np.arange(len(counts)), np.cumsum(counts)), 0, 0)
|
|
|
+ ax.plot(xs, ys, lw=2, **kwargs)
|
|
|
+ ax.set_xticks(bins + 0.5)
|
|
|
+ ax.set_yticks([0, 0.5, 1])
|
|
|
+ return ax, counts
|
|
|
+
|
|
|
+def phase_coupling_scatter(df, ax=None):
|
|
|
+ """Phase-frequency scatter plot for phase coupling."""
|
|
|
+ if ax is None:
|
|
|
+ fig, ax = plt.subplots()
|
|
|
+ for event in ['tonicspk', 'burst']:
|
|
|
+ df_sig = df.query(f'{event}_sig == True')
|
|
|
+ ax.scatter(np.log10(df_sig['freq']), df_sig[f'{event}_phase'], ec=COLORS[event], fc='none', lw=0.5, s=3)
|
|
|
+ n_sig = max([len(df.query(f'{event}_sig == True').groupby(['m', 's', 'e', 'u'])) for event in ['tonicspk', 'burst']])
|
|
|
+
|
|
|
+ ax.set_xticks(FREQUENCYTICKS)
|
|
|
+ ax.set_xticklabels(FREQUENCYTICKLABELS)
|
|
|
+ ax.set_xlim(left=-3.075)
|
|
|
+ ax.set_xlabel("Inverse timescale (s$^{-1}$)")
|
|
|
+
|
|
|
+ ax.set_yticks(PHASETICKS)
|
|
|
+ ax.set_yticklabels(PHASETICKLABELS)
|
|
|
+ ax.set_ylim([-np.pi - 0.15, np.pi + 0.15])
|
|
|
+ ax.set_ylabel("Preferred phase")
|
|
|
+ return ax
|
|
|
+
|
|
|
+def plot_circhist(angles, ax=None, bins=np.linspace(0, 2 * np.pi, 17), density=True, **kwargs):
|
|
|
+ """Plot a circular histogram."""
|
|
|
+ if ax is None:
|
|
|
+ fig, ax = plt.subplots(subplot_kw={'polar':True})
|
|
|
+ weights = np.ones_like(angles)
|
|
|
+ if density:
|
|
|
+ weights /= len(angles)
|
|
|
+ counts, bins = np.histogram(angles, bins=bins, weights=weights)
|
|
|
+ xs = bins + (np.pi / (len(bins) - 1))
|
|
|
+ ys = np.append(counts, counts[0])
|
|
|
+ ax.plot(xs, ys, **kwargs)
|
|
|
+ ax.set_xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2])
|
|
|
+ ax.set_xticklabels(['0', '\u03C0/2', '\u03C0', '3\u03C0/2'])
|
|
|
+ ax.tick_params(axis='x', pad=-5)
|
|
|
+ return ax, counts
|
|
|
+
|
|
|
+def coupling_strength_line_plot(df, agg=np.mean, err=sem, logscale=True, ax=None, **kwargs):
|
|
|
+ """
|
|
|
+ Line plot showing average burst and tonic spike coupling strengths and SE across timescale bins.
|
|
|
+ """
|
|
|
+ if ax is None:
|
|
|
+ fig, ax = plt.subplots()
|
|
|
+ for event in ['burst', 'tonicspk']:
|
|
|
+ df_sig = df.query(f'({event}_sig == True) & ({event}_strength > 0)').copy()
|
|
|
+ strengths = sort_data(df_sig[f'{event}_strength'], df_sig['freq'], bins=FREQUENCYBINS)
|
|
|
+ ys = np.array([agg(s) for s in strengths])
|
|
|
+ yerr = np.array([err(s) for s in strengths])
|
|
|
+ if not logscale:
|
|
|
+ ax.plot(FREQUENCYXPOS, ys, color=COLORS[event], **kwargs)
|
|
|
+ ax.plot(FREQUENCYXPOS, ys + yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
|
|
|
+ ax.plot(FREQUENCYXPOS, ys - yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
|
|
|
+ else:
|
|
|
+ ax.plot(FREQUENCYXPOS, np.log10(ys), color=COLORS[event], **kwargs)
|
|
|
+ ax.plot(FREQUENCYXPOS, np.log10(ys + yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
|
|
|
+ ax.plot(FREQUENCYXPOS, np.log10(ys - yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
|
|
|
+ ax.set_xticks(FREQUENCYTICKS)
|
|
|
+ ax.set_xticklabels(FREQUENCYTICKLABELS)
|
|
|
+ ax.set_xlim(left=-3.1)
|
|
|
+ ax.set_xlabel('Timescale (Hz)')
|
|
|
+ ax.set_ylabel('Coupling strength')
|
|
|
+ return ax
|
|
|
+
|
|
|
+
|
|
|
+## Util
|
|
|
+
|
|
|
+def zero_runs(a):
|
|
|
+ """
|
|
|
+ Return an array with shape (m, 2), where m is the number of "runs" of zeros
|
|
|
+ in a. The first column is the index of the first 0 in each run, the second
|
|
|
+ is the index of the first nonzero element after the run.
|
|
|
+ """
|
|
|
+ # Create an array that is 1 where a is 0, and pad each end with an extra 0.
|
|
|
+ iszero = np.concatenate(([0], np.equal(a, 0).view(np.int8), [0]))
|
|
|
+ absdiff = np.abs(np.diff(iszero))
|
|
|
+ # Runs start and end where absdiff is 1.
|
|
|
+ ranges = np.where(absdiff == 1)[0].reshape(-1, 2)
|
|
|
+
|
|
|
+ return ranges
|
|
|
+
|
|
|
+def merge_ranges(ranges, dt=1):
|
|
|
+ """
|
|
|
+ Given a set of ranges [start, stop], return new set of ranges where all
|
|
|
+ overlapping ranges are merged.
|
|
|
+ """
|
|
|
+ tpts = np.arange(ranges.min(), ranges.max(), dt) # array of time points
|
|
|
+ tc = np.ones_like(tpts) # time course of ranges
|
|
|
+ for t0, t1 in ranges: # for each range
|
|
|
+ i0, i1 = tpts.searchsorted([t0, t1])
|
|
|
+ tc[i0:i1] = 0 # set values in range to 0
|
|
|
+ new_ranges = zero_runs(tc) # get indices of continuous stretches of zero
|
|
|
+ if new_ranges[-1, -1] == len(tpts): # fix end-point
|
|
|
+ new_ranges[-1, -1] = len(tpts) - 1
|
|
|
+ return tpts[new_ranges]
|
|
|
+
|
|
|
+def continuous_runs(data, max0len=1, min1len=1, min1prop=0):
|
|
|
+ """
|
|
|
+ Get start and stop indices of stretches of (relatively) continuous data.
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ data : ndarray
|
|
|
+ 1D boolean array
|
|
|
+ max0len : int
|
|
|
+ maximum length (in data pts) of False stretches to ignore
|
|
|
+ min1len : int
|
|
|
+ minimum length (in data pts) of True runs to keep
|
|
|
+ min1prop : int
|
|
|
+ minimum proprtion of True data in the run necessary for it
|
|
|
+ to be considered
|
|
|
+
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ out : ndarray
|
|
|
+ (m, 2) array of start and stop indices, where m is the number runs of
|
|
|
+ continuous True values
|
|
|
+ """
|
|
|
+ # get ranges of True values
|
|
|
+ one_ranges = zero_runs(~data)
|
|
|
+ if len(one_ranges) == 0:
|
|
|
+ return np.array([[]])
|
|
|
+ # merge ranges that are separated by < min0len of False
|
|
|
+ one_ranges[:, 1] += (max0len - 1)
|
|
|
+ one_ranges = merge_ranges(one_ranges)
|
|
|
+ # return indices to normal
|
|
|
+ one_ranges[:, 1] -= (max0len - 1)
|
|
|
+ one_ranges[-1, -1] += 1
|
|
|
+ # remove ranges that are too short
|
|
|
+ lengths = np.diff(one_ranges, axis=1).ravel()
|
|
|
+ one_ranges = one_ranges[lengths >= min1len]
|
|
|
+ # remove ranges that don't have sufficient proportion True
|
|
|
+ prop = np.array([data[i0:i1].sum() / (i1 - i0) for (i0, i1) in one_ranges])
|
|
|
+ return one_ranges[prop >= min1prop]
|
|
|
+
|
|
|
+def switch_ranges(ranges, dt=1, minval=0, maxval=None):
|
|
|
+ """
|
|
|
+ Given a set of (start, stop) pairs, return a new set of pairs for values
|
|
|
+ outside the given ranges.
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ ranges : ndarray
|
|
|
+ N x 2 array containing start and stop values in the first and second
|
|
|
+ columns respectively
|
|
|
+ dt : float
|
|
|
+
|
|
|
+ minval, maxval : int
|
|
|
+ the minimum and maximum possible values, if maxval is None it is assumed
|
|
|
+ that the maximum possible value is the maximum value in the input ranges
|
|
|
+
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ out : ndarray
|
|
|
+ M x 2 array containing start and stop values of all ranges outside of
|
|
|
+ the input ranges
|
|
|
+ """
|
|
|
+ if ranges.shape[1] == 0:
|
|
|
+ return np.array([[minval, maxval]])
|
|
|
+ assert (ranges.ndim == 2) & (ranges.shape[1] == 2), "A two-column array is expected"
|
|
|
+ maxval = ranges.max() if maxval is None else maxval
|
|
|
+ # get new ranges
|
|
|
+ new_ranges = np.zeros_like(ranges)
|
|
|
+ new_ranges[:,0] = ranges[:,0] - dt # new stop values
|
|
|
+ new_ranges[:,1] = ranges[:,1] + dt # new start values
|
|
|
+ # fix boundaries
|
|
|
+ new_ranges = new_ranges.ravel()
|
|
|
+ if new_ranges[0] >= (minval + dt): # first new stop within allowed range
|
|
|
+ new_ranges = np.concatenate((np.array([minval]), new_ranges))
|
|
|
+ else:
|
|
|
+ new_ranges = new_ranges[1:]
|
|
|
+ if new_ranges[-1] <= (maxval - dt): # first new start within allowed range
|
|
|
+ new_ranges = np.concatenate((new_ranges, np.array([maxval])))
|
|
|
+ else:
|
|
|
+ new_ranges = new_ranges[:-1]
|
|
|
+ return new_ranges.reshape((int(len(new_ranges) / 2), 2))
|
|
|
+
|
|
|
+def shuffle_bins(x, binwidth=1):
|
|
|
+ """
|
|
|
+ Randomly shuffle bins of an array.
|
|
|
+ """
|
|
|
+ # bin start indices
|
|
|
+ bins_i0 = np.arange(0, len(x), binwidth)
|
|
|
+ # shuffled bins
|
|
|
+ np.random.shuffle(bins_i0)
|
|
|
+ # concatenate shuffled bins
|
|
|
+ shf = np.concatenate([x[i0:(i0 + binwidth)] for i0 in bins_i0])
|
|
|
+ return shf
|
|
|
+
|
|
|
+def take_data_in_bouts(series, data, bouts, trange=None, dt=2, dt0=0, dt1=0, concatenate=True, norm=False):
|
|
|
+ if series['%s_bouts' % bouts].shape[1] < 1:
|
|
|
+ return np.array([])
|
|
|
+ header, _ = data.split('_')
|
|
|
+ data_in_bouts = []
|
|
|
+ for t0, t1 in series['%s_bouts' % bouts]:
|
|
|
+ t0 -= dt0
|
|
|
+ t1 += dt1
|
|
|
+ if trange == 'start':
|
|
|
+ t1 = t0 + dt
|
|
|
+ elif trange == 'end':
|
|
|
+ t0 = t1 - dt
|
|
|
+ elif trange == 'middle':
|
|
|
+ t0 = t0 + dt
|
|
|
+ t1 = t1 - dt
|
|
|
+ if t1 <= t0:
|
|
|
+ continue
|
|
|
+ if t0 < series['%s_tpts' % header].min():
|
|
|
+ continue
|
|
|
+ if t1 > series['%s_tpts' % header].max():
|
|
|
+ continue
|
|
|
+ i0, i1 = series['%s_tpts' % header].searchsorted([t0, t1])
|
|
|
+ data_in_bout = series[data][i0:i1].copy()
|
|
|
+ if norm:
|
|
|
+ data_in_bout = data_in_bout / series[data].max()
|
|
|
+ data_in_bouts.append(data_in_bout)
|
|
|
+ if concatenate:
|
|
|
+ data_in_bouts = np.concatenate(data_in_bouts)
|
|
|
+ return data_in_bouts
|
|
|
+
|
|
|
+def get_trials(series, stim_id=0, opto=False, multi_stim='warn'):
|
|
|
+ if opto:
|
|
|
+ opto = np.isin(series['trial_id'], series['opto_trials'])
|
|
|
+ elif not opto:
|
|
|
+ opto = ~np.isin(series['trial_id'], series['opto_trials'])
|
|
|
+ if stim_id < 0:
|
|
|
+ stim = np.ones_like(series['stim_id']).astype('bool')
|
|
|
+ else:
|
|
|
+ stim = series['stim_id'] == stim_id
|
|
|
+ series['trial_on_times'] = series['trial_on_times'][opto & stim]
|
|
|
+ series['trial_off_times'] = series['trial_off_times'][opto & stim]
|
|
|
+ return series
|
|
|
+
|
|
|
+def sort_data(data, sort_vals, bins=10):
|
|
|
+ if type(bins) == int:
|
|
|
+ nbins = bins
|
|
|
+ bin_edges = np.linspace(sort_vals.min(), sort_vals.max(), nbins + 1)
|
|
|
+ else:
|
|
|
+ nbins = len(bins) - 1
|
|
|
+ bin_edges = bins
|
|
|
+ digitized_vals = np.digitize(sort_vals, bins=bin_edges).clip(1, nbins)
|
|
|
+ return [data[digitized_vals == val] for val in np.arange(nbins) + 1]
|
|
|
+
|
|
|
+def apply_sort_data(series, data_col, sort_col, bins=10):
|
|
|
+ return sort_data(series[data_col], series[sort_col], bins)
|
|
|
+
|
|
|
+
|
|
|
+## Statistics
|
|
|
+
|
|
|
+def get_binned_rates(spk_rates, pupil_area, sort=False, nbins=10):
|
|
|
+ # Get bins base on percentiles to eliminate effect of outliers
|
|
|
+ min_area, max_area = np.percentile(pupil_area, [2.5, 97.5])
|
|
|
+ #min_area, max_area = pupil_area.min(), pupil_area.max()
|
|
|
+ area_bins = np.linspace(min_area, max_area, nbins + 1)
|
|
|
+ # Bin pupil area
|
|
|
+ binned_area = np.digitize(pupil_area, bins=area_bins).clip(1, nbins) - 1
|
|
|
+ # Bin rates according to pupil area
|
|
|
+ binned_rates = np.array([spk_rates[binned_area == bin_i] for bin_i in np.arange(nbins)], dtype=object)
|
|
|
+ if sort:
|
|
|
+ sorted_inds = np.argsort([rates.mean() if len(rates) > 0 else 0 for rates in binned_rates])
|
|
|
+ binned_rates = binned_rates[sorted_inds]
|
|
|
+ binned_area = np.squeeze([np.where(sorted_inds == area_bin)[0] for area_bin in binned_area])
|
|
|
+ return area_bins, binned_area, binned_rates
|
|
|
+
|
|
|
+def rescale(x, method='min_max'):
|
|
|
+ # Set re-scaling method
|
|
|
+ if method == 'z_score':
|
|
|
+ return (x - x.mean()) / x.std()
|
|
|
+ elif method == 'min_max':
|
|
|
+ return (x - x.min()) / (x.max() - x.min())
|
|
|
+
|
|
|
+def correlogram(ts1, ts2=None, tau_max=1, dtau=0.01, return_tpts=False):
|
|
|
+ if ts2 is None:
|
|
|
+ ts2 = ts1.copy()
|
|
|
+ auto = True
|
|
|
+ else:
|
|
|
+ auto = False
|
|
|
+ tau_max = (tau_max // dtau) * dtau
|
|
|
+ tbins = np.arange(-tau_max, tau_max + dtau, dtau)
|
|
|
+ ccg = np.zeros(len(tbins) - 1)
|
|
|
+ for t0 in ts1:
|
|
|
+ dts = ts2 - t0
|
|
|
+ if auto:
|
|
|
+ dts = dts[dts != 0]
|
|
|
+ ccg += np.histogram(dts[np.abs(dts) <= tau_max], bins=tbins)[0]
|
|
|
+ ccg /= len(ts1)
|
|
|
+ if not return_tpts:
|
|
|
+ return ccg
|
|
|
+ else:
|
|
|
+ tpts = tbins[:-1] + dtau / 2
|
|
|
+ return tpts, ccg
|
|
|
+
|
|
|
+def angle_subtract(a1, a2, period=(2 * np.pi)):
|
|
|
+ return (a1 - a2) % period
|
|
|
+
|
|
|
+def circmean(alpha, w=None, axis=None):
|
|
|
+ """
|
|
|
+ Compute mean resultant vector of circular data.
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ alpha : ndarray
|
|
|
+ array of angles
|
|
|
+ w : ndarray
|
|
|
+ array of weights, must be same shape as alpha
|
|
|
+ axis : int, None
|
|
|
+ axis across which to compute mean
|
|
|
+
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ mrl : ndarray
|
|
|
+ mean resultant vector length
|
|
|
+ theta : ndarray
|
|
|
+ mean resultant vector angle
|
|
|
+ """
|
|
|
+ # weights default to ones
|
|
|
+ if w is None:
|
|
|
+ w = np.ones_like(alpha)
|
|
|
+ w[np.isnan(alpha)] = 0
|
|
|
+
|
|
|
+ # compute weighted mean
|
|
|
+ mean_vector = np.nansum(w * np.exp(1j * alpha), axis=axis) / w.sum(axis=axis)
|
|
|
+ mrl = np.abs(mean_vector) # length
|
|
|
+ theta = np.angle(mean_vector) # angle
|
|
|
+
|
|
|
+ return mrl, theta
|
|
|
+
|
|
|
+def circmean_angle(alpha, **kwargs):
|
|
|
+ return circmean(alpha, **kwargs)[1]
|
|
|
+
|
|
|
+def circhist(angles, n_bins=8, proportion=True, wrap=False):
|
|
|
+ bins = np.linspace(-np.pi, np.pi, n_bins + 1, endpoint=True)
|
|
|
+ weights = np.ones(len(angles))
|
|
|
+ if proportion:
|
|
|
+ weights /= len(angles)
|
|
|
+ counts, bins = np.histogram(angles, bins=bins, weights=weights)
|
|
|
+ if wrap:
|
|
|
+ counts = np.concatenate([counts, [counts[0]]])
|
|
|
+ bins = np.concatenate([bins, [bins[0]]])
|
|
|
+ return counts, bins
|
|
|
+
|
|
|
+def unbiased_variance(data):
|
|
|
+ if len(data) <= 1:
|
|
|
+ return np.nan
|
|
|
+ else:
|
|
|
+ return np.var(data) * len(data) / (len(data) - 1)
|
|
|
+
|
|
|
+def se_median(sample, n_resamp=1000):
|
|
|
+ """Standard error of the median."""
|
|
|
+ medians = np.full(n_resamp, np.nan)
|
|
|
+ for i in range(n_resamp):
|
|
|
+ resample = np.random.choice(sample, len(sample), replace=True)
|
|
|
+ medians[i] = np.median(resample)
|
|
|
+ return np.std(medians)
|
|
|
+
|
|
|
+def coupling_summary(df):
|
|
|
+ """Print some basic statistics for phase coupling."""
|
|
|
+ # Either spike type
|
|
|
+ units = df.query('(tonicspk_p == tonicspk_p) or (burst_p == burst_p)').groupby(['m', 's', 'e', 'u'])
|
|
|
+ n_sig = units.apply(lambda x: any(x[f'tonicspk_sig']) or any(x[f'burst_sig'])).sum()
|
|
|
+ prop_sig = n_sig / len(units)
|
|
|
+ print(f"Neurons with significant coupling: {prop_sig:.3f} ({n_sig}/{len(units)})")
|
|
|
+ # For each spike type
|
|
|
+ for spk_type in ['tonicspk', 'burst']:
|
|
|
+ units = df.dropna(subset=f'{spk_type}_p').groupby(['m', 's', 'e', 'u'])
|
|
|
+ n_sig = units.apply(lambda x: any(x[f'{spk_type}_sig'])).sum()
|
|
|
+ prop_sig = n_sig / len(units)
|
|
|
+ print(f"{spk_type.capitalize()} prop. significant: {prop_sig:.3f} ({n_sig}/{len(units)})")
|
|
|
+ n_cpds = units.apply(lambda x: x[f'{spk_type}_sig'].sum())
|
|
|
+ print(f"{spk_type.capitalize()} num. CPDs per neuron: {n_cpds.mean():.2f}, {n_cpds.std():.2f}")
|
|
|
+
|
|
|
+def kl_divergence(p, q):
|
|
|
+ return np.sum(np.where(p != 0, p * np.log(p / q), 0))
|
|
|
+
|
|
|
+def match_distributions(x1, x2, x1_bins, x2_bins):
|
|
|
+ """
|
|
|
+ For two time series, x2 & x2, return indices of sub-sampled time
|
|
|
+ periods such that the distribution of x2 is matched across
|
|
|
+ bins of x1.
|
|
|
+ """
|
|
|
+ x1_nbins = len(x1_bins) - 1
|
|
|
+ x2_nbins = len(x2_bins) - 1
|
|
|
+ # bin x1
|
|
|
+ x1_binned = np.digitize(x1, x1_bins).clip(1, x1_nbins) - 1
|
|
|
+ # get continuous periods where x1 visits each bin
|
|
|
+ x1_ranges = [zero_runs(~np.equal(x1_binned, x1_bin)) for x1_bin in np.arange(x1_nbins)]
|
|
|
+ # get mean of x2 for each x1 bin visit
|
|
|
+ x2_means = [np.array([np.mean(x2[i0:i1]) for i0, i1 in x1_bin]) for x1_bin in x1_ranges]
|
|
|
+ # find minimum common distribution across x1 bins
|
|
|
+ x2_counts = np.row_stack([np.histogram(means, bins=x2_bins)[0] for means in x2_means])
|
|
|
+ x2_mcd = x2_counts.min(axis=0)
|
|
|
+ # bin x2 means
|
|
|
+ x2_means_binned = [np.digitize(means, bins=x2_bins).clip(1, x2_nbins) - 1 for means in x2_means]
|
|
|
+ x2_means_in_bins = [[means[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for means, binned_means in zip(x2_means, x2_means_binned)]
|
|
|
+ x1_ranges_in_bins = [[ranges[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for ranges, binned_means in zip(x1_ranges, x2_means_binned)]
|
|
|
+ # loop over x2 bins
|
|
|
+ matched_ranges = [[], [], [], []]
|
|
|
+ for x2_bin in np.arange(x2_nbins):
|
|
|
+ # find the x1 bin matching the MCD
|
|
|
+ seed_x1_bin = np.where(x2_counts[:, x2_bin] == x2_mcd[x2_bin])[0][0]
|
|
|
+ assert len(x2_means_in_bins[seed_x1_bin][x2_bin]) == x2_mcd[x2_bin]
|
|
|
+ # for each bin visit, find the closest matching mean in the other x1 bins
|
|
|
+ target_means = x2_means_in_bins[seed_x1_bin][x2_bin]
|
|
|
+ target_ranges = x1_ranges_in_bins[seed_x1_bin][x2_bin]
|
|
|
+ for target_mean, target_range in zip(target_means, target_ranges):
|
|
|
+ matched_ranges[seed_x1_bin].append(target_range)
|
|
|
+ for x1_bin in np.delete(np.arange(x1_nbins), seed_x1_bin):
|
|
|
+ matching_ind = np.abs(x2_means_in_bins[x1_bin][x2_bin] - target_mean).argmin()
|
|
|
+ matched_ranges[x1_bin].append(x1_ranges_in_bins[x1_bin][x2_bin][matching_ind])
|
|
|
+ # delete the matching period
|
|
|
+ x2_means_in_bins[x1_bin][x2_bin] = np.delete(x2_means_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
|
|
|
+ x1_ranges_in_bins[x1_bin][x2_bin] = np.delete(x1_ranges_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
|
|
|
+ return [np.row_stack(ranges) if len(ranges) > 0 else np.array([]) for ranges in matched_ranges]
|
|
|
+
|
|
|
+## Signal processing
|
|
|
+
|
|
|
+def normalized_xcorr(a, b, dt=None, ts=None):
|
|
|
+ """
|
|
|
+ Compute Pearson r between two arrays at various lags
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ a, b : ndarray
|
|
|
+ The arrays to correlate.
|
|
|
+ dt : float
|
|
|
+ The time step between samples in the arrays.
|
|
|
+ ts : list
|
|
|
+ If not None, only the xcorr between the specified lags will be
|
|
|
+ returned.
|
|
|
+
|
|
|
+ Return
|
|
|
+ ------
|
|
|
+ xcorr, lags : ndarray
|
|
|
+ The cross correlation and corresponding lags between a and b.
|
|
|
+ Positive lags indicate that a is delayed relative to b.
|
|
|
+ """
|
|
|
+ assert len(a) == len(b)
|
|
|
+ n = len(a)
|
|
|
+ a_norm = (a - a.mean()) / a.std()
|
|
|
+ b_norm = (b - b.mean()) / b.std()
|
|
|
+ xcorr = np.correlate(a_norm, b_norm, 'full') / n
|
|
|
+ lags = np.arange(-n + 1, n)
|
|
|
+ if dt is not None:
|
|
|
+ lags = lags * dt
|
|
|
+ if ts is not None:
|
|
|
+ assert len(ts) == 2
|
|
|
+ i0, i1 = lags.searchsorted(ts)
|
|
|
+ xcorr = xcorr[i0:i1]
|
|
|
+ lags = lags[i0:i1]
|
|
|
+
|
|
|
+ return xcorr, lags
|
|
|
+
|
|
|
+def interpolate(y, x_old, x_new, axis=0, fill_value='extrapolate'):
|
|
|
+ """
|
|
|
+ Use linear interpolation to re-sample 1D data.
|
|
|
+ """
|
|
|
+ # get interpolation function
|
|
|
+ func = interp1d(x_old, y, axis=axis, fill_value=fill_value)
|
|
|
+ # get new y-values
|
|
|
+ y_interpolated = func(x_new)
|
|
|
+ return y_interpolated
|
|
|
+
|
|
|
+def interpolate_and_normalize(y, x_old, x_new):
|
|
|
+ """
|
|
|
+ Perform linear interpolation and min-max normalization.
|
|
|
+ """
|
|
|
+ y_new = interpolate(y, x_old, x_new)
|
|
|
+ return (y_new - y_new.min()) / (y_new.max() - y_new.min())
|
|
|
+
|
|
|
+def match_signal_length(a, b, a_tpts, b_tpts):
|
|
|
+ """
|
|
|
+ Given two signals, truncate to match the length of the shortest.
|
|
|
+ """
|
|
|
+ t1 = min(a_tpts.max(), b_tpts.max())
|
|
|
+ a1 = a[:a_tpts.searchsorted(t1)]
|
|
|
+ a1_tpts = a_tpts[:a_tpts.searchsorted(t1)]
|
|
|
+ b1 = b[:b_tpts.searchsorted(t1)]
|
|
|
+ b1_tpts = b_tpts[:b_tpts.searchsorted(t1)]
|
|
|
+ return a1, b1, a1_tpts, b1_tpts
|
|
|
+
|
|
|
+def times_to_counts(series, columns, t0=None, t1=None, dt=0.25):
|
|
|
+ if type(t0) == str:
|
|
|
+ t0, t1 = series[t0].min(), series[t0].max()
|
|
|
+ elif t0 is None: # get overlapping time range for all columns
|
|
|
+ t0 = max([series[f'{col.split("_")[0]}_times'].min() for col in columns])
|
|
|
+ elif t1 is None:
|
|
|
+ t1 = min([series[f'{col.split("_")[0]}_times'].max() for col in columns])
|
|
|
+ # Set time base
|
|
|
+ tbins = np.arange(t0, t1, dt)
|
|
|
+ tpts = tbins[:-1] + (dt / 2)
|
|
|
+ for col in columns:
|
|
|
+ header = col.split('_')[0]
|
|
|
+ times = series[f'{header}_times']
|
|
|
+ counts, _ = np.histogram(times, bins=tbins)
|
|
|
+ series[f'{header}_counts'] = counts
|
|
|
+ series[f'{header}_tpts'] = tpts
|
|
|
+ return series
|
|
|
+
|
|
|
+def resample_timeseries(y, tpts, dt=0.25):
|
|
|
+ tpts_new = np.arange(tpts.min(), tpts.max(), dt)
|
|
|
+ return tpts_new, interpolate(y, tpts, tpts_new)
|
|
|
+
|
|
|
+def _resample_data(series, columns, t0=None, t1=None, dt=0.25):
|
|
|
+ if type(t0) == str:
|
|
|
+ t0, t1 = series[t0].min(), series[t0].max()
|
|
|
+ elif t0 is None: # get overlapping time range for all columns
|
|
|
+ t0 = max([series[f'{col.split("_")[0]}_tpts'].min() for col in columns])
|
|
|
+ t1 = min([series[f'{col.split("_")[0]}_tpts'].max() for col in columns])
|
|
|
+ # Set new time base
|
|
|
+ tbins = np.arange(t0, t1, dt)
|
|
|
+ tpts_new = tbins[:-1] + (dt / 2)
|
|
|
+ # Interpolate and re-sample each column
|
|
|
+ for col in columns:
|
|
|
+ header = col.split('_')[0]
|
|
|
+ data = series[col]
|
|
|
+ tpts = series[f'{header}_tpts']
|
|
|
+ series[col] = interpolate(data, tpts, tpts_new)
|
|
|
+ series[f'{header}_tpts'] = tpts_new
|
|
|
+ return series
|
|
|
+
|
|
|
+## Neural activity
|
|
|
+def get_mean_rates(df):
|
|
|
+ df['mean_rate'] = df.apply(
|
|
|
+ lambda x:
|
|
|
+ len(x['spk_times']) / (x['spk_tinfo'][1] - x['spk_tinfo'][0]),
|
|
|
+ axis='columns'
|
|
|
+ )
|
|
|
+ return df
|
|
|
+
|
|
|
+def get_mean_rate_threshold(df, alpha=0.025):
|
|
|
+ if 'mean_rate' not in df.columns:
|
|
|
+ df = get_mean_rates(df)
|
|
|
+ rates = np.log10(df['mean_rate'])
|
|
|
+ gmm = GaussianMixture(n_components=2)
|
|
|
+ gmm.fit(rates[..., np.newaxis])
|
|
|
+ (mu, var) = (gmm.means_.max(), gmm.covariances_.squeeze()[gmm.means_.argmax()])
|
|
|
+ threshold = mu + norm.ppf(alpha) * np.sqrt(var)
|
|
|
+ return threshold
|
|
|
+
|
|
|
+def filter_units(df, threshold):
|
|
|
+ if 'mean_rate' not in df.columns:
|
|
|
+ df = get_mean_rates(df)
|
|
|
+ return df.query(f'mean_rate >= {threshold}')
|
|
|
+
|
|
|
+def get_raster(series, events, spike_type='spk', pre=0, post=1):
|
|
|
+ events = series['%s_times' % events]
|
|
|
+ spks = series['%s_times' % spike_type]
|
|
|
+ raster = np.array([spks[(spks > t0 - pre) & (spks < t0 + post)] - t0 for t0 in events], dtype='object')
|
|
|
+ return raster
|
|
|
+
|
|
|
+def get_psth(events, spikes, pre=0, post=1, dt=0.001, bw=0.01, baseline=[]):
|
|
|
+ rate_kernel = GaussianKernel(bw*pq.s)
|
|
|
+ tpts = np.arange(pre, post, dt)
|
|
|
+ psth = np.full((len(events), len(tpts)), np.nan)
|
|
|
+ for i, t0 in enumerate(events):
|
|
|
+ rel_ts = spikes - t0
|
|
|
+ rel_ts = rel_ts[(rel_ts >= pre) & (rel_ts <= post)]
|
|
|
+ try:
|
|
|
+ rate = instantaneous_rate(
|
|
|
+ SpikeTrain(rel_ts, t_start=pre, t_stop=post, units='s'),
|
|
|
+ sampling_period=dt*pq.s,
|
|
|
+ kernel=rate_kernel
|
|
|
+ )
|
|
|
+ except:
|
|
|
+ continue
|
|
|
+ psth[i] = rate.squeeze()
|
|
|
+ if baseline:
|
|
|
+ b0, b1 = tpts.searchsorted(baseline)
|
|
|
+ baseline_rate = psth[:, b0:b1].mean(axis=1)
|
|
|
+ psth = (psth.T - baseline_rate).T
|
|
|
+ return psth, tpts
|
|
|
+
|
|
|
+def apply_get_psth(series, events, spike_type, **kwargs):
|
|
|
+ events = series['{}_times'.format(events)]
|
|
|
+ spikes = series['{}_times'.format(spike_type)]
|
|
|
+ psth, tpts = get_psth(events, spikes, **kwargs)
|
|
|
+ return psth
|
|
|
+
|
|
|
+def get_responses(events, data, tpts, pre=0, post=1, baseline=[]):
|
|
|
+ dt = np.round(np.diff(tpts).mean(), 3) # round to nearest ms
|
|
|
+ i_pre, i_post = int(pre / dt), int(post / dt)
|
|
|
+ responses = np.full((len(events), i_pre + i_post), np.nan)
|
|
|
+ for j, t0 in enumerate(events):
|
|
|
+ i = tpts.searchsorted(t0)
|
|
|
+ i0, i1 = i - i_pre, i + i_post
|
|
|
+ if i0 < 0:
|
|
|
+ continue
|
|
|
+ if i1 > len(data):
|
|
|
+ break
|
|
|
+ responses[j] = data[i0:i1]
|
|
|
+ tpts = np.linspace(pre, post, responses.shape[1])
|
|
|
+ if baseline:
|
|
|
+ b0, b1 = tpts.searchsorted(baseline)
|
|
|
+ baseline_resp = responses[:, b0:b1].mean(axis=1)
|
|
|
+ responses = (responses.T - baseline_resp).T
|
|
|
+ return responses, tpts
|
|
|
+
|
|
|
+def apply_get_responses(series, events, data, **kwargs):
|
|
|
+ events = series[f'{events}_times']
|
|
|
+ tpts = series[f'{data.split("_")[0]}_tpts']
|
|
|
+ data = series[f'{data}']
|
|
|
+ responses, tpts = get_responses(events, data, tpts, **kwargs)
|
|
|
+ return responses
|