import os import numpy as np import pandas as pd from tqdm import tqdm import argparse from scipy.interpolate import interp1d from util import zero_runs from hht import HHT from parameters import * from util import (load_data, filter_units, switch_ranges, merge_ranges, angle_subtract, circmean, get_trials, shuffle_bins) def resample(data, old_tpts, new_tpts, axis=0, fill_value='extrapolate'): """ Use linear interpolation to re-sample data. """ # interpolate time-course with linear splines func = interp1d(old_tpts, data, axis=axis, fill_value=fill_value) # get new time-course interpolated_data = func(new_tpts) return interpolated_data def phase2rank(alpha): """ Convert angles to circular ranks. """ n = len(alpha) ranks = np.full(n, np.nan) ranks[alpha.argsort()] = np.arange(n) return 2 * np.pi * ranks / n def rank2phase(rank, alpha): """ Convert a circular rank back to phase in the original distribution. """ n = len(alpha) # convert circular rank to linear rank linrank = n * rank / 2 / np.pi return np.sort(alpha)[np.round(linrank).astype('int')] def inds2train(inds, length): """ Convert event indices into binary time-series. Parameters ---------- inds : 1D array event indices length : int total length of segment in which events occur Returns ------- event_train : ndarray binary time-series """ # initialize output array event_train = np.zeros(length, dtype='uint8') event_train[inds] = 1 return event_train def times2train(evts, tpts): """ Convert event times into binary time-series. Parameters ---------- evts : 1D array event times tpts : 1D array time base in which events occur Returns ------- event_train : ndarray binary time-series array """ # clip events that fall out of time base evts_in_tpts = evts[(evts > tpts.min()) & (evts < tpts.max())] # convert times to indices evis = tpts.searchsorted(evts_in_tpts) # get the event train ev_train = inds2train(evis, len(tpts)) return ev_train def modified_mrl2(alpha, w=None, axis=0): """ A bias-free measure of the squared mean resultant length [1]. Parameters ---------- alpha : ndarray array of angles w : ndarray array of weights, must be same shape as alpha axis : int, None axis across which to compute mean Returns ------- out : ndarray bias-corrected squared mean resultant length Notes ----- - taking the square-root of this measure does *not* provide a bias-free measure of the mean resultant length, see [1]. References ---------- [1] Kutil, R. (2012). Biased and unbiased estimation of the circular mean resultant length and its variance. Statistics, 46(4), 549-561. """ mrl, _ = circmean(alpha, w=w, axis=axis) n = alpha.shape[axis] return (n / (n - 1)) * (mrl ** 2 - (1 / n)) def phase_tuning(phase, spk_train, shuffle_binwidth=1000, n_shuffles=1000): ranks = phase2rank(phase) - np.pi assert len(phase) == len(spk_train) # Get phase ranks where spikes occur spk_ranks = ranks[spk_train == 1] # Compute modified mean vector length r = modified_mrl2(spk_ranks) # Compute mean rank _, mean_rank = circmean(spk_ranks) # Convert rank back to phase theta = rank2phase(mean_rank + np.pi, phase) # Compute tuning strength for shuffled spike trains r_shf = np.full(n_shuffles, np.nan) for shf_i in range(n_shuffles): # Shuffle time bins of spike train spk_train_shf = shuffle_bins(spk_train, shuffle_binwidth) # Take phase of shuffled train spk_ranks_shf = ranks[spk_train_shf == 1] # Compute tuning strength r_shf[shf_i] = modified_mrl2(spk_ranks_shf) p = (r_shf > r).sum() / n_shuffles return r, theta, p if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('e_name') parser.add_argument('-s', '--spk_types', nargs='+', default=['tonicspk', 'burst']) parser.add_argument('-t', '--tranges', default='') args = parser.parse_args() if args.tranges: assert args.tranges in ['run', 'sit', 'desync', 'sizematched', 'nosaccade', 'noopto'] df_pupil = load_data('pupil', [args.e_name]) df_pupil.set_index(['m', 's', 'e'], inplace=True) df_spikes = load_data('spikes', [args.e_name]) df_spikes.set_index(['m', 's', 'e'], inplace=True) df_spikes = filter_units(df_spikes, MINRATE) ## TODO: find a better way to integrate saccades if 'saccade' in args.spk_types: df_spikes['saccade_times'] = [df_pupil.loc[idx]['saccade_times'] for idx, unit in df_spikes.iterrows()] # Load data for requested time ranges if args.tranges in ['nosaccade']: df_pupil = load_data('pupil', [args.e_name]).set_index(['m', 's', 'e']) elif args.tranges in ['desync', 'sizematched']: df_hht = load_data('hht', [args.e_name]).set_index(['m', 's', 'e']) elif args.tranges in ['run', 'sit']: df_run = load_data('ball', [args.e_name]).set_index(['m', 's', 'e']) elif args.tranges in ['noopto']: df_trials = load_data('trials', [args.e_name]) df_trials.rename(columns={'trial_on_time':'trial_on_times', 'trial_off_time':'trial_off_times'}, inplace=True) df_trials = df_trials.apply(get_trials, stim_id=-1, axis='columns').set_index(['m', 's', 'e']) seriess = [] for idx, row in tqdm(df_pupil.iterrows(), total=len(df_pupil)): pupil_area = row['pupil_area'] pupil_tpts = row['pupil_tpts'] # Get IMFs pupil_fs = 1 / np.diff(pupil_tpts).mean() hht = HHT(pupil_area, pupil_fs) hht.emd() # Get phases and frequencies hht.hsa() hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True) imf_phases = hht.phase.T imf_freqs = hht.characteristic_frequency imf_power = hht.power_ratio # Get time-ranges if args.tranges in ['run', 'sit']: try: tranges = df_run.loc[idx, '%s_bouts' % args.tranges] except KeyError: print("No run data found for ", idx) continue if args.tranges == 'run': dt0 = BEHAVEXCLUSIONS['run'][1] dt1 = BEHAVEXCLUSIONS['sit'][0] elif args.tranges == 'sit': dt0 = BEHAVEXCLUSIONS['sit'][1] dt1 = BEHAVEXCLUSIONS['run'][0] ext = np.ones_like(tranges) * np.array([dt0, dt1]) tranges = tranges + ext tranges = np.row_stack([trange for trange in tranges if trange[0] < trange[1]]) tranges = [tranges for imf in range(hht.n_imfs)] elif args.tranges in ['desync', 'sizematched']: try: tranges = df_hht.loc[idx, '%s_bouts' % args.tranges] except KeyError: print("No HHT data found for ", idx) continue elif args.tranges in ['nosaccade']: saccade_times = df_pupil.loc[idx, 'saccade_times'] saccade_tranges = np.column_stack([saccade_times, saccade_times]) saccade_tranges += np.array(BEHAVEXCLUSIONS['saccade']) saccade_tranges = merge_ranges(saccade_tranges, dt=(1 / pupil_fs)) tranges = switch_ranges( saccade_tranges, dt=(1 / pupil_fs), minval=pupil_tpts.min(), maxval=pupil_tpts.max() ) tranges = [tranges for imf in range(hht.n_imfs)] elif args.tranges in ['opto', 'noopto']: try: trial_ids = df_trials.loc[idx, 'trial_id'] opto_trials = df_trials.loc[idx, 'opto_trials'] trial_on_time = df_trials.loc[idx, 'trial_on_times'] trial_off_time = df_trials.loc[idx, 'trial_off_times'] except KeyError: print("No trial data found for ", idx) continue t0s = trial_on_time t1s = trial_off_time tranges = np.column_stack([t0s, t1s]) tranges = [tranges for imf in range(hht.n_imfs)] elif args.tranges in ['half1', 'half2']: t0, t1 = pupil_tpts.min(), pupil_tpts.max() half_length = (t1 - t0) / 2 if args.tranges == 'half1': tranges = [np.array([[t0, t0 + half_length]]) for imf in range(hht.n_imfs)] if args.tranges == 'half2': tranges = [np.array([[t0 + half_length, t1]]) for imf in range(hht.n_imfs)] elif args.tranges in ['split1', 'split2']: imf_cycles = [pupil_tpts[np.where(np.diff(phase) < -np.pi)[0]] for phase in hht.phase.T] imf_cycles = [np.concatenate([pupil_tpts[:1], cycles, pupil_tpts[-1:]]) for cycles in imf_cycles] cycle_tranges = [np.column_stack([cycles[:-1], cycles[1:]]) for cycles in imf_cycles] if args.tranges == 'split1': tranges = [cycles[0::2] for cycles in cycle_tranges] else: tranges = [cycles[1::2] for cycles in cycle_tranges] # Get units for this experiment try: df_units = df_spikes.loc[idx] except KeyError: print("Spikes missing for {}".format(idx)) continue for _, unit in df_units.iterrows(): unit_tpts = np.arange(*unit['spk_tinfo']) t0, t1 = row['pupil_tpts'].min(), row['pupil_tpts'].max() i0, i1 = unit_tpts.searchsorted([t0, t1]) unit_tpts = unit_tpts[i0:i1] imf_phases_resamp = resample(imf_phases, pupil_tpts, unit_tpts, axis=1) spk_trains = {} for spk_type in args.spk_types: spk_trains[spk_type] = times2train(unit['{}_times'.format(spk_type)], unit_tpts) for imf_i, phase in enumerate(imf_phases_resamp): data = { 'm': idx[0], 's': idx[1], 'e': idx[2], 'u': unit['u'], 'imf': imf_i + 1, 'freq': imf_freqs[imf_i], 'power': imf_power[imf_i] } # Get time ranges to analyze if args.tranges: tranges_imf = tranges[imf_i] else: tranges_imf = np.array([[t0, t1]]) iranges_imf = unit_tpts.searchsorted(tranges_imf) unit_fs = 1 / np.diff(unit_tpts).mean() binwidth = np.floor(unit_fs * SHUFFLE_BINWIDTH).astype('int') for spk_type, spk_train in spk_trains.items(): # Take only data in ranges if len(iranges_imf) > 0: phase_clipped = np.concatenate([phase[i0:i1] for i0, i1 in iranges_imf]) train_clipped = np.concatenate([spk_train[i0:i1] for i0, i1 in iranges_imf]) else: train_clipped = np.array([]) # Check that there are enough spikes to do analysis nspikes = train_clipped.sum() if nspikes < NSPIKES: r = theta = p = np.nan else: r, theta, p = phase_tuning( phase_clipped, train_clipped, shuffle_binwidth=binwidth, n_shuffles=NSHUFFLES ) data['_'.join([spk_type, 'n'])] = nspikes data['_'.join([spk_type, 'strength'])] = r data['_'.join([spk_type, 'phase'])] = theta data['_'.join([spk_type, 'p'])] = p seriess.append(pd.Series(data=data)) df_tuning = pd.DataFrame(seriess) if args.tranges: filename = 'phasetuning_{}_{}.pkl'.format(args.e_name, args.tranges) elif args.spk_types == ['saccade']: filename = 'phasetuning_{}_{}.pkl'.format(args.e_name, 'saccades') else: filename = 'phasetuning_{}.pkl'.format(args.e_name) df_tuning.to_pickle(DATAPATH + filename)