Parcourir la source

Upload analysis scripts

Davide Crombie il y a 1 mois
Parent
commit
eaaff4d5d9
8 fichiers modifiés avec 2564 ajouts et 0 suppressions
  1. 738 0
      hht.py
  2. 106 0
      imf_correlation.py
  3. 171 0
      imf_decoding.py
  4. 114 0
      parameters.py
  5. 323 0
      phase_tuning.py
  6. 114 0
      size_tuning.py
  7. 107 0
      triggered_spiking.py
  8. 891 0
      util.py

+ 738 - 0
hht.py

@@ -0,0 +1,738 @@
+import argparse
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+
+import emd
+import scipy.signal
+from scipy import fft
+from scipy.signal import windows, detrend, find_peaks
+from scipy.interpolate import BPoly
+from sklearn.decomposition import PCA
+
+from parameters import DATAPATH, NIMFCYCLES
+from util import (load_data, zero_runs, merge_ranges, circhist, circmean_angle,
+                  kl_divergence, switch_ranges, match_distributions)
+
+def pad_signal(x, method='peaks', npts=None, edge_tolerance=3, npeaks=5):
+    """
+    Extend a signal at both ends with various methods. Return the padded
+    signal and a binary array for recovering the orignal signal i.e.
+    orig_signal = padded_signal[inds == 1].
+
+    method: 'even' - reflect signal at edges.
+            'odd' - "rotate" signal 180deg at edges.
+            'peaks' - extrapolate signal by reflecting peaks at edges and
+            fitting a Bernstein polynomail constrained by 1st derivative.
+
+    npts: number of points added to each end of signal when using the
+          'even' or 'odd' methods.
+
+    edge_tolerance: number of points to extrapolate when checking if the
+                    signal edges represent a peak.
+
+    npeaks: number of peaks to use at each end of signal when using the
+            'peaks' method.
+    """
+    if npts is None:
+        npts = np.round(len(x) / 10).astype('int')
+
+    if method == 'even':
+        assert type(npts) in [int, np.int64], "npts must be an integer"
+        from scipy.signal._arraytools import even_ext
+        padded_signal = even_ext(x, npts) # extend signal using 'even' method
+        # create array indicating original signal
+        inds = np.full(padded_signal.shape, False)
+        inds[npts:-npts] = True
+        return padded_signal, inds
+    elif method == 'odd':
+        assert type(npts) in [int, np.int64], "npts must be an integer"
+        from scipy.signal._arraytools import odd_ext
+        padded_signal = odd_ext(x, npts) # extend signal using 'odd' method
+        # create array indicating original signal
+        inds = np.full(padded_signal.shape, False)
+        inds[npts:-npts] = True
+        return padded_signal, inds
+    elif method == 'peaks':
+        assert type(npeaks) == int, "npeaks must be an integer"
+        pre = _pad_signal_mirror_peaks(x, edge_tolerance, npeaks)
+        post = _pad_signal_mirror_peaks(np.flip(x), edge_tolerance, npeaks)
+        padded_signal = np.concatenate((pre, x, np.flip(post)))
+        # create array indicating original signal
+        inds = np.full(padded_signal.shape, False)
+        inds[len(pre):-len(post)] = True
+
+        return padded_signal, inds
+
+def get_peaks_simple(x, edge_tolerance=2):
+    """Return two arrays of indices indicating peaks & troughs of a signal occur."""
+    assert x.ndim == 1, "Signal must be 1D"
+    peaks, troughs = np.array([]), np.array([])
+    search_range = np.arange(len(x))[1:-1] # search whole signal except endpoints
+    for i in search_range: # go forwards through signal
+        if (x[i-1] < x[i] > x[i+1]): # point is a peak
+            peaks = np.append(peaks, i)
+        elif (x[i-1] > x[i] < x[i+1]): # point is a trough
+            troughs = np.append(troughs, i)
+    # estimate gradient change beyond signal edges for requested tolerance
+    # by assuming a constant 2nd derivative (i.e. peaks are parabolic)
+    gradient = np.gradient(x) # first derivative
+    gradient2 = np.gradient(gradient) # second derivative
+    # before signal
+    grad_pre = gradient[0] - edge_tolerance*gradient2[0]
+    if np.sign(grad_pre) + -np.sign(gradient[0]) == 2: # first point is peak
+        peaks = np.concatenate(([0], peaks))
+    elif np.sign(grad_pre) + -np.sign(gradient[0]) == -2: # first point is trough
+        troughs = np.concatenate(([0], troughs))
+    # after signal
+    grad_post = gradient[-1] + edge_tolerance*gradient2[-1]
+    if np.sign(gradient[-1]) + -np.sign(grad_post) == 2: # last point is peak
+        peaks = np.concatenate((peaks, [len(x) - 1]))
+    elif np.sign(gradient[-1]) + -np.sign(grad_post) == -2: # last point is trough
+        troughs = np.concatenate((troughs, [len(x) - 1]))
+    # return as ints for indexing
+    return peaks.astype('int'), troughs.astype('int')
+
+def _pad_signal_mirror_peaks(x, edge_tolerance, npeaks):
+    """
+    Extend a signal from the beginning by mirroring npeaks and interpolating
+    over these peak values with splines restricted by the gradients at the
+    signal end points.
+
+    Notes
+    -----
+    - To extend a signal from the end, simply pass np.flip(x) instead of x,
+    then flip the returned array before adding it to the original signal.
+    """
+    peaks = np.concatenate(get_peaks_simple(x, edge_tolerance)) # peaks & troughs
+    peaks.sort()
+    assert len(peaks) >= 2, "<2 peaks present in signal"
+    if np.sign(x[peaks[0]]) == np.sign(x[0]):
+        peaks = peaks[1:npeaks].astype('int') # convert to ints for indexing
+    else:
+        peaks = peaks[:npeaks].astype('int')
+    #half_period = peaks[1] - peaks[0] # half period of first oscillation
+    #offset = half_period - peaks[0] # "phase" of oscillation at beginning
+    grad = np.gradient(x)[0] # derivative at beginning of signal
+    inds = np.concatenate(([0], peaks))
+    # y-values of points to interpolate
+    y = np.concatenate(([x[0] - grad], x[peaks]))
+    # gradients at points to interpolate (0 except for 1st point)
+    grads = np.concatenate(([-grad], np.full(peaks.shape, 0)))
+    # get Bernstein polynomial splines for given values and derivatives
+    splines = BPoly.from_derivatives(inds, np.vstack((y, grads)).T, orders=3)
+    return np.flip(splines(np.arange(inds[-1]))) # interpolate & flip
+
+def hilbert(x, fs=1, axis=1):
+    """
+    Perform Hilbert spectral analysis on a signal.
+
+    Parameters
+    ----------
+    x : ndarray
+        the signal
+
+    fs : float (default = 1)
+        sampling frequency
+
+    axis : int
+        axis of x along which to perform the analysis
+    """
+    y = scipy.signal.hilbert(x, axis=axis) # complex analytic signal
+    phase = np.angle(y) # CCW angle from positive real axis
+    # rate of change of phase
+    freq = np.gradient(np.unwrap(phase), axis=axis) / (2*np.pi) * fs
+    amp = np.abs(y) # length of signal vector at each timepoint
+    return phase, freq, amp
+
+def fft_psd(signal, fs):
+    """
+    Compute the power spectral density of a signal using the Fourier transform.
+    """
+    signal = signal - signal.mean()
+    fft_freq = np.fft.rfftfreq(len(signal), 1 / fs)[1:]
+    signal_ft = np.fft.rfft(signal - signal.mean())[1:]
+    #fft_power = np.abs(signal_ft) ** 2 * 2 / len(signal) ** 2
+    fft_power = np.abs(signal_ft) / len(signal)
+    return fft_freq, fft_power
+
+def hsa_psd(f, a, f_bins=None):
+    """
+    Compute the marginal power spectrum of a set of instantaneous frequency and
+    amplitude traces.
+    """
+    if f_bins is None:
+        f_bins = np.fft.rfftfreq(len(f), 1 / fs)[1:]
+    psd = np.zeros(len(f_bins))
+    f_inds = np.digitize(f, f_bins) # assign a frequency bin to each value
+    for x, i in zip(a.ravel(), f_inds.ravel()): # loop over all values
+        psd[i] += x # accumulate squared amplitudes
+    psd /= len(f) # normalize by the number of timepoints
+    return psd
+
+def check_binned_visits(data, bins, n_visits):
+    """
+    ## TODO: update docstring
+    Check that, for each signal in the set, each of the given phase bins is
+    visited a certain number of times.
+
+    Parameters
+    ----------
+    signal_phases : ndarray
+        A set of instantaneous phase traces, rows are signals, columns are
+        time-points.
+    phase_bins : ndarray
+        The start and stop values of a set of phase bins, edge-inclusive.
+    n_visits : int
+        The number of times that each phase bin should be visited.
+
+    Returns
+    -------
+    sufficient_visits : ndarray
+        Boolean array indicating if each signal in the set visited each of the
+        phase bins the desired number of times.
+    """
+    # allow single channel input
+    if data.ndim == 1:
+        data = data[:, np.newaxis].T
+    # initialize boolean array (all true)
+    sufficient_visits = np.ones(len(data)).astype('bool')
+    # loop over instantaneous phase traces
+    for ind, var in enumerate(data):
+        # bin according to phase
+        binned_var = np.digitize(var, bins)
+        # time ranges during which IMF is in each phase bin
+        bin_tranges = [zero_runs(~np.equal(binned_var, b)) for b in np.arange(1, len(bins))]
+        # minimum number of visits across phase bins
+        min_tranges = min([len(tranges) for tranges in bin_tranges])
+        # change 1 to 0 if number of times for each phase bin is insufficient
+        if min_tranges < n_visits:
+            sufficient_visits[ind] = False
+    return sufficient_visits
+
+def compute_fev(signal, components, add_mean=True):
+    """
+    Compute the fraction of variance explained in a target signal by a set of
+    components.
+    """
+    recon = components.sum(axis=0)  # reconstructed signal
+    if add_mean:
+        recon += signal.mean()
+    mse = ((signal - recon) ** 2).mean() # mean squared error
+    fev = 1 - (mse / signal.var()) # fraction explained variance
+    return fev
+
+def mtcsd(x, fs=1, nperseg=None, nfft=None, noverlap=None, nw=3, ntapers=None,
+          detrend_method='constant'):
+    """
+    Pair-wise cross-spectral density using Slepian tapers. Adapted from the
+    mtcsd function in the labbox Matlab toolbox (authors: Partha Mitra,
+    Ken Harris).
+
+    Parameters
+    ----------
+    x : ndarray
+        2D array of signals across which to compute CSD, columns treated as
+        channels
+    fs : float (default = 1)
+        sampling frequency
+    nperseg : int, None (default = None)
+        number of data points per segment, if None nperseg is set to 256
+    nfft : int, None (default = None)
+        number of points to include in scipy.fft.fft, if None nfft is set to
+        2 * nperseg, if nfft > nperseg data will be zero-padded
+    noverlap : int, None (default = None)
+        amout of overlap between consecutive segments, if None noverlap is set
+        to nperseg / 2
+    nw : int (default = 3)
+        time-frequency bandwidth for Slepian tapers, passed on to
+        scipy.signal.windows.dpss
+    ntapers : int, None (default = None)
+        number of tapers, passed on to scipy.signal.windows.dpss, if None
+        ntapers is set to nw * 2 - 1 (as suggested by original authors)
+    detrend_method : {'constant', 'linear'} (default = 'constant')
+        method used by scipy.signal.detrend to detrend each segment
+
+    Returns
+    -------
+    f : ndarray
+        frequency bins
+    csd : ndarray
+        full cross-spectral density matrix
+    """
+    # allow single channel input
+    if x.ndim == 1:
+        x = x[:, np.newaxis]
+
+    # ensure no more than 2D input
+    assert x.ndim == 2
+
+    # set some default for parameters values
+    if nperseg is None:
+        nperseg = 256
+
+    if nfft is None:
+        nfft = nperseg * 2
+
+    if noverlap is None:
+        noverlap = nperseg / 2
+
+    if ntapers is None:
+        ntapers = 2 * nw - 1
+
+    # get step size and total number of segments
+    stepsize = nperseg - noverlap
+    nsegs = int(np.floor(len(x) / stepsize))
+
+    # initialize csd matrix
+    csd = np.zeros((x.shape[1], x.shape[1], nfft), dtype='complex128')
+
+    # get FFT frequency bins
+    f = fft.fftfreq(nfft, 1/fs)
+
+    # get tapers
+    tapers = windows.dpss(nperseg, nw, Kmax=ntapers)
+
+    # loop over segments
+    for seg_ind in range(nsegs):
+
+        # prepare segment
+        i0 = int(seg_ind * stepsize)
+        i1 = int(seg_ind * stepsize + nperseg)
+        if i1 > len(x): # stop if segment is out of range of data
+            nsegs -= (nsegs - seg_ind) # reduce segment count
+            break
+        seg = x[i0:i1, :]
+        seg = detrend(seg, type=detrend_method, axis=0)
+
+        # apply tapers
+        tapered_seg = np.full((len(tapers), seg.shape[0], seg.shape[1]), np.nan)
+        for taper_ind, taper in enumerate(tapers):
+            tapered_seg[taper_ind] = (seg.T * taper).T
+
+        # compute FFT for each channel-taper combination
+        fftnorm = np.sqrt(2) # value taken from original matlab function
+        pxx = fft.fft(tapered_seg, n=nfft, axis=1) / fftnorm
+
+        # fill upper triangle of csd matrix
+        for ch1 in range(x.shape[1]): # loop over unique channel combinations
+            for ch2 in range(ch1, x.shape[1]):
+                # compute csd bewteen channels, summing over tapers and segments
+                csd[ch1, ch2, :] += (pxx[:, :, ch1] * np.conjugate(pxx[:, :, ch2])).sum(axis=0)
+
+    # normalize csd by number of taper-segment combinations
+    # (equivalent to averaging over segments and tapers)
+    csdnorm = ntapers * nsegs
+    csd /= csdnorm
+
+    # fill lower triangle of csd matrix with complex conjugate of upper triangle
+    for ch1 in range(x.shape[1]):
+        for ch2 in range(ch1 + 1, x.shape[1]):
+            csd[ch2, ch1, :] = np.conjugate(csd[ch1, ch2, :])
+
+    return f, csd
+
+def mtcoh(x, **kwargs):
+    """
+    Pair-wise multi-taper coherence for a set of signals.
+
+    Parameters
+    ----------
+    See mtcsd documentation.
+
+    Returns
+    -------
+    f : ndarray
+        frequency bins
+    coh : ndarray
+        full spectral coherence matrix
+    """
+    # Compute cross-spectral density
+    f, csd = mtcsd(x, **kwargs)
+    # Compute power normalization matrix
+    powernorm = np.zeros((x.shape[1], x.shape[1], len(f)))
+    for ch1 in range(x.shape[1]):
+        for ch2 in range(x.shape[1]):
+            powernorm[ch1, ch2] = np.sqrt(np.abs(csd[ch1, ch1]) * np.abs(csd[ch2, ch2]))
+    # Normalize CSD to get coherence
+    coh = np.abs(csd) ** 2 / powernorm
+    # Return frequency array, coherence, and phase differences
+    return f, coh, np.angle(csd)
+
+
+class HHT():
+    def __init__(self, signal, fs):
+        self.signal = signal
+        self.fs = fs
+        self.n_samples = len(self.signal)
+
+    def emd(self):
+        signal = self.signal - self.signal.mean()
+        self.imfs = emd.sift.sift(signal)
+        self.n_imfs = self.imfs.shape[1]
+
+    def hsa(self):
+        phases, frequencies, amplitudes = np.full((3, self.n_samples, self.n_imfs), np.nan)
+        for i, imf in enumerate(self.imfs.T):
+            # pad signal to reduce edge-effects for Hilbert analysis
+            try: # extrapolate by mirroring peaks
+                imf_padded, orig_inds = pad_signal(imf, method='peaks')
+            except AssertionError: # signal has fewer than two peaks
+                # extrapolate by "reflecting signal 180deg"
+                imf_padded, orig_inds = pad_signal(imf, method='odd')
+            # get analytic signal
+            phase, frequency, amplitude = hilbert(imf_padded, fs=self.fs, axis=0)
+            # keep only values corresponding to original signal
+            phases[:, i] = phase[orig_inds]
+            frequencies[:, i] = frequency[orig_inds]
+            amplitudes[:, i] = amplitude[orig_inds]
+        self.phase, self.frequency, self.amplitude = phases, frequencies, amplitudes
+        # Amplitude-weighted mean frequency for each IMF
+        self.characteristic_frequency = (self.frequency * self.amplitude).sum(axis=0) / self.amplitude.sum(axis=0)
+        # Power of each IMF
+        self.power_density = (self.amplitude ** 2).sum(axis=0) / self.amplitude.shape[1]
+        self.power_ratio = self.power_density / self.power_density.sum()
+
+    def marginal_spectrum(self, f_bins=None, ranges=None):
+        """
+        Compute the marginal power spectrum of the IMF set.
+        """
+        if f_bins is None:
+            f_bins = np.fft.rfftfreq(self.n_samples, 1 / self.fs)[1:]
+        psd = np.zeros(len(f_bins))
+        if ranges is None:
+            frequency = self.frequency
+            amplitude = self.amplitude
+        else:
+            frequency = np.concatenate([self.frequency[i0:i1] for i0, i1 in ranges])
+            amplitude = np.concatenate([self.amplitude[i0:i1] for i0, i1 in ranges])
+        binned_frequency = np.digitize(frequency, f_bins) # assign a frequency bin to each value
+        for x, i in zip(amplitude.ravel(), binned_frequency.ravel()): # loop over all values
+            psd[i] += x # accumulate squared amplitudes
+        psd /= len(frequency) # normalize by the number of timepoints
+        return psd
+
+    def check_number_of_phasebin_visits(self, phasebins=None, ncycles=4, remove_invalid=False):
+        if phasebins is None:
+            phasebins = np.linspace(-np.pi, np.pi, 5)
+        self.sufficient_phasebin_visits = check_binned_visits(self.phase.T, phasebins, ncycles)
+        if remove_invalid:
+            for attr in ['imfs', 'phase', 'frequency', 'amplitude']:
+                setattr(self, attr, getattr(self, attr)[:, self.sufficient_phasebin_visits])
+            for attr in ['characteristic_frequency', 'power_density', 'power_ratio']:
+                setattr(self, attr, getattr(self, attr)[self.sufficient_phasebin_visits])
+            self.n_imfs = self.sufficient_phasebin_visits.sum()
+
+    def check_imf_significance(self):
+        print("WARNING: IMF significance depricated.")
+        assert hasattr(self, 'imfs')
+        ln_f, ln_E, bounds = imf_statsig(self.imfs.T, return_period=False, use_hilbert=True)
+        self.imf_significance = ln_E > bounds[1]
+
+    def get_synchronous_events(self, dt=0.5, n_cycles=0.25, threshold_qt=0.95):
+        """
+        Perform a sliding window correlation between pairs of IMFs with similar frequencies.
+
+        Notes
+        -----
+        This measure is similar to a time-resolved version of the pseudo mode splitting index
+        from Wang et al. (2018) and Fabus et al. (2021).
+        """
+        imfs = self.imfs[:, np.where(self.characteristic_frequency > 0)[0]]
+        freqs = self.characteristic_frequency[np.where(self.characteristic_frequency > 0)[0]]
+        imfs1 = imfs[:-1]
+        imfs2 = np.roll(imfs, -1, axis=1)[:-1]
+        freqs2 = np.roll(freqs, -1, axis=1)[:-1]
+
+        step_size = np.round(dt * self.fs).astype(int)
+        samples = np.arange(0, len(imfs1), step_size)
+
+        sync = np.full((len(samples), imfs1.shape[1]),np.nan)
+        for i, (imf1, imf2, freq) in enumerate(zip(imfs1.T, imfs2.T, freqs2)):
+            window_size = np.round(n_cycles * self.fs / freq).astype(int)
+            starts = np.clip(samples - window_size, a_min=0, a_max=None)
+            stops = np.clip(samples + window_size, a_min=None, a_max=(len(imf1) - 1))
+            sync[:, i] = [np.dot(imf1[start:stop], imf2[start:stop]) / (stop - start) for start, stop in zip(starts, stops)]
+
+        threshold = np.quantile(sync.mean(axis=0), threshold_qt)
+        events = continuous_runs(sync.mean(axis=0) > threshold, min1len=5)
+        self.synchronous_events = pts[events]
+
+    def pairwise_coherence(self, ncycles=4):
+        """
+        Compute phase coherence between all pairs of IMFs.
+        """
+        coh_mat, pdiff_mat = np.full((2, self.n_imfs, self.n_imfs), np.nan)
+        for imfi, imf in enumerate(self.imfs.T):
+            # Get appropriate window size for this IMFs characteristic frequency
+            period = 1 / self.characteristic_frequency[imfi] # get IMF period from characteristic frequency
+            seglen = ncycles * period
+            nperseg = int(2 ** np.floor((np.log2(seglen * self.fs)))) # number of samples
+            # Skip if segment not long enough to estimate coherence
+            if nperseg > self.n_samples:
+                continue
+            # Compute pair-wise cross-spectral density
+            f, coh, pdiff = mtcoh(self.imfs, fs=self.fs, nperseg=nperseg)
+            # Take only the row corresponding to the current IMF
+            coh = coh[imfi]
+            pdiff = pdiff[imfi]
+            # Get index of the appropriate frequency bin (consider only +ve freqs)
+            f_ind = f[f > 0].searchsorted(self.characteristic_frequency[imfi])
+            # Take mean of two most apropriate frequency bins
+            coh = coh[:, f_ind:(f_ind + 2)].mean(axis=1)
+            pdiff = circmean_angle(pdiff[:, f_ind:(f_ind + 2)], axis=1)
+            # Fill row of matrix
+            coh_mat[imfi] = coh
+            pdiff_mat[imfi] = pdiff
+        # Normalize each row by it's maximum to get rid of contributions of power
+        self.coherence = (coh_mat.T / coh_mat.max(axis=1)).T
+        self.phasediff = pdiff_mat
+
+    def phase_synchrony(self, n_bins=16, n_shf=1000):
+        phases = self.phase.T
+        n_phases = len(phases)
+        freqs = self.characteristic_frequency
+        bin_edges = np.linspace(-np.pi, np.pi, n_bins + 1)
+        # Get bin areas array to normalize density function
+        D_areas = np.outer(np.diff(bin_edges), np.diff(bin_edges))
+        # Get a reference uniform distribution
+        D_uniform = np.ones((n_bins, n_bins)) / n_bins**2
+        # Initialize array to collect the joint distributions
+        DD = np.full((n_phases, n_phases, n_bins, n_bins), np.nan)
+        # Initialize array to colled KLDs
+        DD_kld = np.full((len(phases), len(phases)), np.nan)
+        # Initialize array to collect distribution p-values
+        DD_p = np.full((len(phases), len(phases)), np.nan)
+        # Initialize array to collect the significance masks
+        DD_masks = np.full((n_phases, n_phases, n_bins, n_bins), np.nan)
+        # Initialize array to collect synchronous time ranges
+        DD_ranges = np.full((n_phases, n_phases), np.nan, dtype='object')
+        # Loop over pairs of phase traces
+        for i in range(len(phases)):
+            for j in range(len(phases)):
+                # Skip if not in upper triangle of pairwise matrix (redundant info)
+                if i == j: continue
+                # Make marginal distributions uniform by converting phases to ranks
+                #ranks_i = phase2rank(phases[i]) - np.pi
+                #ranks_j = phase2rank(phases[j]) - np.pi
+                # Get the joint probability functions
+                D = np.histogram2d(phases[i], phases[j], bins=bin_edges, density=True)[0] * D_areas
+                # Normalize by marginal distributions
+                #D = (D.T / np.histogram(phases[i], bins=phase_bins)[0]).T  # normalize rows
+                #D = D / np.histogram(phases[j], bins=phase_bins)[0]  # normalize columns
+                DD[i, j] = D
+                # Compute Kullback-Leibler divergence from uniform
+                ## TODO: add eps to D to ensure no zero values? --> no negative KLDs
+                kld = kl_divergence(D, D_uniform)
+                # Initialize array to collect shuffle distributions
+                DD_shf = np.full((n_shf, D.shape[0], D.shape[1]), np.nan)
+                # Initialize array to collect shuffle KLDs
+                kld_shf = np.full(n_shf, np.nan)
+                # Perform shuffles
+                for shf in range(n_shf):
+                    # Randomly shuffle time points
+                    #shf_i = np.random.choice(np.arange(len(phases[i])), size=len(phases[i]), replace=False)
+                    #shf_j = np.random.choice(np.arange(len(phases[j])), size=len(phases[j]), replace=False)
+                    # Shuffle cycle order
+                    cycles_i = np.split(phases[i], np.where(np.diff(phases[i]) < -np.pi)[0])
+                    cycles_j = np.split(phases[j], np.where(np.diff(phases[j]) < -np.pi)[0])
+                    np.random.shuffle(cycles_i)
+                    np.random.shuffle(cycles_j)
+                    shf_i = np.concatenate(cycles_i)
+                    shf_j = np.concatenate(cycles_j)
+                    # Get PDF of shuffle
+                    D_shf = np.histogram2d(shf_i, shf_j, bins=bin_edges, density=True)[0] * D_areas
+                    DD_shf[shf] = D_shf
+                    # Compute KLD of shuffle
+                    kld_shf[shf] = kl_divergence(D_shf, D_uniform)
+                # Get KLD diff
+                DD_kld[i, j] = (kld - kld_shf.mean()) / kld_shf.std()
+                # Get KLD p-values
+                DD_p[i, j] = (kld_shf > kld).mean()
+                # Get significance mask for joint distribution
+                mask = D > np.percentile(D_shf, 95, axis=0)
+                DD_masks[i, j] = mask
+                # Find time ranges where phase traces pass though significant regions
+                ranges = []
+                for pi, pj in np.column_stack(np.where(mask)):
+                    mask_i = (phases[i] > bin_edges[pi]) & (phases[i] <= bin_edges[pi + 1])
+                    mask_j = (phases[j] > bin_edges[pj]) & (phases[j] <= bin_edges[pj + 1])
+                    ranges.append(zero_runs(~(mask_i & mask_j)))
+                DD_ranges[i, j] = merge_ranges(np.concatenate(ranges))
+        return DD, DD_kld, DD_p, DD_masks, DD_ranges
+
+    def modemix(self, alpha=0.05):
+        """
+        Compute a metric for the amount of overlap between all pairs of signals in a
+        set. Metric represents the average (over all signal pairs) proportion of
+        time during which the signals crossed.
+
+        Parameters
+        ----------
+        signals : ndarray
+            signals array of shape (nchannels, ntimepoints)
+
+        Returns
+        -------
+        out : float
+            mean proportion of overlap between all pairs of signals
+
+        Notes
+        -----
+        -   Designed for use as a metric of frequency overlap (i.e. 'mode mixing')
+            between a set of IMFs resulting from EMD, in this case the input
+            should be a set of instantaneous frequency traces
+
+        References
+        ----------
+        [1] Laszuk, D., Cadenas, O., & Nasuto, S. J. (2015, July). Objective
+        empirical mode decomposition metric. In 2015 38th International Conference
+        on Telecommunications and Signal Processing (TSP) (pp. 504-507). IEEE.
+        """
+        metric = np.full((self.n_imfs, self.n_imfs), np.nan) # pair-wise matrix
+        for i in range(self.n_imfs): # loop over unique pairs
+            for j in range(i + 1, self.n_imfs):
+                assert self.characteristic_frequency[i] > self.characteristic_frequency[j]
+                overlap_i = (self.frequency[:, i] < np.quantile(self.frequency[:, j], 1 - alpha)).sum()
+                overlap_j = (self.frequency[:, j] > np.quantile(self.frequency[:, i], alpha)).sum()
+                # proportion of time for which there is overlap
+                metric[i, j] = (overlap_i + overlap_j) / self.n_samples
+        return metric
+
+    def get_imf_colors(self, cmap):
+        color_vals = 1 - np.linspace(0.1, 1, self.n_imfs)
+        return cmap(color_vals)
+
+def tranges_with_events(tranges, events):
+    with_event = np.full(len(tranges), False)
+    for i, (t0, t1) in enumerate(tranges):
+        with_event[i] = any([(evt >= t0) & (evt <= t1) for evt in events])
+    return with_event
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('e_name')
+    args = parser.parse_args()
+
+    df_pupil = load_data('pupil', [args.e_name])
+    df_run = load_data('ball', [args.e_name])
+    df = pd.merge(df_pupil, df_run, on=['m', 's', 'e'])
+    seriess = []
+    for _, row in tqdm(df.iterrows(), total=len(df)):
+        pupil_area = row['pupil_area']
+        pupil_tpts = row['pupil_tpts']
+
+        # Get IMFs
+        fs = 1 / np.diff(pupil_tpts).mean()
+        hht = HHT(pupil_area, fs)
+        hht.emd()
+
+        hht.hsa()
+        f_bins, psd = fft_psd(pupil_area, hht.fs)
+        hht_psd = hht.marginal_spectrum(f_bins=f_bins)
+        frequencies = hht.characteristic_frequency.copy()
+        powers = hht.power_ratio.copy()
+
+        run_ranges = row['pupil_tpts'].searchsorted(row['run_bouts'])
+        run_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=run_ranges)
+        sit_ranges = row['pupil_tpts'].searchsorted(row['sit_bouts'])
+        sit_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=sit_ranges)
+        #tranges = np.array([[0, int(len(pupil_tpts) / 2)]])
+        #half1_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=tranges)
+        #tranges = np.array([[int(len(pupil_tpts) / 2), len(pupil_tpts)]])
+        #half2_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=tranges)
+
+        hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
+
+        pbi = np.full(hht.n_imfs, np.nan)
+        cycle_tranges = []
+        for i, phase in enumerate(hht.phase.T):
+            # Get phase bias index
+            counts, _ = circhist(phase)
+            pbi[i] = (counts.max() - counts.min()) / counts.max()
+
+        #phase_components = np.column_stack([np.cos(hht.phase), np.sin(hht.phase)])
+        #pca = PCA()
+        #pca.fit(phase_components)
+        #hht.pairwise_coherence()
+        jpd, sync_kld, sync_p, sync_masks, sync_ranges = hht.phase_synchrony()
+        # Take ranges only for non-uniform distrbutions
+        #ranges = merge_ranges(np.concatenate(sync_ranges[sync_p <= 0.05]))
+        #synchronous_bouts = pupil_tpts[ranges]
+        sync_boutss = []
+        desync_boutss = []
+        for i, (ranges, ps) in enumerate(zip(sync_ranges, sync_p)):
+            bouts = ranges[ps <= 0.05] # take only if overall distribution is significant
+            #bouts = np.delete(ranges, i) # take all (except self)
+            if len(bouts) > 0:
+                sync_bouts = merge_ranges(np.concatenate(bouts))
+                desync_bouts = switch_ranges(sync_bouts, maxval=(hht.n_samples - 1))
+                sync_bouts = pupil_tpts[sync_bouts]
+                desync_bouts = pupil_tpts[desync_bouts]
+            sync_boutss.append(sync_bouts)
+            desync_boutss.append(desync_bouts)
+
+        # Get pupil size matched time ranges for each IMF
+        pupilarea_norm = pupil_area / pupil_area.max()
+        phase_bins = np.linspace(-np.pi, np.pi, 5)
+        phase_nbins = len(phase_bins) - 1
+        size_bins = np.linspace(0, 1, 11)
+        unmatched_means, matched_means = np.full((2, hht.n_imfs, phase_nbins, phase_nbins), np.nan)
+        matched_ranges = []
+        # loop over IMFs
+        for imf_i, imf_phase in enumerate(hht.phase.T):
+            phase_binned = np.digitize(imf_phase, phase_bins).clip(1, phase_nbins) - 1
+            phasebin_ranges = [zero_runs(~np.equal(phase_binned, phase_bin)) for phase_bin in np.arange(phase_nbins)]
+            unmatched_dists = [np.concatenate([pupilarea_norm[i0:i1] for i0, i1 in ranges]) for ranges in phasebin_ranges]
+            matched_ranges_imf = match_distributions(imf_phase, pupilarea_norm, phase_bins, size_bins)
+            matched_ranges.append(pupil_tpts[np.concatenate(matched_ranges_imf).clip(0, len(pupil_tpts) - 1).astype(int)])
+            matched_dists = [np.concatenate([pupilarea_norm[i0:i1] for i0, i1 in ranges]) if len(ranges) > 0 else np.array([]) for ranges in matched_ranges_imf]
+            for i in np.arange(phase_nbins):
+                for j in np.arange(phase_nbins):
+                    if j > i:
+                        unmatched_means[imf_i, i, j] = unmatched_dists[i].mean() - unmatched_dists[j].mean()
+                        matched_means[imf_i, i, j] = matched_dists[i].mean() - matched_dists[j].mean()
+
+        nosaccade_ranges = []
+        saccades = pupil_tpts.searchsorted(row['saccade_times'])
+        for imf_i, imf_phase in enumerate(hht.phase.T):
+            phase_binned = np.digitize(imf_phase, phase_bins).clip(1, phase_nbins) - 1
+            phasebin_ranges = [zero_runs(~np.equal(phase_binned, phase_bin)) for phase_bin in np.arange(phase_nbins)]
+            ranges2keep = [~tranges_with_events(ranges, saccades) for ranges in phasebin_ranges]
+            phasebin_ranges = np.concatenate([pupil_tpts.searchsorted(ranges[keep]) for ranges, keep in zip(phasebin_ranges, ranges2keep)])
+            nosaccade_ranges.append(phasebin_ranges)
+
+        data = {
+            'm': row['m'],
+            's': row['s'],
+            'e': row['e'],
+            'condition': args.e_name,
+            't0': pupil_tpts[0],
+            't1': pupil_tpts[-1],
+            'segment_length': pupil_tpts[-1] - pupil_tpts[0],
+            'frequency': frequencies,
+            'power': powers,
+            'hht_psd': hht_psd,
+            'fft_psd': psd,
+            'psd_freq': f_bins,
+            'run_psd': run_psd,
+            'sit_psd': sit_psd,
+            #'half1_psd': half1_psd,
+            #'half2_psd': half2_psd,
+            'pbi': pbi,
+            #'pc_variance':pca.explained_variance_ratio_,
+            #'coherence': hht.coherence,
+            #'phasediff': hht.phasediff,
+            'sync_p': sync_p,
+            'sync_kld': sync_kld,
+            'sync_bouts': sync_boutss,
+            'desync_bouts': desync_boutss,
+            'sizematched_bouts': matched_ranges,
+            'unmatched_meansize': unmatched_means,
+            'matched_meansize': matched_means,
+            'nosaccade_bouts': nosaccade_ranges
+            }
+        seriess.append(pd.Series(data=data))
+    df_hht = pd.DataFrame(seriess)
+    df_hht.to_pickle(DATAPATH + 'hht_{}.pkl'.format(args.e_name))
+    

