nixrawio.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. """
  2. RawIO Class for NIX files
  3. The RawIO assumes all segments and all blocks have the same structure.
  4. It supports all kinds of NEO objects.
  5. Author: Chek Yin Choi
  6. """
  7. from __future__ import print_function, division, absolute_import
  8. from neo.rawio.baserawio import (BaseRawIO, _signal_channel_dtype,
  9. _unit_channel_dtype, _event_channel_dtype)
  10. import numpy as np
  11. try:
  12. import nixio as nix
  13. HAVE_NIX = True
  14. except ImportError:
  15. HAVE_NIX = False
  16. nix = None
  17. class NIXRawIO(BaseRawIO):
  18. extensions = ['nix']
  19. rawmode = 'one-file'
  20. def __init__(self, filename=''):
  21. BaseRawIO.__init__(self)
  22. self.filename = filename
  23. def _source_name(self):
  24. return self.filename
  25. def _parse_header(self):
  26. self.file = nix.File.open(self.filename, nix.FileMode.ReadOnly)
  27. sig_channels = []
  28. size_list = []
  29. for bl in self.file.blocks:
  30. for seg in bl.groups:
  31. for da_idx, da in enumerate(seg.data_arrays):
  32. if da.type == "neo.analogsignal":
  33. chan_id = da_idx
  34. ch_name = da.metadata['neo_name']
  35. units = str(da.unit)
  36. dtype = str(da.dtype)
  37. sr = 1 / da.dimensions[0].sampling_interval
  38. da_leng = da.size
  39. if da_leng not in size_list:
  40. size_list.append(da_leng)
  41. group_id = 0
  42. for sid, li_leng in enumerate(size_list):
  43. if li_leng == da_leng:
  44. group_id = sid
  45. # very important! group_id use to store channel groups!!!
  46. # use only for different signal length
  47. gain = 1
  48. offset = 0.
  49. sig_channels.append((ch_name, chan_id, sr, dtype,
  50. units, gain, offset, group_id))
  51. break
  52. break
  53. sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
  54. unit_channels = []
  55. unit_name = ""
  56. unit_id = ""
  57. for bl in self.file.blocks:
  58. for seg in bl.groups:
  59. for mt in seg.multi_tags:
  60. if mt.type == "neo.spiketrain":
  61. unit_name = mt.metadata['neo_name']
  62. unit_id = mt.id
  63. if mt.features:
  64. wf_units = mt.features[0].data.unit
  65. wf_sampling_rate = 1 / mt.features[0].data.dimensions[
  66. 2].sampling_interval
  67. else:
  68. wf_units = None
  69. wf_sampling_rate = 0
  70. wf_gain = 1
  71. wf_offset = 0.
  72. if mt.features and "left_sweep" in mt.features[0].data.metadata:
  73. wf_left_sweep = mt.features[0].data.metadata["left_sweep"]
  74. else:
  75. wf_left_sweep = 0
  76. unit_channels.append((unit_name, unit_id, wf_units, wf_gain,
  77. wf_offset, wf_left_sweep, wf_sampling_rate))
  78. break
  79. break
  80. unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
  81. event_channels = []
  82. event_count = 0
  83. epoch_count = 0
  84. for bl in self.file.blocks:
  85. for seg in bl.groups:
  86. for mt in seg.multi_tags:
  87. if mt.type == "neo.event":
  88. ev_name = mt.metadata['neo_name']
  89. ev_id = event_count
  90. event_count += 1
  91. ev_type = "event"
  92. event_channels.append((ev_name, ev_id, ev_type))
  93. if mt.type == "neo.epoch":
  94. ep_name = mt.metadata['neo_name']
  95. ep_id = epoch_count
  96. epoch_count += 1
  97. ep_type = "epoch"
  98. event_channels.append((ep_name, ep_id, ep_type))
  99. break
  100. break
  101. event_channels = np.array(event_channels, dtype=_event_channel_dtype)
  102. self.da_list = {'blocks': []}
  103. for block_index, blk in enumerate(self.file.blocks):
  104. d = {'segments': []}
  105. self.da_list['blocks'].append(d)
  106. for seg_index, seg in enumerate(blk.groups):
  107. d = {'signals': []}
  108. self.da_list['blocks'][block_index]['segments'].append(d)
  109. size_list = []
  110. data_list = []
  111. da_name_list = []
  112. for da in seg.data_arrays:
  113. if da.type == 'neo.analogsignal':
  114. size_list.append(da.size)
  115. data_list.append(da)
  116. da_name_list.append(da.metadata['neo_name'])
  117. self.da_list['blocks'][block_index]['segments'][seg_index]['data_size'] = size_list
  118. self.da_list['blocks'][block_index]['segments'][seg_index]['data'] = data_list
  119. self.da_list['blocks'][block_index]['segments'][seg_index]['ch_name'] = \
  120. da_name_list
  121. self.unit_list = {'blocks': []}
  122. for block_index, blk in enumerate(self.file.blocks):
  123. d = {'segments': []}
  124. self.unit_list['blocks'].append(d)
  125. for seg_index, seg in enumerate(blk.groups):
  126. d = {'spiketrains': [], 'spiketrains_id': [], 'spiketrains_unit': []}
  127. self.unit_list['blocks'][block_index]['segments'].append(d)
  128. st_idx = 0
  129. for st in seg.multi_tags:
  130. d = {'waveforms': []}
  131. self.unit_list[
  132. 'blocks'][block_index]['segments'][seg_index]['spiketrains_unit'].append(d)
  133. if st.type == 'neo.spiketrain':
  134. seg = self.unit_list['blocks'][block_index]['segments'][seg_index]
  135. seg['spiketrains'].append(st.positions)
  136. seg['spiketrains_id'].append(st.id)
  137. if st.features and st.features[0].data.type == "neo.waveforms":
  138. waveforms = st.features[0].data
  139. if waveforms:
  140. seg['spiketrains_unit'][st_idx]['waveforms'] = waveforms
  141. else:
  142. seg['spiketrains_unit'][st_idx]['waveforms'] = None
  143. # assume one spiketrain one waveform
  144. st_idx += 1
  145. self.header = {}
  146. self.header['nb_block'] = len(self.file.blocks)
  147. self.header['nb_segment'] = [len(bl.groups) for bl in self.file.blocks]
  148. self.header['signal_channels'] = sig_channels
  149. self.header['unit_channels'] = unit_channels
  150. self.header['event_channels'] = event_channels
  151. self._generate_minimal_annotations()
  152. def _segment_t_start(self, block_index, seg_index):
  153. t_start = 0
  154. for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
  155. if mt.type == "neo.spiketrain":
  156. t_start = mt.metadata['t_start']
  157. return t_start
  158. def _segment_t_stop(self, block_index, seg_index):
  159. t_stop = 0
  160. for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
  161. if mt.type == "neo.spiketrain":
  162. t_stop = mt.metadata['t_stop']
  163. return t_stop
  164. def _get_signal_size(self, block_index, seg_index, channel_indexes):
  165. if channel_indexes is None:
  166. channel_indexes = list(range(self.header['signal_channels'].size))
  167. ch_idx = channel_indexes[0]
  168. size = self.da_list['blocks'][block_index]['segments'][seg_index]['data_size'][ch_idx]
  169. return size # size is per signal, not the sum of all channel_indexes
  170. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  171. if channel_indexes is None:
  172. channel_indexes = list(range(self.header['signal_channels'].size))
  173. ch_idx = channel_indexes[0]
  174. da = [da for da in self.file.blocks[block_index].groups[seg_index].data_arrays][ch_idx]
  175. sig_t_start = float(da.metadata['t_start'])
  176. return sig_t_start # assume same group_id always same t_start
  177. def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
  178. if channel_indexes is None:
  179. channel_indexes = list(range(self.header['signal_channels'].size))
  180. if i_start is None:
  181. i_start = 0
  182. if i_stop is None:
  183. for c in channel_indexes:
  184. i_stop = self.da_list['blocks'][block_index]['segments'][seg_index]['data_size'][c]
  185. break
  186. raw_signals_list = []
  187. da_list = self.da_list['blocks'][block_index]['segments'][seg_index]
  188. for idx in channel_indexes:
  189. da = da_list['data'][idx]
  190. raw_signals_list.append(da[i_start:i_stop])
  191. raw_signals = np.array(raw_signals_list)
  192. raw_signals = np.transpose(raw_signals)
  193. return raw_signals
  194. def _spike_count(self, block_index, seg_index, unit_index):
  195. count = 0
  196. head_id = self.header['unit_channels'][unit_index][1]
  197. for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
  198. for src in mt.sources:
  199. if mt.type == 'neo.spiketrain' and [src.type == "neo.unit"]:
  200. if head_id == src.id:
  201. return len(mt.positions)
  202. return count
  203. def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
  204. spike_dict = self.unit_list['blocks'][block_index]['segments'][seg_index]['spiketrains']
  205. spike_timestamps = spike_dict[unit_index]
  206. spike_timestamps = np.transpose(spike_timestamps)
  207. if t_start is not None or t_stop is not None:
  208. lim0 = t_start
  209. lim1 = t_stop
  210. mask = (spike_timestamps >= lim0) & (spike_timestamps <= lim1)
  211. spike_timestamps = spike_timestamps[mask]
  212. return spike_timestamps
  213. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  214. spike_times = spike_timestamps.astype(dtype)
  215. return spike_times
  216. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
  217. # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample)
  218. seg = self.unit_list['blocks'][block_index]['segments'][seg_index]
  219. waveforms = seg['spiketrains_unit'][unit_index]['waveforms']
  220. if not waveforms:
  221. return None
  222. raw_waveforms = np.array(waveforms)
  223. if t_start is not None:
  224. lim0 = t_start
  225. mask = (raw_waveforms >= lim0)
  226. raw_waveforms = np.where(mask, raw_waveforms, np.nan) # use nan to keep the shape
  227. if t_stop is not None:
  228. lim1 = t_stop
  229. mask = (raw_waveforms <= lim1)
  230. raw_waveforms = np.where(mask, raw_waveforms, np.nan)
  231. return raw_waveforms
  232. def _event_count(self, block_index, seg_index, event_channel_index):
  233. event_count = 0
  234. for event in self.file.blocks[block_index].groups[seg_index].multi_tags:
  235. if event.type == 'neo.event' or event.type == 'neo.epoch':
  236. if event_count == event_channel_index:
  237. return len(event.positions)
  238. else:
  239. event_count += 1
  240. return event_count
  241. def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
  242. timestamp = []
  243. labels = []
  244. durations = None
  245. if event_channel_index is None:
  246. raise IndexError
  247. for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
  248. if mt.type == "neo.event" or mt.type == "neo.epoch":
  249. labels.append(mt.positions.dimensions[0].labels)
  250. po = mt.positions
  251. if po.type == "neo.event.times" or po.type == "neo.epoch.times":
  252. timestamp.append(po)
  253. if self.header['event_channels'][event_channel_index]['type'] == b'epoch' \
  254. and mt.extents:
  255. if mt.extents.type == 'neo.epoch.durations':
  256. durations = np.array(mt.extents)
  257. break
  258. timestamp = timestamp[event_channel_index][:]
  259. timestamp = np.array(timestamp, dtype="float")
  260. labels = labels[event_channel_index][:]
  261. labels = np.array(labels, dtype='U')
  262. if t_start is not None:
  263. keep = timestamp >= t_start
  264. timestamp, labels = timestamp[keep], labels[keep]
  265. if t_stop is not None:
  266. keep = timestamp <= t_stop
  267. timestamp, labels = timestamp[keep], labels[keep]
  268. return timestamp, durations, labels # only the first fits in rescale
  269. def _rescale_event_timestamp(self, event_timestamps, dtype='float64'):
  270. ev_unit = ''
  271. for mt in self.file.blocks[0].groups[0].multi_tags:
  272. if mt.type == "neo.event":
  273. ev_unit = mt.positions.unit
  274. break
  275. if ev_unit == 'ms':
  276. event_timestamps /= 1000
  277. event_times = event_timestamps.astype(dtype)
  278. # supposing unit is second, other possibilities maybe mS microS...
  279. return event_times # return in seconds
  280. def _rescale_epoch_duration(self, raw_duration, dtype='float64'):
  281. ep_unit = ''
  282. for mt in self.file.blocks[0].groups[0].multi_tags:
  283. if mt.type == "neo.epoch":
  284. ep_unit = mt.positions.unit
  285. break
  286. if ep_unit == 'ms':
  287. raw_duration /= 1000
  288. durations = raw_duration.astype(dtype)
  289. # supposing unit is second, other possibilities maybe mS microS...
  290. return durations # return in seconds