pack.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # include modules to the path
  2. import sys, os
  3. parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
  4. sys.path.append(parent_dir)
  5. sys.path.append(os.path.join(parent_dir, 'session'))
  6. import os, json, h5py, time
  7. import numpy as np
  8. import scipy.ndimage as ndi
  9. from scipy import signal
  10. from head_direction import head_direction
  11. from spatial import place_field_2D, map_stats, get_field_patches, best_match_rotation_polar
  12. from spatial import bins2meters, cart2pol, pol2cart
  13. from spiketrain import instantaneous_rate, spike_idxs
  14. from spiking_metrics import mean_firing_rate, isi_cv, isi_fano
  15. from session.utils import get_sessions_list, get_sampling_rate, cleaned_epochs
  16. from session.adapters import load_clu_res, H5NAMES, create_dataset
  17. def pack(session_path):
  18. """
  19. Pack independent tracking datasets into a single HDF5 file.
  20. File has the following structure:
  21. /raw
  22. /positions - raw positions from .csv
  23. /events - raw events from .csv
  24. /sounds - raw sounds from .csv
  25. /islands - raw island infos from .csv (if exists)
  26. /processed
  27. /timeline - matrix of [time, x, y, speed, HD, trial_no, sound_id] sampled at 100Hz,
  28. data is smoothed using gaussian kernels,
  29. inter-trial intervals have trial_no = 0
  30. /trial_idxs - matrix of trial indices to timeline
  31. /sound_idxs - matrix of sound indices to timeline
  32. each dataset has an attribute 'headers' with the description of columns.
  33. """
  34. params_file = [x for x in os.listdir(session_path) if x.endswith('.json')][0]
  35. with open(os.path.join(session_path, params_file)) as json_file:
  36. parameters = json.load(json_file)
  37. h5name = os.path.join(session_path, '%s.h5' % params_file.split('.')[0])
  38. with h5py.File(h5name, 'w') as f: # overwrite mode
  39. # -------- save raw data ------------
  40. raw = f.create_group('raw')
  41. raw.attrs['parameters'] = json.dumps(parameters)
  42. for ds_name in ['positions', 'events', 'sounds', 'islands']:
  43. filename = os.path.join(session_path, '%s.csv' % ds_name)
  44. if not os.path.exists(filename):
  45. continue
  46. with open(filename) as ff:
  47. headers = ff.readline()
  48. data = np.loadtxt(filename, delimiter=',', skiprows=1)
  49. ds = raw.create_dataset(ds_name, data=data)
  50. ds.attrs['headers'] = headers
  51. # TODO - saving contours! and get file names from the config
  52. # with open(os.path.join(session_path, '%s.csv' % 'contours')) as ff:
  53. # data = ff.readlines()
  54. # headers = data[0] # skip headers line
  55. # contours = [[(x.split(':')[0], x.split(':')[1]) for x in contour.split(',')] for contour in data[1:]]
  56. # contours = [np.array(contour) for contour in contours]
  57. # read raw data and normalize to session start
  58. positions = np.array(f['raw']['positions'])
  59. s_start, s_end = positions[:, 0][0], positions[:, 0][-1]
  60. positions[:, 0] = positions[:, 0] - s_start
  61. events = np.array(f['raw']['events'])
  62. events[:, 0] = events[:, 0] - s_start
  63. sounds = np.array(f['raw']['sounds'])
  64. sounds[:, 0] = sounds[:, 0] - s_start
  65. # squeeze - if session was interrupted, adjust times
  66. # to have a continuous timeline
  67. end_idxs = np.where(events[:, 5] == -1)[0]
  68. if len(end_idxs) > 1:
  69. # diffs in time beetween pauses
  70. deltas = [events[idx + 1][0] - events[idx][0] for idx in end_idxs[:-1]]
  71. for df, delta in zip(end_idxs, deltas): # squeezing events
  72. events[df+1:][:, 0] = events[df+1:][:, 0] - delta
  73. end_idxs = np.where(np.diff(sounds[:, 0]) > 20)[0] # squeezing sounds
  74. for df, delta in zip(end_idxs, deltas):
  75. sounds[df+1:][:, 0] = sounds[df+1:][:, 0] - delta
  76. end_idxs = np.where(np.diff(positions[:, 0]) > 20)[0] # squeezing positions - more than 20? secs pauses
  77. for df, delta in zip(end_idxs, deltas):
  78. positions[df+1:][:, 0] = positions[df+1:][:, 0] - delta
  79. parameters['experiment']['timepoints'] = [positions[df+1][0] for df in end_idxs] # update session parameters
  80. parameters['experiment']['session_duration'] = positions[-1][0]
  81. # -------- save processed ------------
  82. proc = f.create_group('processed')
  83. proc.attrs['parameters'] = json.dumps(parameters)
  84. # TODO remove outliers - position jumps over 20cm?
  85. #diffs_x = np.diff(positions[:, 1])
  86. #diffs_y = np.diff(positions[:, 2])
  87. #dists = np.sqrt(diffs_x**2 + diffs_y**2)
  88. #np.where(dists > 0.2 / pixel_size)[0]
  89. # convert timeline to 100 Hz
  90. time_freq = 100 # at 100Hz
  91. s_start, s_end = positions[:, 0][0], positions[:, 0][-1]
  92. times = np.linspace(s_start, s_end, int((s_end - s_start) * time_freq))
  93. pos_at_freq = np.zeros((len(times), 3))
  94. curr_idx = 0
  95. for i, t in enumerate(times):
  96. if curr_idx < len(positions) - 1 and \
  97. np.abs(t - positions[:, 0][curr_idx]) > np.abs(t - positions[:, 0][curr_idx + 1]):
  98. curr_idx += 1
  99. pos_at_freq[i] = (t, positions[curr_idx][1], positions[curr_idx][2])
  100. # save trials
  101. t_count = len(np.unique(events[events[:, -1] != 0][:, -2]))
  102. trials = np.zeros((t_count, 6))
  103. for i in range(t_count):
  104. t_start_idx = (np.abs(pos_at_freq[:, 0] - events[2*i][0])).argmin()
  105. t_end_idx = (np.abs(pos_at_freq[:, 0] - events[2*i + 1][0])).argmin()
  106. state = 0 if events[2*i + 1][-1] > 1 else 1
  107. trials[i] = (t_start_idx, t_end_idx, events[2*i][1], events[2*i][2], events[2*i][3], state)
  108. trial_idxs = proc.create_dataset('trial_idxs', data=trials)
  109. trial_idxs.attrs['headers'] = 't_start_idx, t_end_idx, target_x, target_y, target_r, fail_or_success'
  110. # save sounds
  111. sound_idxs = np.zeros((len(sounds), 2))
  112. left_idx = 0
  113. delta = 10**5
  114. for i in range(len(sounds)):
  115. while left_idx < len(pos_at_freq) and \
  116. np.abs(sounds[i][0] - pos_at_freq[:, 0][left_idx]) < delta:
  117. delta = np.abs(sounds[i][0] - pos_at_freq[:, 0][left_idx])
  118. left_idx += 1
  119. sound_idxs[i] = (left_idx, sounds[i][1])
  120. delta = 10**5
  121. sound_idxs = proc.create_dataset('sound_idxs', data=sound_idxs)
  122. sound_idxs.attrs['headers'] = 'timeline_idx, sound_id'
  123. # building timeline
  124. width = 50 # 100 points ~= 1 sec with at 100Hz
  125. kernel = signal.gaussian(width, std=(width) / 7.2)
  126. x_smooth = np.convolve(pos_at_freq[:, 1], kernel, 'same') / kernel.sum()
  127. y_smooth = np.convolve(pos_at_freq[:, 2], kernel, 'same') / kernel.sum()
  128. # speed
  129. dx = np.sqrt(np.square(np.diff(x_smooth)) + np.square(np.diff(y_smooth)))
  130. dt = np.diff(pos_at_freq[:, 0])
  131. speed = np.concatenate([dx/dt, [dx[-1]/dt[-1]]])
  132. # head direction
  133. temp_tl = np.column_stack([pos_at_freq[:, 0], x_smooth, y_smooth, speed])
  134. hd = head_direction(temp_tl)
  135. # trial numbers
  136. trials_data = np.zeros(len(temp_tl))
  137. for i, trial in enumerate(trials):
  138. idx1, idx2 = trial[0], trial[1]
  139. trials_data[int(idx1):int(idx2)] = i + 1
  140. # sounds played
  141. sound_tl = np.zeros(len(temp_tl))
  142. curr_sound_idx = 0
  143. for i in range(len(temp_tl)):
  144. if curr_sound_idx + 1 >= len(sounds):
  145. break
  146. if temp_tl[i][0] > sounds[curr_sound_idx][0]:
  147. curr_sound_idx += 1
  148. sound_tl[i] = sounds[curr_sound_idx][1]
  149. timeline = proc.create_dataset('timeline', data=np.column_stack(\
  150. [pos_at_freq[:, 0], x_smooth, y_smooth, speed, hd, trials_data, sound_tl]
  151. ))
  152. timeline.attrs['headers'] = 'time, x, y, speed, hd, trial_no, sound_ids'
  153. return h5name
  154. def write_units(sessionpath):
  155. filebase = os.path.basename(os.path.normpath(sessionpath))
  156. h5name = os.path.join(sessionpath, filebase + '.h5')
  157. # loading unit data
  158. units = load_clu_res(sessionpath) # spikes are in samples, not seconds
  159. sampling_rate = get_sampling_rate(sessionpath)
  160. # loading trajectory
  161. with h5py.File(h5name, 'r') as f:
  162. tl = np.array(f['processed']['timeline']) # time, X, Y, speed, HD, trials, sounds
  163. # packing
  164. with h5py.File(h5name, 'a') as f:
  165. if not 'units' in f:
  166. f.create_group('units')
  167. for electrode_idx in units.keys():
  168. unit_idxs = units[electrode_idx]
  169. for unit_idx, spiketrain in unit_idxs.items():
  170. unit_name = '%s-%s' % (electrode_idx, unit_idx)
  171. s_times = spiketrain/sampling_rate
  172. i_rate = instantaneous_rate(s_times, tl[:, 0])
  173. s_idxs = spike_idxs(s_times, tl[:, 0])
  174. with h5py.File(h5name, 'a') as f:
  175. if not unit_name in f['units']:
  176. grp = f['units'].create_group(unit_name)
  177. create_dataset(h5name, '/units/%s' % unit_name, H5NAMES.spike_times, s_times)
  178. create_dataset(h5name, '/units/%s' % unit_name, H5NAMES.inst_rate, i_rate)
  179. create_dataset(h5name, '/units/%s' % unit_name, H5NAMES.spike_idxs, s_idxs)
  180. def write_spiking_metrics(sessionpath):
  181. filebase = os.path.basename(os.path.normpath(sessionpath))
  182. h5name = os.path.join(sessionpath, filebase + '.h5')
  183. epochs = cleaned_epochs(sessionpath)
  184. with h5py.File(h5name, 'r') as f:
  185. unit_names = [name for name in f['units']]
  186. for unit_name in unit_names:
  187. with h5py.File(h5name, 'r') as f:
  188. st = np.array(f['units'][unit_name][H5NAMES.spike_times['name']])
  189. mfr_vals = np.zeros(len(epochs))
  190. isi_cv_vals = np.zeros(len(epochs))
  191. isi_fn_vals = np.zeros(len(epochs))
  192. for i, epoch in enumerate(epochs):
  193. st_cut = st[(st > epoch[0]) & (st < epoch[1])]
  194. mfr_vals[i] = mean_firing_rate(st_cut)
  195. isi_cv_vals[i] = isi_cv(st_cut)
  196. isi_fn_vals[i] = isi_fano(st_cut)
  197. create_dataset(h5name, '/units/%s' % unit_name, H5NAMES.mfr, np.array(mfr_vals))
  198. create_dataset(h5name, '/units/%s' % unit_name, H5NAMES.isi_cv, np.array(isi_cv_vals))
  199. create_dataset(h5name, '/units/%s' % unit_name, H5NAMES.isi_fano, np.array(isi_fn_vals))
  200. def write_spatial_metrics(sessionpath):
  201. metric_names = (H5NAMES.o_maps, H5NAMES.f_maps, H5NAMES.sparsity, H5NAMES.selectivity, \
  202. H5NAMES.spat_info, H5NAMES.peak_FR, H5NAMES.f_patches, H5NAMES.f_COM, \
  203. H5NAMES.pfr_center, H5NAMES.occ_info, H5NAMES.o_patches, H5NAMES.o_COM)
  204. xy_range = [-0.5, 0.5, -0.5, 0.5] # make fixed for cross-comparisons
  205. bin_size = 0.02
  206. filebase = os.path.basename(os.path.normpath(sessionpath))
  207. h5name = os.path.join(sessionpath, filebase + '.h5')
  208. epochs = cleaned_epochs(sessionpath)
  209. with h5py.File(h5name, 'r') as f:
  210. tl = np.array(f['processed']['timeline']) # time, X, Y, speed, etc.
  211. unit_names = [name for name in f['units']]
  212. s_rate_pos = round(1.0 / np.diff(tl[:, 0]).mean())
  213. run_idxs = np.where(tl[:, 3] > 0.04)[0]
  214. for unit_name in unit_names:
  215. with h5py.File(h5name, 'r') as f:
  216. spk_idxs = np.array(f['units'][unit_name][H5NAMES.spike_idxs['name']])
  217. spk_idxs = np.intersect1d(spk_idxs, run_idxs)
  218. collected = []
  219. for epoch in epochs:
  220. epoch_idxs = np.where((tl[:, 0] > epoch[0]) & (tl[:, 0] < epoch[1]))[0]
  221. # filter for epoch and speed > 4cm/s
  222. unit_pos = tl[np.intersect1d(spk_idxs, epoch_idxs)][:, 1:3]
  223. traj_pos = tl[epoch_idxs][:, 1:3]
  224. # compute 2D maps: occupancy and firing rate (place fields)
  225. #xy_range = [tl[:, 1].min(), tl[:, 1].max(), tl[:, 2].min(), tl[:, 2].max()]
  226. o_map, s1_map, s2_map, f_map = place_field_2D(traj_pos, unit_pos, s_rate_pos, bin_size=bin_size, xy_range=xy_range)
  227. # firing map metrics
  228. sparsity, selectivity, spat_info, peak_FR = map_stats(f_map, o_map)
  229. # place field metrics
  230. patches = get_field_patches(f_map) # 2D matrix, patches labeled according to the size
  231. #f_sizes = np.bincount(patches.flat)[1:] # 1D array of field sizes, sorted
  232. if f_map.max() == 0:
  233. f_COM_rho, f_COM_phi, pfr_rho, pfr_phi = 0, 0, 0, 0
  234. else:
  235. x_in_b, y_in_b = ndi.center_of_mass(f_map, labels=patches, index=1) # largest field COM, in bins
  236. f_COM_rho, f_COM_phi = cart2pol(*bins2meters(x_in_b, y_in_b, xy_range)) # largest field COM, in polar coords.
  237. x, y = np.where(f_map == np.max(f_map)) # location of the peak unit firing, in bins
  238. pfr_rho, pfr_phi = cart2pol(*bins2meters(x[0], y[0], xy_range)) # location of the peak unit firing, in polar
  239. # same for occupancy
  240. _, _, occ_info, _ = map_stats(o_map, o_map)
  241. o_patches = get_field_patches(o_map) # 2D matrix, patches labeled according to the size
  242. x, y = ndi.center_of_mass(o_map, labels=o_patches, index=1) # largest field COM, in bins
  243. o_COM_rho, o_COM_phi = cart2pol(*bins2meters(x, y, xy_range)) # largest field COM, in polar coords.
  244. # order should match metric_names defined above
  245. collected.append([o_map, f_map, sparsity, selectivity, spat_info, peak_FR, \
  246. patches, (f_COM_rho, f_COM_phi), (pfr_rho, pfr_phi), occ_info, \
  247. o_patches, (o_COM_rho, o_COM_phi)])
  248. for i in range(len(collected[0])): # iterate over metrics
  249. dataset = np.array([x[i] for x in collected]) # metric data for each epoch
  250. create_dataset(h5name, '/units/%s' % unit_name, metric_names[i], dataset)
  251. def write_best_match_rotation(sessionpath):
  252. filebase = os.path.basename(os.path.normpath(sessionpath))
  253. h5name = os.path.join(sessionpath, filebase + '.h5')
  254. with h5py.File(h5name, 'r') as f:
  255. unit_names = [name for name in f['units']]
  256. for unit_name in unit_names:
  257. # assuming maps for first 3 epochs are already there
  258. with h5py.File(h5name, 'r') as f:
  259. maps = np.array(f['units'][unit_name][H5NAMES.f_maps['name']])
  260. corr_profiles = np.zeros((3, 360))
  261. for i, idxs in enumerate([(0, 1), (1, 2), (0, 2)]):
  262. corr_profiles[i], phi = best_match_rotation_polar(maps[idxs[0]], maps[idxs[1]])
  263. create_dataset(h5name, '/units/%s' % unit_name, H5NAMES.best_m_rot, corr_profiles)