+ 106 - 0
imf_correlation.py

@@ -0,0 +1,106 @@
+import argparse
+from tqdm import tqdm
+import numpy as np
+import pandas as pd
+from scipy.stats import pearsonr
+
+from parameters import NIMFCYCLES, NSHUFFLES, DATAPATH
+from util import load_data, interpolate, normalized_xcorr
+from hht import HHT
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('e_name')
+    parser.add_argument('-t', '--tranges', default='')
+    args = parser.parse_args()
+
+    df_pupil = load_data('pupil', [args.e_name]).set_index(['m', 's', 'e'])
+    df_run = load_data('ball', [args.e_name]).set_index(['m', 's', 'e'])
+    df_all = pd.read_pickle(DATAPATH + 'run.pkl')
+
+    # TODO: move to parameters
+    max_lag = 30  # seconds
+
+    seriess = []
+    for idx, row in tqdm(df_pupil.iterrows(), total=len(df_pupil)):
+        pupil_area = row['pupil_area']
+        pupil_tpts = row['pupil_tpts']
+        pupil_dt = np.diff(pupil_tpts).mean()
+        pupil_fs = 1 / pupil_dt
+        t0, t1 = (pupil_tpts.min(), pupil_tpts.max())
+        # Prepare run data
+        try:
+            run_speed, run_tpts = df_run.loc[idx, ['run_speed', 'run_tpts']]
+        except KeyError:
+            print("No run data found for ", idx)
+            continue
+        i0, i1 = run_tpts.searchsorted([t0, t1])
+        run_speed = interpolate(run_speed[i0:i1], run_tpts[i0:i1], pupil_tpts)
+        # Get IMFs
+        hht = HHT(pupil_area, pupil_fs)
+        hht.emd()
+        hht.hsa()
+        hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
+        imfs = hht.imfs.T
+        imf_freqs = hht.characteristic_frequency
+        imf_power = hht.power_ratio
+
+        if args.tranges:
+            tranges = df_run.loc[idx, '%s_bouts' % args.tranges]
+            if args.tranges == 'run':
+                ext = np.ones_like(tranges) * np.array([2, -2])
+            elif args.tranges == 'sit':
+                ext = np.ones_like(tranges) * np.array([4, -2])
+            tranges = tranges + ext
+            tranges = np.row_stack([trange for trange in tranges if trange[0] < trange[1]])
+            iranges = pupil_tpts.searchsorted(tranges)
+            imfs = np.column_stack([imfs[:, i0:i1] for i0, i1 in iranges])
+            pupil_area = np.concatenate([pupil_area[i0:i1] for i0, i1 in iranges])
+            run_speed = np.concatenate([run_speed[i0:i1] for i0, i1 in iranges])
+
+        for i, imf in enumerate(np.row_stack([pupil_area, imfs])):
+            data = {
+                'm':    idx[0],
+                's':    idx[1],
+                'e':    idx[2],
+                'imf':  i
+                }
+            if i == 0:
+                data['freq'] = data['power'] = np.nan
+            else:
+                data['freq'] = imf_freqs[i - 1]
+                data['power'] = imf_power[i - 1]
+            xcorr, lags = normalized_xcorr(imf, run_speed, dt=pupil_dt, ts=[-1 * max_lag, max_lag])
+            data['xcorr'] = xcorr
+            data['xcorr_lags'] = lags
+            r_null = np.full(NSHUFFLES, np.nan)
+            j = 0
+            while j < NSHUFFLES:
+                tpts, signal = df_all.iloc[np.random.choice(np.arange(len(df_all)))]
+                signal = interpolate(signal, tpts, np.arange(tpts.min(), tpts.max(), pupil_dt))
+                i_max = min(len(imf), len(signal))
+                if len(np.unique(signal[:i_max])) == 1:
+                    continue
+                r_null[j], _  = pearsonr(imf[:i_max], signal[:i_max])
+                j += 1
+            # Get search window for peak
+            if i == 0:
+                i0, i1 = 0, len(xcorr)
+            else:
+                T = 1 / imf_freqs[i - 1]
+                i0 , i1 = lags.searchsorted([-T, T])
+            # Find peak
+            xcorr_peak = lags[i0:i1][np.abs(xcorr[i0:i1]).argmax()]
+            data['xcorr_peak'] = xcorr_peak
+            # Compare to null distibution
+            xcorr_max = xcorr[i0:i1][np.abs(xcorr[i0:i1]).argmax()]
+            p = (r_null > xcorr_max).sum() / NSHUFFLES
+            data['xcorr_p'] = p
+            data['xcorr_sig'] = (p < 0.025) | (p > 0.975)
+            seriess.append(pd.Series(data=data))
+    df_corr = pd.DataFrame(seriess)
+    if not args.tranges:
+        filename = 'imfcorr_{}.pkl'.format(args.e_name)
+    else:
+        filename = 'imfcorr_{}_{}.pkl'.format(args.e_name, args.tranges)
+    df_corr.to_pickle(DATAPATH + filename)

