loading.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import sys, os
  2. sys.path.append(os.path.join(os.getcwd(), '..'))
  3. sys.path.append(os.path.join(os.getcwd(), '..', '..'))
  4. from imports import *
  5. from target import build_tgt_matrix
  6. import pandas as pd
  7. def load_session_data(session, load_units=True, load_aeps=True, load_moseq=True):
  8. all_areas = ['A1', 'PPC', 'HPC']
  9. animal = session.split('_')[0]
  10. sessionpath = os.path.join(source, animal, session)
  11. h5_file = os.path.join(sessionpath, session + '.h5')
  12. aeps_file = os.path.join(sessionpath, 'AEPs.h5')
  13. moseq_file = os.path.join(sessionpath, 'moseq.h5')
  14. report_path = os.path.join(report, 'PSTH', session)
  15. if not os.path.exists(report_path):
  16. os.makedirs(report_path)
  17. # load timeline and configuration
  18. with h5py.File(h5_file, 'r') as f:
  19. tl = np.array(f['processed']['timeline']) # time, X, Y, speed, etc.
  20. trials = np.array(f['processed']['trial_idxs']) # t_start_idx, t_end_idx, x_tgt, y_tgt, r_tgt, result
  21. cfg = json.loads(f['processed'].attrs['parameters'])
  22. # load units
  23. unit_names, single_units, spike_times = [], {}, {}
  24. if load_units:
  25. with h5py.File(h5_file, 'r') as f:
  26. unit_names = [x for x in f['units']]
  27. with h5py.File(h5_file, 'r') as f:
  28. for unit_name in unit_names:
  29. spike_times[unit_name] = np.array(f['units'][unit_name][H5NAMES.spike_times['name']])
  30. single_units[unit_name] = np.array(f['units'][unit_name][H5NAMES.inst_rate['name']])
  31. #single_units[unit_name] = instantaneous_rate(unit_times, tl[:, 0], k_width=50)
  32. # load AEPs
  33. areas, aeps, aeps_events, lfp = [], {}, [], {}
  34. AEP_metrics_lims, AEP_metrics_raw, AEP_metrics_norm = {}, {}, {}
  35. tgt_matrix = []
  36. if load_aeps:
  37. with h5py.File(aeps_file, 'r') as f:
  38. for area in all_areas:
  39. if not area in f:
  40. continue
  41. aeps[area] = np.array(f[area]['aeps'])
  42. aeps_events = np.array(f['aeps_events'])
  43. areas = list(aeps.keys())
  44. # TODO find better way. Remove outliers
  45. if 'A1' in areas:
  46. aeps['A1'][aeps['A1'] > 5000] = 5000
  47. aeps['A1'][aeps['A1'] < -5000] = -5000
  48. if 'PPC' in areas:
  49. aeps['PPC'][aeps['PPC'] > 1500] = 1500
  50. aeps['PPC'][aeps['PPC'] < -1500] = -1500
  51. if 'HPC' in areas:
  52. aeps['HPC'][aeps['HPC'] > 1500] = 1500
  53. aeps['HPC'][aeps['HPC'] < -1500] = -1500
  54. aeps[areas[0]].shape
  55. # load LFP
  56. lfp = {}
  57. with h5py.File(aeps_file, 'r') as f:
  58. for area in areas:
  59. if 'LFP' in f[area]:
  60. lfp[area] = np.array(f[area]['LFP'])
  61. # load AEP metrics
  62. AEP_metrics_lims = dict([(area, {}) for area in areas])
  63. AEP_metrics_raw = dict([(area, {}) for area in areas])
  64. AEP_metrics_norm = dict([(area, {}) for area in areas])
  65. with h5py.File(aeps_file, 'r') as f:
  66. for area in areas:
  67. grp = f[area]
  68. for metric_name in grp['raw']:
  69. AEP_metrics_raw[area][metric_name] = np.array(grp['raw'][metric_name])
  70. AEP_metrics_norm[area][metric_name] = np.array(grp['norm'][metric_name])
  71. AEP_metrics_lims[area][metric_name] = [int(x) for x in grp['raw'][metric_name].attrs['limits'].split(',')]
  72. # build target matrix
  73. tgt_matrix = build_tgt_matrix(tl, trials, aeps_events)
  74. # load moseq
  75. moseq = []
  76. if load_moseq:
  77. with h5py.File(moseq_file, 'r') as f:
  78. moseq_matrix = np.array(f['moseq'])
  79. moseq_headers = f['moseq'].attrs['headers']
  80. moseq_headers = moseq_headers.split(',')
  81. moseq_headers = [moseq_headers[0]] + [x[1:] for x in moseq_headers[1:]]
  82. moseq = pd.DataFrame(moseq_matrix, columns=moseq_headers)
  83. return {
  84. 'tl': tl,
  85. 'trials': trials,
  86. 'cfg': cfg,
  87. 'areas': areas,
  88. 'aeps': aeps,
  89. 'aeps_events': aeps_events,
  90. 'lfp': lfp,
  91. 'AEP_metrics_lims': AEP_metrics_lims,
  92. 'AEP_metrics_raw': AEP_metrics_raw,
  93. 'AEP_metrics_norm': AEP_metrics_norm,
  94. 'tgt_matrix': tgt_matrix,
  95. 'single_units': single_units,
  96. 'spike_times': spike_times,
  97. 'unit_names': unit_names,
  98. 'animal': animal,
  99. 'aeps_file': aeps_file,
  100. 'moseq_file': moseq_file,
  101. 'h5_file': h5_file,
  102. 'report_path': report_path,
  103. 'moseq': moseq
  104. }