imf_correlation.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import argparse
  2. from tqdm import tqdm
  3. import numpy as np
  4. import pandas as pd
  5. from scipy.stats import pearsonr
  6. from parameters import NIMFCYCLES, NSHUFFLES, DATAPATH
  7. from util import load_data, interpolate, normalized_xcorr
  8. from hht import HHT
  9. if __name__ == "__main__":
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('e_name')
  12. parser.add_argument('-t', '--tranges', default='')
  13. args = parser.parse_args()
  14. df_pupil = load_data('pupil', [args.e_name]).set_index(['m', 's', 'e'])
  15. df_run = load_data('ball', [args.e_name]).set_index(['m', 's', 'e'])
  16. df_all = pd.read_pickle(DATAPATH + 'run.pkl')
  17. # TODO: move to parameters
  18. max_lag = 30 # seconds
  19. seriess = []
  20. for idx, row in tqdm(df_pupil.iterrows(), total=len(df_pupil)):
  21. pupil_area = row['pupil_area']
  22. pupil_tpts = row['pupil_tpts']
  23. pupil_dt = np.diff(pupil_tpts).mean()
  24. pupil_fs = 1 / pupil_dt
  25. t0, t1 = (pupil_tpts.min(), pupil_tpts.max())
  26. # Prepare run data
  27. try:
  28. run_speed, run_tpts = df_run.loc[idx, ['run_speed', 'run_tpts']]
  29. except KeyError:
  30. print("No run data found for ", idx)
  31. continue
  32. i0, i1 = run_tpts.searchsorted([t0, t1])
  33. run_speed = interpolate(run_speed[i0:i1], run_tpts[i0:i1], pupil_tpts)
  34. # Get IMFs
  35. hht = HHT(pupil_area, pupil_fs)
  36. hht.emd()
  37. hht.hsa()
  38. hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
  39. imfs = hht.imfs.T
  40. imf_freqs = hht.characteristic_frequency
  41. imf_power = hht.power_ratio
  42. if args.tranges:
  43. tranges = df_run.loc[idx, '%s_bouts' % args.tranges]
  44. if args.tranges == 'run':
  45. ext = np.ones_like(tranges) * np.array([2, -2])
  46. elif args.tranges == 'sit':
  47. ext = np.ones_like(tranges) * np.array([4, -2])
  48. tranges = tranges + ext
  49. tranges = np.row_stack([trange for trange in tranges if trange[0] < trange[1]])
  50. iranges = pupil_tpts.searchsorted(tranges)
  51. imfs = np.column_stack([imfs[:, i0:i1] for i0, i1 in iranges])
  52. pupil_area = np.concatenate([pupil_area[i0:i1] for i0, i1 in iranges])
  53. run_speed = np.concatenate([run_speed[i0:i1] for i0, i1 in iranges])
  54. for i, imf in enumerate(np.row_stack([pupil_area, imfs])):
  55. data = {
  56. 'm': idx[0],
  57. 's': idx[1],
  58. 'e': idx[2],
  59. 'imf': i
  60. }
  61. if i == 0:
  62. data['freq'] = data['power'] = np.nan
  63. else:
  64. data['freq'] = imf_freqs[i - 1]
  65. data['power'] = imf_power[i - 1]
  66. xcorr, lags = normalized_xcorr(imf, run_speed, dt=pupil_dt, ts=[-1 * max_lag, max_lag])
  67. data['xcorr'] = xcorr
  68. data['xcorr_lags'] = lags
  69. r_null = np.full(NSHUFFLES, np.nan)
  70. j = 0
  71. while j < NSHUFFLES:
  72. tpts, signal = df_all.iloc[np.random.choice(np.arange(len(df_all)))]
  73. signal = interpolate(signal, tpts, np.arange(tpts.min(), tpts.max(), pupil_dt))
  74. i_max = min(len(imf), len(signal))
  75. if len(np.unique(signal[:i_max])) == 1:
  76. continue
  77. r_null[j], _ = pearsonr(imf[:i_max], signal[:i_max])
  78. j += 1
  79. # Get search window for peak
  80. if i == 0:
  81. i0, i1 = 0, len(xcorr)
  82. else:
  83. T = 1 / imf_freqs[i - 1]
  84. i0 , i1 = lags.searchsorted([-T, T])
  85. # Find peak
  86. xcorr_peak = lags[i0:i1][np.abs(xcorr[i0:i1]).argmax()]
  87. data['xcorr_peak'] = xcorr_peak
  88. # Compare to null distibution
  89. xcorr_max = xcorr[i0:i1][np.abs(xcorr[i0:i1]).argmax()]
  90. p = (r_null > xcorr_max).sum() / NSHUFFLES
  91. data['xcorr_p'] = p
  92. data['xcorr_sig'] = (p < 0.025) | (p > 0.975)
  93. seriess.append(pd.Series(data=data))
  94. df_corr = pd.DataFrame(seriess)
  95. if not args.tranges:
  96. filename = 'imfcorr_{}.pkl'.format(args.e_name)
  97. else:
  98. filename = 'imfcorr_{}_{}.pkl'.format(args.e_name, args.tranges)
  99. df_corr.to_pickle(DATAPATH + filename)