+ 171 - 0
imf_decoding.py

@@ -0,0 +1,171 @@
+import argparse
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+
+from sklearn.svm import SVC
+from sklearn.model_selection import RepeatedStratifiedKFold, StratifiedKFold, cross_val_score
+
+from util import (load_data, get_trials, filter_units, get_psth, get_responses, circmean,
+                  angle_subtract)
+from phase_tuning import HHT
+from parameters import DATAPATH, NIMFCYCLES, NSPIKES, MINRATE
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('e_name')
+    args = parser.parse_args()
+
+    NSPLITS = 5
+    NPHASEBINS = 2
+    PHASEBINS = np.linspace(-np.pi, np.pi, NPHASEBINS + 1)
+
+    df_pupil = load_data('pupil', [args.e_name])
+    df_trials = load_data('trials', [args.e_name])
+    df_trials.rename(columns={'trial_on_time':'trial_on_times', 'trial_off_time':'trial_off_times'}, inplace=True)
+    df_trials = df_trials.apply(get_trials, stim_id=0, axis='columns')
+    df = pd.merge(df_pupil, df_trials).set_index(['m', 's', 'e'])
+
+    df_spikes = load_data('spikes', [args.e_name]).set_index(['m', 's', 'e'])
+    #df_spikes = df_spikes[df_spikes.index.isin(df_pupil.index)]
+    df_spikes = filter_units(df_spikes, MINRATE)
+
+    df_tuning = load_data('phasetuning', [args.e_name], tranges='noopto').set_index(['m', 's', 'e', 'u'])
+
+    seriess = []
+    for idx, row in tqdm(df.iterrows(), total=len(df)):
+        pupil_area = row['pupil_area']
+        pupil_tpts = row['pupil_tpts']
+        pupil_fs = 1 / np.diff(pupil_tpts).mean()
+
+        # Get IMFs
+        hht = HHT(pupil_area, pupil_fs)
+        hht.emd()
+        hht.hsa()
+        hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
+        imf_freqs = hht.characteristic_frequency
+        imf_phases = hht.phase.T
+
+        trial_starts = row['trial_on_times']
+        #trial_stops = row['trial_off_times']
+        #trial_duration = (trial_stops - trial_starts).mean()
+        trial_duration = 5
+        trial_starts = trial_starts[(trial_starts + trial_duration) < pupil_tpts.max()]
+        stim_labels = np.repeat(np.arange(NSPLITS), len(trial_starts))
+
+        # Get units for this experiment
+        try:
+            df_units = df_spikes.loc[idx]
+            df_units = df_units.reset_index().set_index(['m', 's', 'e', 'u'])
+        except KeyError:
+            print("Spikes missing for {}".format(idx))
+            continue
+
+        for u_idx, unit in df_units.iterrows():
+            print(u_idx)
+
+            for imf_i, phase in enumerate(imf_phases):
+                # Skip if no phase tuning analysis was done for this unit
+                df_unittuning = df_tuning.loc[u_idx].query('imf == %d' % (imf_i + 1))
+                if len(df_unittuning) < 1:
+                    continue
+                print(imf_i)
+
+                data = {
+                    'm':        idx[0],
+                    's':        idx[1],
+                    'e':        idx[2],
+                    'u':        u_idx[-1],
+                    'imf':      imf_i + 1,
+                    'freq':     imf_freqs[imf_i]
+                    }
+
+                phase_raster, raster_tpts = get_responses(trial_starts, phase, pupil_tpts, post=trial_duration)
+                split_inds = np.linspace(0, len(raster_tpts), NSPLITS + 1).astype(int)
+                # Mean phase for each stimulus segment
+                phase_means = np.concatenate(([circmean(phase_raster[:, i0:i1], axis=1)[1] for i0, i1 in zip(split_inds[:-1], split_inds[1:])]))
+
+                # Decode IMF phase using each spike type
+                for spike_type in ['tonicspk', 'burst']:
+                    print(spike_type)
+                    # get raster for spike type and split into segments
+                    spike_times = unit['%s_times' % spike_type]
+                    if len(spike_times) <= NSPIKES:
+                        continue
+                    spike_raster, spike_tpts = get_psth(trial_starts, spike_times, post=trial_duration)
+                    split_inds = np.linspace(0, len(spike_tpts), NSPLITS + 1).astype(int)
+                    X = np.row_stack([spike_raster[:, i0:i1] for i0, i1 in zip(split_inds[:-1], split_inds[1:])])
+                    # use tuning phase to set phase bins
+                    tuning_phase = df_unittuning['%s_phase' % spike_type][0]
+                    if np.isnan(tuning_phase):
+                        continue
+                    phase_shift = angle_subtract(tuning_phase, -1 *  np.pi / 2) - np.pi
+                    phase_means_shifted = angle_subtract(phase_means, phase_shift) - np.pi
+                    phase_labels = np.digitize(phase_means_shifted, bins=PHASEBINS)
+                    phase_labels = phase_labels.clip(1, NPHASEBINS) - 1
+                    if len(np.unique(phase_labels)) < 2:
+                        raise RuntimeError
+                    # predict phase bin
+                    print("decoding phase")
+                    classifier = SVC(kernel='rbf')
+                    crossval = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
+                    scores = cross_val_score(classifier, X, phase_labels, cv=crossval)
+                    data['%s_phase' % spike_type] = scores.mean()
+
+                ## Decode stimulus across phase bins using all spike times
+                spike_times = unit['spk_times']
+                if len(spike_times) <= NSPIKES:
+                    continue
+                spike_raster, spike_tpts = get_psth(trial_starts, spike_times, post=trial_duration)
+                #i0, i1 = spike_tpts.searchsorted([0, trial_duration])
+                #spike_raster = spike_raster[:, i0:i1]
+                #spike_tpts = spike_tpts[i0:i1]
+                split_inds = np.linspace(0, len(spike_tpts), NSPLITS + 1).astype(int)
+                X = np.row_stack([spike_raster[:, i0:i1] for i0, i1 in zip(split_inds[:-1], split_inds[1:])])
+                # use tonic phase to set the phase bins
+                tuning_phase = df_unittuning['tonicspk_phase'][0]
+                if np.isnan(tuning_phase):
+                    continue
+                phase_shift = angle_subtract(tuning_phase, -1 *  np.pi / 2) - np.pi
+                phase_means_shifted = angle_subtract(phase_means, phase_shift) - np.pi
+                phase_labels = np.digitize(phase_means_shifted, bins=PHASEBINS)
+                phase_labels = phase_labels.clip(1, NPHASEBINS) - 1
+                if ((phase_labels == 0).mean() < 0.25) or ((phase_labels == 1).mean() < 0.25):
+                    print("Phase split biased")
+                    continue
+                # split segments based on phase bin
+                X1 = X[phase_labels.astype(bool)]
+                y1 = stim_labels[phase_labels.astype(bool)]
+                if len(np.unique(y1)) < 5:
+                    raise RuntimeError
+                X2 = X[~phase_labels.astype(bool)]
+                y2 = stim_labels[~phase_labels.astype(bool)]
+                if len(np.unique(y2)) < 5:
+                    raise RuntimeError
+                # train on phase bin 2 & test
+                print("decoding stimulus, set 2")
+                classifier = SVC(kernel='linear').fit(X2, y2)
+                data['stim_train2_test2'] = classifier.score(X2, y2)
+                data['stim_train2_test1'] = classifier.score(X1, y1)
+                # train on the second phase bin & test
+                print("decoding stimulus, set 1")
+                classifier = SVC(kernel='linear').fit(X1, y1)
+                data['stim_train1_test1'] = classifier.score(X1, y1)
+                data['stim_train1_test2'] = classifier.score(X2, y2)
+                # random
+                splitter = RepeatedStratifiedKFold(n_splits=2, n_repeats=5, random_state=0)
+                shf_diffs = np.full(10, np.nan)
+                y = stim_labels
+                for shf_i, (train, test) in enumerate(splitter.split(X, stim_labels)):
+                    classifier = SVC(kernel='linear').fit(X[train], y[train])
+                    shf_diffs[shf_i] = classifier.score(X[test], y[test]) - classifier.score(X[train], y[train])
+                data['stim_testshf'] = shf_diffs
+
+                seriess.append(pd.Series(data=data))
+    df_decoding = pd.DataFrame(seriess)
+    filename = 'imfdecoding_{}_norm.pkl'.format(args.e_name)
+    df_decoding.to_pickle(DATAPATH + filename)
+
+
+

+ 114 - 0
parameters.py

@@ -0,0 +1,114 @@
+"""
+Settings for analysis and plotting.
+"""
+import numpy as np
+from matplotlib import pyplot as plt
+from matplotlib import cm as colormaps
+from matplotlib.colors import to_rgb
+
+SERVERPATH = '/mnt/hux/mudata'
+DATAPATH = 'data/'
+FIGUREPATH = 'figures/'
+FIGSAVEFORMAT = '.svg'
+
+LABELFONTSIZE = 6
+COLORS = {
+    'spk': 'black',
+    'tonicspk': 'C0',
+    'burst': 'C3',
+    'pupil': 'black',
+    'run': 'green',
+    'sit': 'gray',
+    'imfs': colormaps.viridis,
+    'hht': 'black',
+    'spontaneous': 'gray',
+    'sparsenoise': 'pink',
+    'dark': 'black',
+    'natmov': 'darkblue',
+    'natmov_opto': 'darkblue',
+    'elevation': list(to_rgb('purple')) + [0.5],
+    'azimuth': 'purple',
+    'saccade': 'purple',
+    'phasebin2': 'mediumaquamarine',
+    'phasebin1': 'violet'
+}
+plt.rcParams['figure.dpi'] = 180
+plt.rcParams['axes.labelsize'] = LABELFONTSIZE
+plt.rcParams['xtick.labelsize'] = LABELFONTSIZE 
+plt.rcParams['ytick.labelsize'] = LABELFONTSIZE 
+plt.rcParams['legend.fontsize'] = LABELFONTSIZE 
+
+
+NPUPILAREABINS = 10
+NGAMSPLINES = 20
+
+MINRATE = 0.01 # spikes / second
+BURSTCRITERIA = '(fp_dtsilent BETWEEN 0.099 and 0.101) AND (fp_dtburst BETWEEN 0.0039 AND 0.0041)'
+
+FIG1BEXAMPLEKEY = {'m': 'BL6_2014_0191', 's': 6, 'e': 4, 'u': 8}
+FIG1BEXAMPLETRANGE = 220, 520  # seconds
+FIG1BTBINWIDTH = 2.5  # seconds
+FIG1DEXAMPLEKEY = {'m': 'BL6_2014_0191', 's': 6, 'e': 4}
+FIG1DEXAMPLEIMF = 2
+FIG2BEXAMPLEKEY = {'m': 'PVCre_2019_0002', 's': 8, 'e': 8, 'u': 2}
+FIG2BEXAMPLEIMF = 2
+FIG2BEXAMPLETRANGE = 360, 420
+FIG5AEXAMPLEKEY = {'m':'PVCre_2018_0003', 's':3, 'e':3, 'u':53}
+
+# Locomotion bout detection
+RUNTHRESHOLD = 1
+MINRUNTIME = 2
+MAXSITTIME = 2
+MINRUNPROP = 0.5
+
+# MAXIBI = 0.5
+
+NIMFCYCLES = 4
+NSPIKES = 8
+NSHUFFLES = 1000
+SHUFFLE_BINWIDTH = 0.3 # seconds
+
+FREQUENCYBINS = np.logspace(-3, 0, 7)
+FREQUENCYXPOS = np.log10(FREQUENCYBINS[:-1]) + 0.25
+FREQUENCYTICKS = np.linspace(-3, 0, 4)
+FREQUENCYTICKLABELS = ["$10^{%d}$" % tick for tick in FREQUENCYTICKS]
+PHASEBINS = np.linspace(-np.pi, np.pi, 9, endpoint=True)
+PHASETICKS = PHASEBINS[::4]
+PHASETICKLABELS = ['-\u03C0', '0', '\u03C0']
+
+TRIGGEREDAVERAGES = {
+    'run': {
+        'dt':       0.05,
+        'bw':       0.1,
+        'pre':      -5,
+        'post':     5,
+        'baseline': [-5, -3]
+        },
+    'sit': {
+        'dt':       0.05,
+        'bw':       0.1,
+        'pre':      -5,
+        'post':     5,
+        'baseline': [-5, -3]
+        },
+    'saccade': {
+        'dt':       0.05,
+        'bw':       0.1,
+        'pre':      -5,
+        'post':     5,
+        'baseline': [-5, -3]
+        },
+    'trial_on': {
+        'dt':       0.05,
+        'bw':       0.1,
+        'pre':      -1,
+        'post':     5,
+        'baseline': [-1, 0]
+        }
+    }
+
+BEHAVEXCLUSIONS = {
+    'run':      [-2, 2],
+    'sit':      [-1, 4],
+    'saccade':  [-2, 2]
+}

+ 323 - 0
phase_tuning.py

@@ -0,0 +1,323 @@
+import os
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+import argparse
+
+from scipy.interpolate import interp1d
+from util import zero_runs
+from hht import HHT
+
+from parameters import *
+from util import (load_data, filter_units, switch_ranges, merge_ranges, angle_subtract,
+                  circmean, get_trials, shuffle_bins)
+
+
+def resample(data, old_tpts, new_tpts, axis=0, fill_value='extrapolate'):
+    """
+    Use linear interpolation to re-sample data.
+    """
+    # interpolate time-course with linear splines
+    func = interp1d(old_tpts, data, axis=axis, fill_value=fill_value)
+    # get new time-course
+    interpolated_data = func(new_tpts)
+    return interpolated_data
+
+def phase2rank(alpha):
+    """
+    Convert angles to circular ranks.
+    """
+    n = len(alpha)
+    ranks = np.full(n, np.nan)
+    ranks[alpha.argsort()] = np.arange(n)
+    return 2 * np.pi * ranks / n
+
+def rank2phase(rank, alpha):
+    """
+    Convert a circular rank back to phase in the original distribution.
+    """
+    n = len(alpha)
+    # convert circular rank to linear rank
+    linrank = n * rank / 2 / np.pi
+    return np.sort(alpha)[np.round(linrank).astype('int')]
+
+def inds2train(inds, length):
+    """
+    Convert event indices into binary time-series.
+
+    Parameters
+    ----------
+    inds : 1D array
+        event indices
+    length : int
+        total length of segment in which events occur
+
+    Returns
+    -------
+    event_train : ndarray
+        binary time-series
+    """
+    # initialize output array
+    event_train = np.zeros(length, dtype='uint8')
+    event_train[inds] = 1
+    return event_train
+
+def times2train(evts, tpts):
+    """
+    Convert event times into binary time-series.
+
+    Parameters
+    ----------
+    evts : 1D array
+        event times
+    tpts : 1D array
+        time base in which events occur
+
+    Returns
+    -------
+    event_train : ndarray
+        binary time-series array
+    """
+    # clip events that fall out of time base
+    evts_in_tpts = evts[(evts > tpts.min()) & (evts < tpts.max())]
+    # convert times to indices
+    evis = tpts.searchsorted(evts_in_tpts)
+    # get the event train
+    ev_train = inds2train(evis, len(tpts))
+    return ev_train
+
+def modified_mrl2(alpha, w=None, axis=0):
+    """
+    A bias-free measure of the squared mean resultant length [1].
+
+    Parameters
+    ----------
+    alpha : ndarray
+        array of angles
+    w : ndarray
+        array of weights, must be same shape as alpha
+    axis : int, None
+        axis across which to compute mean
+
+    Returns
+    -------
+    out : ndarray
+        bias-corrected squared mean resultant length
+
+    Notes
+    -----
+    -   taking the square-root of this measure does *not* provide a bias-free
+        measure of the mean resultant length, see [1].
+
+    References
+    ----------
+    [1] Kutil, R. (2012). Biased and unbiased estimation of the circular mean
+    resultant length and its variance. Statistics, 46(4), 549-561.
+    """
+    mrl, _ = circmean(alpha, w=w, axis=axis)
+    n = alpha.shape[axis]
+    return (n / (n - 1)) * (mrl ** 2 - (1 / n))
+
+
+def phase_tuning(phase, spk_train, shuffle_binwidth=1000, n_shuffles=1000):
+    ranks = phase2rank(phase) - np.pi
+    assert len(phase) == len(spk_train)
+    # Get phase ranks where spikes occur
+    spk_ranks = ranks[spk_train == 1]
+    # Compute modified mean vector length
+    r = modified_mrl2(spk_ranks)
+    # Compute mean rank
+    _, mean_rank = circmean(spk_ranks)
+    # Convert rank back to phase
+    theta = rank2phase(mean_rank + np.pi, phase)
+    # Compute tuning strength for shuffled spike trains
+    r_shf = np.full(n_shuffles, np.nan)
+    for shf_i in range(n_shuffles):
+        # Shuffle time bins of spike train
+        spk_train_shf = shuffle_bins(spk_train, shuffle_binwidth)
+        # Take phase of shuffled train
+        spk_ranks_shf = ranks[spk_train_shf == 1]
+        # Compute tuning strength
+        r_shf[shf_i] = modified_mrl2(spk_ranks_shf)
+    p = (r_shf > r).sum() / n_shuffles
+    return r, theta, p
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('e_name')
+    parser.add_argument('-s', '--spk_types', nargs='+', default=['tonicspk', 'burst'])
+    parser.add_argument('-t', '--tranges', default='')
+    args = parser.parse_args()
+
+    if args.tranges:
+        assert args.tranges in ['run', 'sit', 'desync', 'sizematched', 'nosaccade', 'noopto']
+
+    df_pupil = load_data('pupil', [args.e_name])
+    df_pupil.set_index(['m', 's', 'e'], inplace=True)
+    df_spikes = load_data('spikes', [args.e_name])
+    df_spikes.set_index(['m', 's', 'e'], inplace=True)
+    df_spikes = filter_units(df_spikes, MINRATE)
+
+    ## TODO: find a better way to integrate saccades
+    if 'saccade' in args.spk_types:
+        df_spikes['saccade_times'] = [df_pupil.loc[idx]['saccade_times'] for idx, unit in df_spikes.iterrows()]
+
+    # Load data for requested time ranges
+    if args.tranges in ['nosaccade']:
+        df_pupil = load_data('pupil', [args.e_name]).set_index(['m', 's', 'e'])
+    elif args.tranges in ['desync', 'sizematched']:
+        df_hht = load_data('hht', [args.e_name]).set_index(['m', 's', 'e'])
+    elif args.tranges in ['run', 'sit']:
+        df_run = load_data('ball', [args.e_name]).set_index(['m', 's', 'e'])
+    elif args.tranges in ['noopto']:
+        df_trials = load_data('trials', [args.e_name])
+        df_trials.rename(columns={'trial_on_time':'trial_on_times', 'trial_off_time':'trial_off_times'}, inplace=True)
+        df_trials = df_trials.apply(get_trials, stim_id=-1, axis='columns').set_index(['m', 's', 'e'])
+
+    seriess = []
+    for idx, row in tqdm(df_pupil.iterrows(), total=len(df_pupil)):    
+        pupil_area = row['pupil_area']
+        pupil_tpts = row['pupil_tpts']
+        # Get IMFs
+        pupil_fs = 1 / np.diff(pupil_tpts).mean()
+        hht = HHT(pupil_area, pupil_fs)
+        hht.emd()
+        # Get phases and frequencies
+        hht.hsa()
+        hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
+        imf_phases = hht.phase.T
+        imf_freqs = hht.characteristic_frequency
+        imf_power = hht.power_ratio
+
+        # Get time-ranges
+        if args.tranges in ['run', 'sit']:
+            try:
+                tranges = df_run.loc[idx, '%s_bouts' % args.tranges]
+            except KeyError:
+                print("No run data found for ", idx)
+                continue
+            if args.tranges == 'run':
+                dt0 = BEHAVEXCLUSIONS['run'][1]
+                dt1 = BEHAVEXCLUSIONS['sit'][0]
+            elif args.tranges == 'sit':
+                dt0 = BEHAVEXCLUSIONS['sit'][1]
+                dt1 = BEHAVEXCLUSIONS['run'][0]
+            ext = np.ones_like(tranges) * np.array([dt0, dt1])
+            tranges = tranges + ext
+            tranges = np.row_stack([trange for trange in tranges if trange[0] < trange[1]])
+            tranges = [tranges for imf in range(hht.n_imfs)]
+        elif args.tranges in ['desync', 'sizematched']:
+            try:
+                tranges = df_hht.loc[idx, '%s_bouts' % args.tranges]
+            except KeyError:
+                print("No HHT data found for ", idx)
+                continue
+        elif args.tranges in ['nosaccade']:
+            saccade_times = df_pupil.loc[idx, 'saccade_times']
+            saccade_tranges = np.column_stack([saccade_times, saccade_times])
+            saccade_tranges += np.array(BEHAVEXCLUSIONS['saccade'])
+            saccade_tranges = merge_ranges(saccade_tranges, dt=(1 / pupil_fs))
+            tranges = switch_ranges(
+                saccade_tranges,
+                dt=(1 / pupil_fs),
+                minval=pupil_tpts.min(),
+                maxval=pupil_tpts.max()
+                )
+            tranges = [tranges for imf in range(hht.n_imfs)]
+        elif args.tranges in ['opto', 'noopto']:
+            try:
+                trial_ids = df_trials.loc[idx, 'trial_id']
+                opto_trials = df_trials.loc[idx, 'opto_trials']
+                trial_on_time = df_trials.loc[idx, 'trial_on_times']
+                trial_off_time = df_trials.loc[idx, 'trial_off_times']
+            except KeyError:
+                print("No trial data found for ", idx)
+                continue
+            t0s = trial_on_time
+            t1s = trial_off_time
+            tranges = np.column_stack([t0s, t1s])
+            tranges = [tranges for imf in range(hht.n_imfs)]
+        elif args.tranges in ['half1', 'half2']:
+            t0, t1 = pupil_tpts.min(), pupil_tpts.max()
+            half_length = (t1 - t0) / 2
+            if args.tranges == 'half1':
+                tranges = [np.array([[t0, t0 + half_length]]) for imf in range(hht.n_imfs)]
+            if args.tranges == 'half2':
+                tranges = [np.array([[t0 + half_length, t1]]) for imf in range(hht.n_imfs)]
+        elif args.tranges in ['split1', 'split2']:
+            imf_cycles = [pupil_tpts[np.where(np.diff(phase) < -np.pi)[0]] for phase in hht.phase.T]
+            imf_cycles = [np.concatenate([pupil_tpts[:1], cycles, pupil_tpts[-1:]]) for cycles in imf_cycles]
+            cycle_tranges = [np.column_stack([cycles[:-1], cycles[1:]]) for cycles in imf_cycles]
+            if args.tranges == 'split1':
+                tranges = [cycles[0::2] for cycles in cycle_tranges]
+            else:
+                tranges = [cycles[1::2] for cycles in cycle_tranges]
+
+        # Get units for this experiment
+        try:
+            df_units = df_spikes.loc[idx]
+        except KeyError:
+            print("Spikes missing for {}".format(idx))
+            continue
+
+        for _, unit in df_units.iterrows():
+            unit_tpts = np.arange(*unit['spk_tinfo'])
+            t0, t1 = row['pupil_tpts'].min(), row['pupil_tpts'].max()
+            i0, i1 = unit_tpts.searchsorted([t0, t1])
+            unit_tpts = unit_tpts[i0:i1]
+            imf_phases_resamp = resample(imf_phases, pupil_tpts, unit_tpts, axis=1)
+            spk_trains = {}
+            for spk_type in args.spk_types:
+                spk_trains[spk_type] = times2train(unit['{}_times'.format(spk_type)], unit_tpts)
+            for imf_i, phase in enumerate(imf_phases_resamp):
+                data = {
+                    'm':        idx[0],
+                    's':        idx[1],
+                    'e':        idx[2],
+                    'u':        unit['u'],
+                    'imf':      imf_i + 1,
+                    'freq':     imf_freqs[imf_i],
+                    'power':    imf_power[imf_i]
+                    }
+                # Get time ranges to analyze
+                if args.tranges:
+                    tranges_imf = tranges[imf_i]
+                else:
+                    tranges_imf = np.array([[t0, t1]])
+                iranges_imf = unit_tpts.searchsorted(tranges_imf)
+                unit_fs = 1 / np.diff(unit_tpts).mean()
+                binwidth = np.floor(unit_fs * SHUFFLE_BINWIDTH).astype('int')
+                for spk_type, spk_train in spk_trains.items():
+                    # Take only data in ranges
+                    if len(iranges_imf) > 0:
+                        phase_clipped = np.concatenate([phase[i0:i1] for i0, i1 in iranges_imf])
+                        train_clipped = np.concatenate([spk_train[i0:i1] for i0, i1 in iranges_imf])
+                    else:
+                        train_clipped = np.array([])
+                    # Check that there are enough spikes to do analysis
+                    nspikes = train_clipped.sum()
+                    if nspikes < NSPIKES:
+                        r = theta = p = np.nan
+                    else:
+                        r, theta, p = phase_tuning(
+                            phase_clipped, train_clipped,
+                            shuffle_binwidth=binwidth, n_shuffles=NSHUFFLES
+                            )
+                    data['_'.join([spk_type, 'n'])] = nspikes
+                    data['_'.join([spk_type, 'strength'])] = r
+                    data['_'.join([spk_type, 'phase'])] = theta
+                    data['_'.join([spk_type, 'p'])] = p
+                seriess.append(pd.Series(data=data))
+    df_tuning = pd.DataFrame(seriess)
+
+    if args.tranges:
+        filename = 'phasetuning_{}_{}.pkl'.format(args.e_name, args.tranges)
+    elif args.spk_types == ['saccade']:
+        filename = 'phasetuning_{}_{}.pkl'.format(args.e_name, 'saccades')
+    else:
+        filename = 'phasetuning_{}.pkl'.format(args.e_name)
+    df_tuning.to_pickle(DATAPATH + filename)
+
+
+

+ 114 - 0
size_tuning.py

@@ -0,0 +1,114 @@
+import argparse
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+
+import statsmodels.api as sm
+from statsmodels.formula.api import ols
+from statsmodels.stats.anova import anova_lm
+#from pygam import PoissonGAM, LinearGAM, GammaGAM
+#from pygam import l as linear_term
+#from pygam import s as spline_term
+
+from parameters import DATAPATH, MINRATE, NPUPILAREABINS, NGAMSPLINES
+from util import load_data, unbiased_variance, times_to_counts, _resample_data, filter_units, sort_data
+
+
+def arousal_rate_matrix(spk_rates, pupil_area, **kwargs):
+    """
+    Get a matrix of firing rate histograms for different levels of arousal.
+    """
+    area_bins, binned_area, binned_rates = get_binned_rates(spk_rates, pupil_area, **kwargs)
+    rate_bins = np.linspace(spk_rates.min(), spk_rates.max() + np.finfo('float').eps, kwargs['nbins'] + 1)
+    # Fill pupil area rate matrix
+    rate_mat = np.zeros((len(rate_bins) - 1, len(area_bins) - 1))
+    for bin_i, rates in enumerate(binned_rates):
+        rate_hist = np.histogram(rates, bins=rate_bins, density=True)[0] / (len(rate_bins) - 1)
+        rate_mat[:, bin_i] += rate_hist
+
+    return rate_bins, area_bins, rate_mat
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('e_name')
+    parser.add_argument('spk_type')
+    args = parser.parse_args()
+    df_pupil = load_data('pupil', [args.e_name])
+    df_spikes = load_data('spikes', [args.e_name])
+    df_spikes = filter_units(df_spikes, MINRATE)
+    df = pd.merge(df_pupil, df_spikes, on=['m', 's', 'e'])
+
+    seriess = []
+    for idx, row in tqdm(df.iterrows(), total=len(df)):
+        data = {
+            'm':                row['m'],
+            's':                row['s'],
+            'e':                row['e'],
+            'u':                row['u'],
+            }
+
+        if args.spk_type == 'ratio':
+            row = times_to_counts(row, ['tonicspk', 'burstspk'], t0='pupil_tpts')
+            row['ratio_counts'] = (row['burstspk_counts'] + np.finfo('float').eps) / row['tonicspk_counts']
+        else:
+            row = times_to_counts(row, [args.spk_type], t0='pupil_tpts')
+        row = _resample_data(row, ['pupil_area'], t0='pupil_tpts')
+        # Get pupil area gradient
+        row['pupil_gradient'] = np.gradient(row['pupil_area'])
+
+        spk_counts = row['{}_counts'.format(args.spk_type)]
+        data['rate_max'] = spk_counts.max()
+        data['rate_min'] = spk_counts.min()
+
+        for var in ['area', 'gradient']:
+            regressor = row['pupil_{}'.format(var)]
+
+            data['{}_max'.format(var)] = regressor.max()
+            data['{}_min'.format(var)] = regressor.min()
+
+            if var == 'area':
+                # Bin rates
+                bin_min, bin_max = np.percentile(regressor, [2.5, 97.5])
+                #bin_min, bin_max = regressor.min(), regressor.max()
+                bins = np.linspace(bin_min, bin_max, NPUPILAREABINS + 1)
+            elif var == 'gradient':
+                bin_max = np.percentile(np.abs(regressor), 97.5)
+                bins_neg = np.linspace(-1 * bin_max, 0, NPUPILAREABINS + 1)
+                bins_pos = np.linspace(0, bin_max, NPUPILAREABINS + 1)
+                bins = np.concatenate([bins_neg[:-1], bins_pos])
+            binned_counts = sort_data(spk_counts, regressor, bins=bins)
+            data[f'{var}_means'] = np.array([counts.mean() for counts in binned_counts])
+
+            # ANOVA (linear OLS with categorical pupil size)
+            digitized_regressor = np.digitize(regressor, bins=bins).clip(1, NPUPILAREABINS)
+            df_anova = pd.DataFrame(np.column_stack([spk_counts, digitized_regressor]), columns=['count', 'bin'])
+            anova_model = ols('count ~ C(bin)', data=df_anova).fit() # use formula to specify model
+            data['{}_p'.format(var)] = anova_model.f_pvalue
+
+            # Linear model (linear OLS with sorted categorical pupil size)
+            mean_counts = np.array([counts.mean() for counts in binned_counts if len(counts) > 0])
+            mean_counts.sort()
+            y = mean_counts
+            X = sm.add_constant(np.arange(len(mean_counts))) # use design matrix to specify model
+            linear_model = sm.OLS(y, X).fit()
+            intercept, slope = linear_model.params
+            data['{}_slope'.format(var)] = slope
+
+            """
+            # Poisson GAM model for rates
+            gam_model = PoissonGAM(spline_term(0, n_splines=NGAMSPLINES), fit_intercept=True)
+            # Search over parameter controlling smoothness (lambda) for best fit
+            gam_model.gridsearch(regressor[..., np.newaxis], spk_counts, keep_best=True, progress=False)
+            data['{}_fit'.format(var)] = gam_model.predict(gam_model.generate_X_grid(term=0))
+            data['{}_dev'.format(var)] = gam_model.score(regressor[..., np.newaxis], spk_counts)
+            """
+
+            # Get rate variance in pupil area bins
+            binned_counts = sort_data(spk_counts, regressor, bins=bins)
+            data['{}_var'.format(var)] = np.array([unbiased_variance(rates) for rates in binned_counts])
+
+        seriess.append(pd.Series(data=data))
+    df_tuning = pd.DataFrame(seriess)
+    filename = 'sizetuning_{}_{}.pkl'.format(args.e_name, args.spk_type)
+    df_tuning.to_pickle(DATAPATH + filename)

+ 107 - 0
triggered_spiking.py

@@ -0,0 +1,107 @@
+import os
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+import argparse
+
+import quantities as pq
+from neo import SpikeTrain
+from elephant.statistics import instantaneous_rate
+from elephant.kernels import GaussianKernel
+
+from parameters import DATAPATH, MINRATE, NSPIKES, TRIGGEREDAVERAGES, NSHUFFLES, SHUFFLE_BINWIDTH
+from util import load_data, filter_units, shuffle_bins
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('e_name')
+    parser.add_argument('region')
+    parser.add_argument('-t', '--triggers', nargs='+', default=['saccade', 'run', 'sit'])
+    parser.add_argument('-s', '--spk_types', nargs='+', default=['tonicspk', 'burst'])
+    args = parser.parse_args()
+
+    df_run = load_data('ball', [args.e_name], region=args.region)
+    df_run['run_times'] = df_run['run_bouts'].apply(lambda x: x[:, 0])
+    df_run['sit_times'] = df_run['run_bouts'].apply(lambda x: x[:, 1])
+    df_pupil = load_data('pupil', [args.e_name], region=args.region)
+    df_triggers = pd.merge(df_run, df_pupil)
+
+    if 'trial_on' in args.triggers:
+        df_trials = load_data('trials', [args.e_name], region=args.region)
+        df_trials.rename(columns={'trial_on_time':'trial_on_times'}, inplace=True)
+        df_triggers = pd.merge(df_triggers, df_trials)
+    if 'burst' in args.triggers:
+        import sys; sys.exit()
+
+    df_triggers.set_index(['m', 's', 'e'], inplace=True)
+
+    df_spikes = load_data('spikes', [args.e_name], region=args.region).set_index(['m', 's', 'e'])
+    df_spikes = filter_units(df_spikes, MINRATE)
+
+    seriess = []
+    for idx, unit in tqdm(df_spikes.iterrows(), total=len(df_spikes)):
+        data = {
+            'm':idx[0],
+            's':idx[1],
+            'e':idx[2],
+            'u':unit['u']
+            }
+        for trigger in args.triggers:
+            #try:
+            trigger_times = df_triggers.loc[idx]['%s_times' % trigger]
+            #except KeyError:
+                #continue
+
+            pars = TRIGGEREDAVERAGES[trigger]
+            fs = pars['dt'] * pq.s
+            kernel = GaussianKernel(pars['bw'] * pq.s)
+
+            for spk_type in args.spk_types:
+                # Get inst. rate for whole experiment
+                spk_times = unit['%s_times' % spk_type]
+                if len(spk_times) < NSPIKES:
+                    continue
+                t0 = min([unit['spk_tinfo'][0], spk_times.min()])
+                t1 = max([unit['spk_tinfo'][1], spk_times.max()])
+                spk_train = SpikeTrain(spk_times, t_start=t0 * pq.s, t_stop=t1 * pq.s, units='s')
+                inst_rate = instantaneous_rate(spk_train, sampling_period=fs, kernel=kernel)
+                inst_rate = inst_rate.squeeze().magnitude
+
+                # Get responses to each trigger
+                spk_tpts = np.linspace(t0, t1, inst_rate.shape[0])
+                trigger_times = trigger_times[trigger_times < (spk_tpts.max() - pars['post'])]
+                trigger_times = trigger_times[trigger_times > (spk_tpts.min() - pars['pre'])]
+                i0s = spk_tpts.searchsorted(trigger_times) + int(pars['pre'] / pars['dt'])
+                i1s = spk_tpts.searchsorted(trigger_times) + int(pars['post'] / pars['dt'])
+                responses = np.row_stack([inst_rate[i0:i1] for i0, i1 in zip(i0s, i1s)])
+
+                # Baseline normalize responses
+                response_tpts = np.linspace(pars['pre'], pars['post'], responses.shape[1])
+                b0, b1 = response_tpts.searchsorted(pars['baseline'])
+                responses = (responses.T - responses[:, b0:b1].mean(axis=1)).T
+
+                # Take mean
+                triggered_average = responses.mean(axis=0)
+
+                # Get triggereg averages from shuffled rates
+                triggered_average_shf = np.full((NSHUFFLES, triggered_average.shape[0]), np.nan)
+                for shf_i in range(NSHUFFLES):
+                    shuffle_binwidth = int(SHUFFLE_BINWIDTH / pars['dt'])
+                    inst_rate_shf = shuffle_bins(inst_rate, shuffle_binwidth)
+                    responses_shf = np.row_stack([inst_rate_shf[i0:i1] for i0, i1 in zip(i0s, i1s)])
+                    responses_shf = (responses_shf.T - responses_shf[:, b0:b1].mean(axis=1)).T
+                    triggered_average_shf[shf_i] = responses_shf.mean(axis=0)
+                ci_low, ci_high = np.percentile(triggered_average_shf, [2.5, 97.5], axis=0)
+                sig = (triggered_average < ci_low).any() | (triggered_average > ci_high).any()
+
+                data[f'{trigger}_{spk_type}_response'] = triggered_average
+                data[f'{trigger}_{spk_type}_tpts'] = response_tpts
+                data[f'{trigger}_{spk_type}_sig'] = sig
+        seriess.append(pd.Series(data=data))
+
+    df_resp = pd.DataFrame(seriess)
+    filename = f'responses_{args.e_name}_{args.region}.pkl'
+    df_resp.to_pickle(DATAPATH + filename)
+
+
+

+ 891 - 0
util.py

@@ -0,0 +1,891 @@
+import numpy as np
+import pandas as pd
+import quantities as pq
+from matplotlib import pyplot as plt
+from scipy.interpolate import interp1d
+from statsmodels.stats.multitest import fdrcorrection
+from scipy.stats import norm, sem
+from sklearn.mixture import GaussianMixture
+from neo import SpikeTrain
+from elephant.statistics import instantaneous_rate
+from elephant.kernels import GaussianKernel
+
+from parameters import *
+
+
+## Data handling
+def df2keys(df):
+    """Return MSE keys for all entries in a DataFrame."""
+
+    df = df.reset_index()
+    keys = [key for key in df.columns if key in ['m', 's', 'e', 'u']]
+
+    return [{key: val for key, val in zip(keys, vals)} for vals in df[keys].values]
+
+def key2idx(key):
+    """Return DataFrame index tuple for the given key."""
+    return tuple([val for k, val in key.items()])
+
+def df_metadata(df):
+    """Print metadata for a DataFrame."""
+
+    df = df.reset_index()
+    print("No. mice: {}".format(len(df.groupby('m'))))
+    print("No. series: {}".format(len(df.groupby(['m', 's']))))
+    print("No. experiments: {}".format(len(df.groupby(['m', 's', 'e']))))
+    if 'u' in df.columns:
+        print("No. units: {}".format(len(df.groupby(['m', 's', 'e', 'u']))))
+        for idx, df in df.groupby(['m', 's', 'e']):
+            print("\n{} s{:02d} e{:02d}".format(idx[0], idx[1], idx[2]))
+            print("    {} units".format(len(df)))
+
+def load_data(data, conditions, **kwargs):
+    """
+    Load data of a specified type and recording region, pooling across all
+    requested conditions.
+    """
+    dfs = []
+    for condition in conditions:
+        filename = '{}_{}'.format(data, condition)
+        for kw, arg in kwargs.items():
+            filename = filename + '_{}'.format(arg)
+        filename = filename + '.pkl'
+        print("Loading: ", filename)
+        df = pd.read_pickle(DATAPATH + filename)
+        if 'condition' not in df.columns:
+            df['condition'] = condition
+        dfs.append(df)
+    return pd.concat(dfs, axis='index')
+
+## Plotting utils
+
+def set_plotsize(w, h=None, ax=None):
+    """
+    Set the size of a matplotlib axes object in cm.
+
+    Parameters
+    ----------
+    w, h : float
+        Desired width and height of plot, if height is None, the axis will be
+        square.
+
+    ax : matplotlib.axes
+        Axes to resize, if None the output of plt.gca() will be re-sized.
+
+    Notes
+    -----
+    - Use after subplots_adjust (if adjustment is needed)
+    - Matplotlib axis size is determined by the figure size and the subplot
+      margins (r, l; given as a fraction of the figure size), i.e.
+      w_ax = w_fig * (r - l)
+    """
+    if h is None: # assume square
+        h = w
+    w /= 2.54 # convert cm to inches
+    h /= 2.54
+    if not ax: # get current axes
+        ax = plt.gca()
+    # get margins
+    l = ax.figure.subplotpars.left
+    r = ax.figure.subplotpars.right
+    t = ax.figure.subplotpars.top
+    b = ax.figure.subplotpars.bottom
+    # set fig dimensions to produce desired ax dimensions
+    figw = float(w)/(r-l)
+    figh = float(h)/(t-b)
+    ax.figure.set_size_inches(figw, figh)
+
+def clip_axes_to_ticks(ax=None, spines=['left', 'bottom'], ext={}):
+    """
+    Clip the axis lines to end at the minimum and maximum tick values.
+
+    Parameters
+    ----------
+    ax : matplotlib.axes
+        Axes to resize, if None the output of plt.gca() will be re-sized.
+
+    spines : list
+        Axes to keep and clip, axes not included in this list will be removed.
+        Valid values include 'left', 'bottom', 'right', 'top'.
+
+    ext : dict
+        For each axis in ext.keys() ('left', 'bottom', 'right', 'top'),
+        the axis line will be extended beyond the last tick by the value
+        specified, e.g. {'left':[0.1, 0.2]} will results in an axis line
+        that extends 0.1 units beyond the bottom tick and 0.2 unit beyond
+        the top tick.
+    """
+    if ax is None:
+        ax = plt.gca()
+    spines2ax = {
+        'left': ax.yaxis,
+        'top': ax.xaxis,
+        'right': ax.yaxis,
+        'bottom': ax.xaxis
+    }
+    all_spines = ['left', 'bottom', 'right', 'top']
+    for spine in spines:
+        low = min(spines2ax[spine].get_majorticklocs())
+        high = max(spines2ax[spine].get_majorticklocs())
+        if spine in ext.keys():
+            low += ext[spine][0]
+            high += ext[spine][1]
+        ax.spines[spine].set_bounds(low, high)
+    for spine in [spine for spine in all_spines if spine not in spines]:
+        ax.spines[spine].set_visible(False)
+
+def p2stars(p):
+    if p <= 0.0001:
+        return '***'
+    elif p <= 0.001:
+        return '**'
+    elif p<= 0.05:
+        return '*'
+    else:
+        return ''
+
+def violin_plot(dists, colors, ax=None, logscale=False):
+    if type(colors) is list:
+        assert len(colors) == len(dists)
+    if ax is None:
+        fig, ax = plt.subplots()
+    violins = ax.violinplot(dists, showmedians=True, showextrema=False)
+    for violin, color in zip(violins['bodies'], colors):
+        violin.set_facecolor('none')
+        violin.set_edgecolor(color)
+        violin.set_alpha(1)
+        violin.set_linewidth(2)
+    violins['cmedians'].set_color('black')
+    for pos, dist in enumerate(dists):
+        median = np.median(dist)
+        text = f'{median:.2f}'
+        if logscale:
+            text = f'{10 ** median:.2f}'
+        ax.text(pos + 1.4, median, text, va='center', ha='center', rotation=-90, fontsize=LABELFONTSIZE)
+    ax.set_xticks(np.arange(len(dists)) + 1)
+    ax.tick_params(bottom=False)
+    return ax
+
+def pupil_area_rate_heatmap(df, cmap='gray', max_rate='high', example=None):
+    """
+    Plot a heatmap of event rates (tonic spikes or bursts) where each row is a neuron and each column is a pupil size bin.
+
+    Parameters
+    ----------
+    df : pandas.DataFrame
+        Dataframe with neurons in the rows and mean firing rates for each pupil size bin in a column called 'area_means'.
+        
+    cmap : str or Matplotlib colormap object
+
+    max_rate : str
+        Is the max event rate expected to at 'high' or 'low' pupil sizes?
+
+    example : dict
+        If not none, MSEU key passed will be used to highlight example neuron.
+    """
+    fig = plt.figure()
+    
+    # Find out which pupil size bin has the highest firing rate
+    df['tuning_max'] = df['area_means'].apply(np.argmax)
+    # Min-max normalize firing rates for each neuron
+    df['tuning_norm'] = df['area_means'].apply(lambda x: (x - x.min()) / (x.max() - x.min()))
+
+    # Get heatmap for neurons with significant rate differences across pupil size bins
+    df_sig = df.query('area_p <= 0.05').sort_values('tuning_max')
+    heatmap_sig = np.row_stack(df_sig['tuning_norm'])
+    n_sig = len(df_sig)  # number of neurons with significant differences
+
+    # Make axis with size proportional to the fraction of significant neurons
+    n_units = len(df)  # total number of neurons
+    ax1_height = n_sig / n_units
+    ax1 = fig.add_axes([0.1, 0.9 - (0.76 * ax1_height), 0.8, (0.76 * ax1_height)])
+    # Plot heatmap for signifcant neurons
+    mat = ax1.matshow(heatmap_sig, cmap=cmap, aspect='auto')
+    # Scatter plot marking pupil size bin with maximum
+    ax1.scatter(df_sig['tuning_max'], np.arange(len(df_sig)) - 0.5, s=0.5, color='black', zorder=3)
+
+    # Format axes
+    ax1.set_xticks([])
+    yticks = np.insert(np.arange(0, n_sig, 20), -1, n_sig)  # ticks every 20 neurons, and final count
+    ax1.set_yticks(yticks - 0.5)
+    ax1.set_yticklabels(yticks)
+    ax1.set_ylabel('Neurons')
+
+    # Make a colorbar
+    cbar = fig.colorbar(mat, ax=ax1, ticks=[0, 1], location='top', shrink=0.75)
+    cbar.ax.set_xticklabels(['min', 'max'])
+    cbar.ax.set_xlabel('Spikes', labelpad=-5)
+
+    # Print dotted line and label for neurons considered to have 'monotonic' modulation profiles
+    # (highest event rate at one of the pupil size extremes)
+    if max_rate == 'high':
+        n_mon = len(df_sig.query('tuning_max >= 9'))
+        print("Monotonic increasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
+        ax1.axvline(8.5, lw=2, ls='--', color='white')
+        ax1.set_title(r'Monotonic$\rightarrow$', fontsize=LABELFONTSIZE, loc='right', pad=0)
+    elif max_rate == 'low':
+        n_mon = len(df_sig.query('tuning_max <= 1'))
+        print("Monotonic decreasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
+        ax1.axvline(0.5, lw=2, ls='--', color='white')
+        ax1.set_title(r'$\leftarrow$Monotonic', fontsize=LABELFONTSIZE, loc='left', pad=0)
+
+    # Get heatmap for neurons without significant rate differences acrss pupil sizes
+    df_ns = df.query('area_p > 0.05').sort_values('tuning_max')
+    heatmap_ns = np.row_stack(df_ns['tuning_norm'])
+    
+    # Make axis with size proportional to the fraction of non-significant neurons
+    n_ns = len(df_ns)
+    ax2_height = n_ns / n_units
+    ax2 = fig.add_axes([0.1, 0.1, 0.8, (0.76 * ax2_height)])
+
+    # Plot heatmap for neurons without significant rate differences acrss pupil sizes
+    mat = ax2.matshow(heatmap_ns, cmap='Greys', vmax=2, aspect='auto')
+    ax2.scatter(df_ns['tuning_max'], np.arange(n_ns) - 0.5, s=0.5, color='black', zorder=3)
+
+    # Format axes
+    ax2.xaxis.set_ticks_position('bottom')
+    ax2.set_xticks([-0.5, 4.5, 9.5])
+    ax2.set_xticklabels([0, 0.5, 1])  # x-axis ticks mark percentiles of pupil size range
+    ax2.set_xlim(right=9.5)
+    ax2.set_xlabel('Pupil size (norm.)')
+    yticks = np.insert(np.arange(0, n_ns, 20), -1, n_ns)  # ticks every 20 neurons, and final count
+    ax2.set_yticks(yticks - 0.5)
+    ax2.set_yticklabels(yticks)
+
+    # Highligh example neuron by circling scatter dot
+    if example is not None:
+        try:  # first check if example neuron is among significant neurons
+            is_ex = df_sig.index == tuple([v for k, v in example.items()])
+            assert is_ex.any()
+            ex_max = df_sig['tuning_max'][is_ex]
+            ax1.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
+        except:
+            is_ex = df_ns.index == tuple([v for k, v in example.items()])
+            ex_max = df_ns['tuning_max'][is_ex]
+            ax2.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
+
+    return fig
+
+def cumulative_histogram(data, bins, color='C0', ax=None):
+    """Convenience function, cleaner looking plot than plt.hist(..., cumulative=True)."""
+    if ax is None:
+        fig, ax = plt.subplots()
+    weights = np.ones_like(data) / len(data)
+    counts, _ = np.histogram(data, bins=bins, weights=weights)
+    ax.plot(bins[:-1], np.cumsum(counts), color=color)
+    return ax
+
+def cumulative_hist(x, bins, density=True, ax=None, **kwargs):
+    weights = np.ones_like(x)
+    if density:
+        weights = weights / len(x)
+    counts, _ = np.histogram(x, bins=bins, weights=weights)
+    if ax is None:
+        fig, ax = plt.subplots()
+    xs = np.insert(bins, np.arange(len(bins) - 1), bins[:-1])
+    ys = np.insert(np.insert(np.cumsum(counts), np.arange(len(counts)), np.cumsum(counts)), 0, 0)
+    ax.plot(xs, ys, lw=2, **kwargs)
+    ax.set_xticks(bins + 0.5)
+    ax.set_yticks([0, 0.5, 1])
+    return ax, counts
+
+def phase_coupling_scatter(df, ax=None):
+    """Phase-frequency scatter plot for phase coupling."""
+    if ax is None:
+        fig, ax = plt.subplots()
+    for event in ['tonicspk', 'burst']:
+        df_sig = df.query(f'{event}_sig == True')
+        ax.scatter(np.log10(df_sig['freq']), df_sig[f'{event}_phase'], ec=COLORS[event], fc='none', lw=0.5, s=3)
+    n_sig = max([len(df.query(f'{event}_sig == True').groupby(['m', 's', 'e', 'u'])) for event in ['tonicspk', 'burst']])
+        
+    ax.set_xticks(FREQUENCYTICKS)
+    ax.set_xticklabels(FREQUENCYTICKLABELS)
+    ax.set_xlim(left=-3.075)
+    ax.set_xlabel("Inverse timescale (s$^{-1}$)")
+    
+    ax.set_yticks(PHASETICKS)
+    ax.set_yticklabels(PHASETICKLABELS)
+    ax.set_ylim([-np.pi - 0.15, np.pi + 0.15])
+    ax.set_ylabel("Preferred phase")
+    return ax
+
+def plot_circhist(angles, ax=None, bins=np.linspace(0, 2 * np.pi, 17), density=True, **kwargs):
+    """Plot a circular histogram."""
+    if ax is None:
+        fig, ax = plt.subplots(subplot_kw={'polar':True})
+    weights = np.ones_like(angles) 
+    if density:
+        weights /= len(angles)
+    counts, bins = np.histogram(angles, bins=bins, weights=weights)
+    xs = bins + (np.pi / (len(bins) - 1))
+    ys = np.append(counts, counts[0])
+    ax.plot(xs, ys, **kwargs)
+    ax.set_xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2])
+    ax.set_xticklabels(['0',  '\u03C0/2', '\u03C0', '3\u03C0/2'])
+    ax.tick_params(axis='x', pad=-5)
+    return ax, counts
+
+def coupling_strength_line_plot(df, agg=np.mean, err=sem, logscale=True, ax=None, **kwargs):
+    """
+    Line plot showing average burst and tonic spike coupling strengths and SE across timescale bins.
+    """
+    if ax is None:
+        fig, ax = plt.subplots()
+    for event in ['burst', 'tonicspk']:
+        df_sig = df.query(f'({event}_sig == True) & ({event}_strength > 0)').copy()
+        strengths = sort_data(df_sig[f'{event}_strength'], df_sig['freq'], bins=FREQUENCYBINS)
+        ys = np.array([agg(s) for s in strengths])
+        yerr = np.array([err(s) for s in strengths])
+        if not logscale:
+            ax.plot(FREQUENCYXPOS, ys, color=COLORS[event], **kwargs)
+            ax.plot(FREQUENCYXPOS, ys + yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
+            ax.plot(FREQUENCYXPOS, ys - yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
+        else:
+            ax.plot(FREQUENCYXPOS, np.log10(ys), color=COLORS[event], **kwargs)
+            ax.plot(FREQUENCYXPOS, np.log10(ys + yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
+            ax.plot(FREQUENCYXPOS, np.log10(ys - yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
+    ax.set_xticks(FREQUENCYTICKS)
+    ax.set_xticklabels(FREQUENCYTICKLABELS)
+    ax.set_xlim(left=-3.1)
+    ax.set_xlabel('Timescale (Hz)')
+    ax.set_ylabel('Coupling strength')
+    return ax
+    
+
+## Util
+
+def zero_runs(a):
+    """
+    Return an array with shape (m, 2), where m is the number of "runs" of zeros
+    in a. The first column is the index of the first 0 in each run, the second
+    is the index of the first nonzero element after the run.
+    """
+    # Create an array that is 1 where a is 0, and pad each end with an extra 0.
+    iszero = np.concatenate(([0], np.equal(a, 0).view(np.int8), [0]))
+    absdiff = np.abs(np.diff(iszero))
+    # Runs start and end where absdiff is 1.
+    ranges = np.where(absdiff == 1)[0].reshape(-1, 2)
+
+    return ranges
+
+def merge_ranges(ranges, dt=1):
+    """
+    Given a set of ranges [start, stop], return new set of ranges where all
+    overlapping ranges are merged.
+    """
+    tpts = np.arange(ranges.min(), ranges.max(), dt) # array of time points
+    tc = np.ones_like(tpts) # time course of ranges
+    for t0, t1 in ranges: # for each range
+        i0, i1 = tpts.searchsorted([t0, t1])
+        tc[i0:i1] = 0 # set values in range to 0
+    new_ranges = zero_runs(tc) # get indices of continuous stretches of zero
+    if new_ranges[-1, -1] == len(tpts): # fix end-point
+        new_ranges[-1, -1] = len(tpts) - 1
+    return tpts[new_ranges]
+
+def continuous_runs(data, max0len=1, min1len=1, min1prop=0):
+    """
+    Get start and stop indices of stretches of (relatively) continuous data.
+
+    Parameters
+    ----------
+    data : ndarray
+        1D boolean array
+    max0len : int
+        maximum length (in data pts) of False stretches to ignore
+    min1len : int
+        minimum length (in data pts) of True runs to keep
+    min1prop : int
+        minimum proprtion of True data in the run necessary for it
+        to be considered
+
+    Returns
+    -------
+    out : ndarray
+        (m, 2) array of start and stop indices, where m is the number runs of
+        continuous True values
+    """
+    # get ranges of True values
+    one_ranges = zero_runs(~data)
+    if len(one_ranges) == 0:
+        return np.array([[]])
+    # merge ranges that are separated by < min0len of False
+    one_ranges[:, 1] += (max0len - 1)
+    one_ranges = merge_ranges(one_ranges)
+    # return indices to normal
+    one_ranges[:, 1] -= (max0len - 1)
+    one_ranges[-1, -1] += 1
+    # remove ranges that are too short
+    lengths = np.diff(one_ranges, axis=1).ravel()
+    one_ranges = one_ranges[lengths >= min1len]
+    # remove ranges that don't have sufficient proportion True
+    prop = np.array([data[i0:i1].sum() / (i1 - i0) for (i0, i1) in one_ranges])
+    return one_ranges[prop >= min1prop]
+
+def switch_ranges(ranges, dt=1, minval=0, maxval=None):
+    """
+    Given a set of (start, stop) pairs, return a new set of pairs for values
+    outside the given ranges.
+
+    Parameters
+    ----------
+    ranges : ndarray
+        N x 2 array containing start and stop values in the first and second
+        columns respectively
+    dt : float
+
+    minval, maxval : int
+        the minimum and maximum possible values, if maxval is None it is assumed
+        that the maximum possible value is the maximum value in the input ranges
+
+    Returns
+    -------
+    out : ndarray
+        M x 2 array containing start and stop values of all ranges outside of
+        the input ranges
+    """
+    if ranges.shape[1] == 0:
+        return np.array([[minval, maxval]])
+    assert (ranges.ndim == 2) & (ranges.shape[1] == 2), "A two-column array is expected"
+    maxval = ranges.max() if maxval is None else maxval
+    # get new ranges
+    new_ranges = np.zeros_like(ranges)
+    new_ranges[:,0] = ranges[:,0] - dt # new stop values
+    new_ranges[:,1] = ranges[:,1] + dt # new start values
+    # fix boundaries
+    new_ranges = new_ranges.ravel()
+    if new_ranges[0] >= (minval + dt): # first new stop within allowed range
+        new_ranges = np.concatenate((np.array([minval]), new_ranges))
+    else:
+        new_ranges = new_ranges[1:]
+    if new_ranges[-1] <= (maxval - dt): # first new start within allowed range
+        new_ranges = np.concatenate((new_ranges, np.array([maxval])))
+    else:
+        new_ranges = new_ranges[:-1]
+    return new_ranges.reshape((int(len(new_ranges) / 2), 2))
+
+def shuffle_bins(x, binwidth=1):
+    """
+    Randomly shuffle bins of an array.
+    """
+    # bin start indices
+    bins_i0 = np.arange(0, len(x), binwidth)
+    # shuffled bins
+    np.random.shuffle(bins_i0)
+    # concatenate shuffled bins
+    shf = np.concatenate([x[i0:(i0 + binwidth)] for i0 in bins_i0])
+    return shf
+
+def take_data_in_bouts(series, data, bouts, trange=None, dt=2, dt0=0, dt1=0, concatenate=True, norm=False):
+    if series['%s_bouts' % bouts].shape[1] < 1:
+        return np.array([])
+    header, _ = data.split('_')
+    data_in_bouts = []
+    for t0, t1 in series['%s_bouts' % bouts]:
+        t0 -= dt0
+        t1 += dt1
+        if trange == 'start':
+            t1 = t0 + dt
+        elif trange == 'end':
+            t0 = t1 - dt
+        elif trange == 'middle':
+            t0 = t0 + dt
+            t1 = t1 - dt
+        if t1 <= t0:
+            continue
+        if t0 < series['%s_tpts' % header].min():
+            continue
+        if t1 > series['%s_tpts' % header].max():
+            continue 
+        i0, i1 = series['%s_tpts' % header].searchsorted([t0, t1])
+        data_in_bout = series[data][i0:i1].copy()
+        if norm:
+            data_in_bout = data_in_bout / series[data].max()
+        data_in_bouts.append(data_in_bout)
+    if concatenate:
+        data_in_bouts = np.concatenate(data_in_bouts)
+    return data_in_bouts
+
+def get_trials(series, stim_id=0, opto=False, multi_stim='warn'):
+    if opto:
+        opto = np.isin(series['trial_id'], series['opto_trials'])
+    elif not opto:
+        opto = ~np.isin(series['trial_id'], series['opto_trials'])
+    if stim_id < 0:
+        stim = np.ones_like(series['stim_id']).astype('bool')
+    else:
+        stim = series['stim_id'] == stim_id
+    series['trial_on_times'] = series['trial_on_times'][opto & stim]
+    series['trial_off_times'] = series['trial_off_times'][opto & stim]
+    return series
+
+def sort_data(data, sort_vals, bins=10):
+    if type(bins) == int:
+        nbins = bins
+        bin_edges = np.linspace(sort_vals.min(), sort_vals.max(), nbins + 1)
+    else:
+        nbins = len(bins) - 1
+        bin_edges = bins
+    digitized_vals = np.digitize(sort_vals, bins=bin_edges).clip(1, nbins)
+    return [data[digitized_vals == val] for val in np.arange(nbins) + 1]
+
+def apply_sort_data(series, data_col, sort_col, bins=10):
+    return sort_data(series[data_col], series[sort_col], bins)
+
+
+## Statistics
+
+def get_binned_rates(spk_rates, pupil_area, sort=False, nbins=10):
+    # Get bins base on percentiles to eliminate effect of outliers
+    min_area, max_area = np.percentile(pupil_area, [2.5, 97.5])
+    #min_area, max_area = pupil_area.min(), pupil_area.max()
+    area_bins = np.linspace(min_area, max_area, nbins + 1)
+    # Bin pupil area
+    binned_area = np.digitize(pupil_area, bins=area_bins).clip(1, nbins) - 1
+    # Bin rates according to pupil area
+    binned_rates = np.array([spk_rates[binned_area == bin_i] for bin_i in np.arange(nbins)], dtype=object)
+    if sort:
+        sorted_inds = np.argsort([rates.mean() if len(rates) > 0 else 0 for rates in binned_rates])
+        binned_rates = binned_rates[sorted_inds]
+        binned_area = np.squeeze([np.where(sorted_inds == area_bin)[0] for area_bin in binned_area])
+    return area_bins, binned_area, binned_rates
+
+def rescale(x, method='min_max'):
+    # Set re-scaling method
+    if method == 'z_score':
+        return (x - x.mean()) / x.std()
+    elif method == 'min_max':
+        return (x - x.min()) / (x.max() - x.min())
+
+def correlogram(ts1, ts2=None, tau_max=1, dtau=0.01, return_tpts=False):
+    if ts2 is None:
+        ts2 = ts1.copy()
+        auto = True
+    else:
+        auto = False
+    tau_max = (tau_max // dtau) * dtau
+    tbins = np.arange(-tau_max, tau_max + dtau, dtau)
+    ccg = np.zeros(len(tbins) - 1)
+    for t0 in ts1:
+        dts = ts2 - t0
+        if auto:
+            dts = dts[dts != 0]
+        ccg += np.histogram(dts[np.abs(dts) <= tau_max], bins=tbins)[0]
+    ccg /= len(ts1)
+    if not return_tpts:
+        return ccg
+    else:
+        tpts = tbins[:-1] + dtau / 2
+        return tpts, ccg
+
+def angle_subtract(a1, a2, period=(2 * np.pi)):
+    return (a1 - a2) % period
+
+def circmean(alpha, w=None, axis=None):
+    """
+    Compute mean resultant vector of circular data.
+
+    Parameters
+    ----------
+    alpha : ndarray
+        array of angles
+    w : ndarray
+        array of weights, must be same shape as alpha
+    axis : int, None
+        axis across which to compute mean
+
+    Returns
+    -------
+    mrl : ndarray
+        mean resultant vector length
+    theta : ndarray
+        mean resultant vector angle
+    """
+    # weights default to ones
+    if w is None:
+        w = np.ones_like(alpha)
+    w[np.isnan(alpha)] = 0
+
+    # compute weighted mean
+    mean_vector = np.nansum(w * np.exp(1j * alpha), axis=axis) / w.sum(axis=axis)
+    mrl = np.abs(mean_vector) # length
+    theta = np.angle(mean_vector) # angle
+
+    return mrl, theta
+
+def circmean_angle(alpha, **kwargs):
+    return circmean(alpha, **kwargs)[1]
+
+def circhist(angles, n_bins=8, proportion=True, wrap=False):
+    bins = np.linspace(-np.pi, np.pi, n_bins + 1, endpoint=True)
+    weights = np.ones(len(angles))
+    if proportion:
+        weights /= len(angles)
+    counts, bins = np.histogram(angles, bins=bins, weights=weights)
+    if wrap:
+        counts = np.concatenate([counts, [counts[0]]])
+        bins = np.concatenate([bins, [bins[0]]])
+    return counts, bins
+
+def unbiased_variance(data):
+    if len(data) <= 1:
+        return np.nan
+    else:
+        return np.var(data) * len(data) / (len(data) - 1)
+
+def se_median(sample, n_resamp=1000):
+    """Standard error of the median."""
+    medians = np.full(n_resamp, np.nan)
+    for i in range(n_resamp):
+        resample = np.random.choice(sample, len(sample), replace=True)
+        medians[i] = np.median(resample)
+    return np.std(medians)
+
+def coupling_summary(df):
+    """Print some basic statistics for phase coupling."""
+    # Either spike type
+    units = df.query('(tonicspk_p == tonicspk_p) or (burst_p == burst_p)').groupby(['m', 's', 'e', 'u'])
+    n_sig = units.apply(lambda x: any(x[f'tonicspk_sig']) or any(x[f'burst_sig'])).sum()
+    prop_sig = n_sig / len(units)
+    print(f"Neurons with significant coupling: {prop_sig:.3f} ({n_sig}/{len(units)})")
+    # For each spike type
+    for spk_type in ['tonicspk', 'burst']:
+        units = df.dropna(subset=f'{spk_type}_p').groupby(['m', 's', 'e', 'u'])
+        n_sig = units.apply(lambda x: any(x[f'{spk_type}_sig'])).sum()
+        prop_sig = n_sig / len(units)
+        print(f"{spk_type.capitalize()} prop. significant: {prop_sig:.3f} ({n_sig}/{len(units)})")
+        n_cpds = units.apply(lambda x: x[f'{spk_type}_sig'].sum())
+        print(f"{spk_type.capitalize()} num. CPDs per neuron: {n_cpds.mean():.2f}, {n_cpds.std():.2f}")
+
+def kl_divergence(p, q):
+    return np.sum(np.where(p != 0, p * np.log(p / q), 0))
+
+def match_distributions(x1, x2, x1_bins, x2_bins):
+    """
+    For two time series, x2 & x2, return indices of sub-sampled time
+    periods such that the distribution of x2 is matched across
+    bins of x1.
+    """
+    x1_nbins = len(x1_bins) - 1
+    x2_nbins = len(x2_bins) - 1
+    # bin x1
+    x1_binned = np.digitize(x1, x1_bins).clip(1, x1_nbins) - 1
+    # get continuous periods where x1 visits each bin
+    x1_ranges = [zero_runs(~np.equal(x1_binned, x1_bin)) for x1_bin in np.arange(x1_nbins)]
+    # get mean of x2 for each x1 bin visit
+    x2_means = [np.array([np.mean(x2[i0:i1]) for i0, i1 in x1_bin]) for x1_bin in x1_ranges]
+    # find minimum common distribution across x1 bins
+    x2_counts = np.row_stack([np.histogram(means, bins=x2_bins)[0] for means in x2_means])
+    x2_mcd = x2_counts.min(axis=0)
+    # bin x2 means
+    x2_means_binned = [np.digitize(means, bins=x2_bins).clip(1, x2_nbins) - 1 for means in x2_means]
+    x2_means_in_bins = [[means[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for means, binned_means in zip(x2_means, x2_means_binned)]
+    x1_ranges_in_bins = [[ranges[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for ranges, binned_means in zip(x1_ranges, x2_means_binned)]
+    # loop over x2 bins
+    matched_ranges = [[], [], [], []]
+    for x2_bin in np.arange(x2_nbins):
+        # find the x1 bin matching the MCD
+        seed_x1_bin = np.where(x2_counts[:, x2_bin] == x2_mcd[x2_bin])[0][0]
+        assert len(x2_means_in_bins[seed_x1_bin][x2_bin]) == x2_mcd[x2_bin]
+        # for each bin visit, find the closest matching mean in the other x1 bins
+        target_means = x2_means_in_bins[seed_x1_bin][x2_bin]
+        target_ranges = x1_ranges_in_bins[seed_x1_bin][x2_bin]
+        for target_mean, target_range in zip(target_means, target_ranges):
+            matched_ranges[seed_x1_bin].append(target_range)
+            for x1_bin in np.delete(np.arange(x1_nbins), seed_x1_bin):
+                matching_ind = np.abs(x2_means_in_bins[x1_bin][x2_bin] - target_mean).argmin()
+                matched_ranges[x1_bin].append(x1_ranges_in_bins[x1_bin][x2_bin][matching_ind])
+                # delete the matching period
+                x2_means_in_bins[x1_bin][x2_bin] = np.delete(x2_means_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
+                x1_ranges_in_bins[x1_bin][x2_bin] = np.delete(x1_ranges_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
+    return [np.row_stack(ranges) if len(ranges) > 0 else np.array([]) for ranges in matched_ranges]
+
+## Signal processing
+
+def normalized_xcorr(a, b, dt=None, ts=None):
+    """
+    Compute Pearson r between two arrays at various lags
+
+    Parameters
+    ----------
+    a, b : ndarray
+        The arrays to correlate.
+    dt : float
+        The time step between samples in the arrays.
+    ts : list
+        If not None, only the xcorr between the specified lags will be
+        returned.
+
+    Return
+    ------
+    xcorr, lags : ndarray
+        The cross correlation and corresponding lags between a and b.
+        Positive lags indicate that a is delayed relative to b.
+    """
+    assert len(a) == len(b)
+    n = len(a)
+    a_norm = (a - a.mean()) / a.std()
+    b_norm = (b - b.mean()) / b.std()
+    xcorr = np.correlate(a_norm, b_norm, 'full') / n
+    lags = np.arange(-n + 1, n)
+    if dt is not None:
+        lags = lags * dt
+    if ts is not None:
+        assert len(ts) == 2
+        i0, i1 = lags.searchsorted(ts)
+        xcorr = xcorr[i0:i1]
+        lags = lags[i0:i1]
+
+    return xcorr, lags
+
+def interpolate(y, x_old, x_new, axis=0, fill_value='extrapolate'):
+    """
+    Use linear interpolation to re-sample 1D data.
+    """
+    # get interpolation function
+    func = interp1d(x_old, y, axis=axis, fill_value=fill_value)
+    # get new y-values
+    y_interpolated = func(x_new)
+    return y_interpolated
+
+def interpolate_and_normalize(y, x_old, x_new):
+    """
+    Perform linear interpolation and min-max normalization.
+    """
+    y_new = interpolate(y, x_old, x_new)
+    return (y_new - y_new.min()) / (y_new.max() - y_new.min())
+
+def match_signal_length(a, b, a_tpts, b_tpts):
+    """
+    Given two signals, truncate to match the length of the shortest.
+    """
+    t1 = min(a_tpts.max(), b_tpts.max())
+    a1 = a[:a_tpts.searchsorted(t1)]
+    a1_tpts = a_tpts[:a_tpts.searchsorted(t1)]
+    b1 = b[:b_tpts.searchsorted(t1)]
+    b1_tpts = b_tpts[:b_tpts.searchsorted(t1)]
+    return a1, b1, a1_tpts, b1_tpts
+
+def times_to_counts(series, columns, t0=None, t1=None, dt=0.25):
+    if type(t0) == str:
+        t0, t1 = series[t0].min(), series[t0].max() 
+    elif t0 is None:  # get overlapping time range for all columns
+        t0 = max([series[f'{col.split("_")[0]}_times'].min() for col in columns])
+    elif t1 is None:
+        t1 = min([series[f'{col.split("_")[0]}_times'].max() for col in columns])
+    # Set time base
+    tbins = np.arange(t0, t1, dt)
+    tpts = tbins[:-1] + (dt / 2)
+    for col in columns:
+        header = col.split('_')[0]        
+        times = series[f'{header}_times']
+        counts, _ = np.histogram(times, bins=tbins)
+        series[f'{header}_counts'] = counts
+        series[f'{header}_tpts'] = tpts
+    return series
+
+def resample_timeseries(y, tpts, dt=0.25):
+    tpts_new = np.arange(tpts.min(), tpts.max(), dt)
+    return tpts_new, interpolate(y, tpts, tpts_new)
+
+def _resample_data(series, columns, t0=None, t1=None, dt=0.25):
+    if type(t0) == str:
+        t0, t1 = series[t0].min(), series[t0].max() 
+    elif t0 is None:  # get overlapping time range for all columns
+        t0 = max([series[f'{col.split("_")[0]}_tpts'].min() for col in columns])
+        t1 = min([series[f'{col.split("_")[0]}_tpts'].max() for col in columns])
+    # Set new time base
+    tbins = np.arange(t0, t1, dt)
+    tpts_new = tbins[:-1] + (dt / 2)
+    # Interpolate and re-sample each column
+    for col in columns:
+        header = col.split('_')[0]
+        data = series[col]
+        tpts = series[f'{header}_tpts']
+        series[col] = interpolate(data, tpts, tpts_new)
+        series[f'{header}_tpts'] = tpts_new
+    return series
+
+## Neural activity
+def get_mean_rates(df):
+    df['mean_rate'] = df.apply(
+        lambda x:
+            len(x['spk_times']) / (x['spk_tinfo'][1] - x['spk_tinfo'][0]),
+        axis='columns'
+        )
+    return df
+
+def get_mean_rate_threshold(df, alpha=0.025):
+    if 'mean_rate' not in df.columns:
+        df = get_mean_rates(df)
+    rates = np.log10(df['mean_rate'])
+    gmm = GaussianMixture(n_components=2)
+    gmm.fit(rates[..., np.newaxis])
+    (mu, var) = (gmm.means_.max(), gmm.covariances_.squeeze()[gmm.means_.argmax()])
+    threshold = mu + norm.ppf(alpha) * np.sqrt(var)
+    return threshold
+
+def filter_units(df, threshold):
+    if 'mean_rate' not in df.columns:
+        df = get_mean_rates(df)
+    return df.query(f'mean_rate >= {threshold}')
+
+def get_raster(series, events, spike_type='spk', pre=0, post=1):
+    events = series['%s_times' % events]
+    spks = series['%s_times' % spike_type]
+    raster = np.array([spks[(spks > t0 - pre) & (spks < t0 + post)] - t0 for t0 in events], dtype='object')
+    return raster
+
+def get_psth(events, spikes, pre=0, post=1, dt=0.001, bw=0.01, baseline=[]):
+    rate_kernel = GaussianKernel(bw*pq.s)
+    tpts = np.arange(pre, post, dt)
+    psth = np.full((len(events), len(tpts)), np.nan)
+    for i, t0 in enumerate(events):
+        rel_ts = spikes - t0
+        rel_ts = rel_ts[(rel_ts >= pre) & (rel_ts <= post)]
+        try:
+            rate = instantaneous_rate(
+                SpikeTrain(rel_ts, t_start=pre, t_stop=post, units='s'),
+                sampling_period=dt*pq.s,
+                kernel=rate_kernel
+                )
+        except:
+            continue
+        psth[i] = rate.squeeze()
+    if baseline:
+        b0, b1 = tpts.searchsorted(baseline)
+        baseline_rate = psth[:, b0:b1].mean(axis=1)
+        psth = (psth.T - baseline_rate).T
+    return psth, tpts
+
+def apply_get_psth(series, events, spike_type, **kwargs):
+    events = series['{}_times'.format(events)]
+    spikes = series['{}_times'.format(spike_type)]
+    psth, tpts = get_psth(events, spikes, **kwargs)
+    return psth
+
+def get_responses(events, data, tpts, pre=0, post=1, baseline=[]):
+    dt = np.round(np.diff(tpts).mean(), 3)  # round to nearest ms
+    i_pre, i_post = int(pre / dt), int(post / dt)
+    responses = np.full((len(events), i_pre + i_post), np.nan)
+    for j, t0 in enumerate(events):
+        i = tpts.searchsorted(t0)
+        i0, i1 = i - i_pre, i + i_post
+        if i0 < 0:
+            continue 
+        if i1 > len(data):
+            break
+        responses[j] = data[i0:i1]
+    tpts = np.linspace(pre, post, responses.shape[1])
+    if baseline:
+        b0, b1 = tpts.searchsorted(baseline)
+        baseline_resp = responses[:, b0:b1].mean(axis=1)
+        responses = (responses.T - baseline_resp).T
+    return responses, tpts
+
+def apply_get_responses(series, events, data, **kwargs):
+    events = series[f'{events}_times']
+    tpts = series[f'{data.split("_")[0]}_tpts']
+    data = series[f'{data}']
+    responses, tpts = get_responses(events, data, tpts, **kwargs)
+    return responses