test_nixio.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) 2016, German Neuroinformatics Node (G-Node)
  3. # Achilleas Koutsou <achilleas.k@gmail.com>
  4. #
  5. # All rights reserved.
  6. #
  7. # Redistribution and use in source and binary forms, with or without
  8. # modification, are permitted under the terms of the BSD License. See
  9. # LICENSE file in the root of the Project.
  10. """
  11. Tests for neo.io.nixio
  12. """
  13. import os
  14. from datetime import datetime
  15. import unittest
  16. try:
  17. from unittest import mock
  18. except ImportError:
  19. import mock
  20. import string
  21. import numpy as np
  22. import quantities as pq
  23. from neo.core import (Block, Segment, ChannelIndex, AnalogSignal,
  24. IrregularlySampledSignal, Unit, SpikeTrain, Event, Epoch)
  25. from neo.test.iotest.common_io_test import BaseTestIO
  26. try:
  27. import nixio as nix
  28. HAVE_NIX = True
  29. except ImportError:
  30. HAVE_NIX = False
  31. from neo.io.nixio import NixIO
  32. from neo.io.nixio import string_types
  33. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  34. class NixIOTest(unittest.TestCase):
  35. filename = None
  36. io = None
  37. def compare_blocks(self, neoblocks, nixblocks):
  38. for neoblock, nixblock in zip(neoblocks, nixblocks):
  39. self.compare_attr(neoblock, nixblock)
  40. self.assertEqual(len(neoblock.segments), len(nixblock.groups))
  41. for idx, neoseg in enumerate(neoblock.segments):
  42. nixgrp = nixblock.groups[neoseg.annotations["nix_name"]]
  43. self.compare_segment_group(neoseg, nixgrp)
  44. for idx, neochx in enumerate(neoblock.channel_indexes):
  45. nixsrc = nixblock.sources[neochx.annotations["nix_name"]]
  46. self.compare_chx_source(neochx, nixsrc)
  47. self.check_refs(neoblock, nixblock)
  48. def compare_chx_source(self, neochx, nixsrc):
  49. self.compare_attr(neochx, nixsrc)
  50. nix_channels = list(src for src in nixsrc.sources
  51. if src.type == "neo.channelindex")
  52. self.assertEqual(len(neochx.index), len(nix_channels))
  53. if len(neochx.channel_ids):
  54. nix_chanids = list(src.metadata["channel_id"] for src
  55. in nixsrc.sources
  56. if src.type == "neo.channelindex")
  57. self.assertEqual(len(neochx.channel_ids), len(nix_chanids))
  58. for nixchan in nix_channels:
  59. nixchanidx = nixchan.metadata["index"]
  60. try:
  61. neochanpos = list(neochx.index).index(nixchanidx)
  62. except ValueError:
  63. self.fail("Channel indexes do not match.")
  64. if len(neochx.channel_names):
  65. neochanname = neochx.channel_names[neochanpos]
  66. if ((not isinstance(neochanname, str)) and
  67. isinstance(neochanname, bytes)):
  68. neochanname = neochanname.decode()
  69. nixchanname = nixchan.metadata["neo_name"]
  70. self.assertEqual(neochanname, nixchanname)
  71. if len(neochx.channel_ids):
  72. neochanid = neochx.channel_ids[neochanpos]
  73. nixchanid = nixchan.metadata["channel_id"]
  74. self.assertEqual(neochanid, nixchanid)
  75. elif "channel_id" in nixchan.metadata:
  76. self.fail("Channel ID not loaded")
  77. nix_units = list(src for src in nixsrc.sources
  78. if src.type == "neo.unit")
  79. self.assertEqual(len(neochx.units), len(nix_units))
  80. for neounit in neochx.units:
  81. nixunit = nixsrc.sources[neounit.annotations["nix_name"]]
  82. self.compare_attr(neounit, nixunit)
  83. def check_refs(self, neoblock, nixblock):
  84. """
  85. Checks whether the references between objects that are not nested are
  86. mapped correctly (e.g., SpikeTrains referenced by a Unit).
  87. :param neoblock: A Neo block
  88. :param nixblock: The corresponding NIX block
  89. """
  90. for idx, neochx in enumerate(neoblock.channel_indexes):
  91. nixchx = nixblock.sources[neochx.annotations["nix_name"]]
  92. # AnalogSignals referencing CHX
  93. neoasigs = list(sig.annotations["nix_name"]
  94. for sig in neochx.analogsignals)
  95. nixasigs = list(set(da.metadata.name for da in nixblock.data_arrays
  96. if da.type == "neo.analogsignal" and
  97. nixchx in da.sources))
  98. self.assertEqual(len(neoasigs), len(nixasigs),
  99. neochx.analogsignals)
  100. # IrregularlySampledSignals referencing CHX
  101. neoisigs = list(sig.annotations["nix_name"] for sig in
  102. neochx.irregularlysampledsignals)
  103. nixisigs = list(
  104. set(da.metadata.name for da in nixblock.data_arrays
  105. if da.type == "neo.irregularlysampledsignal" and
  106. nixchx in da.sources)
  107. )
  108. self.assertEqual(len(neoisigs), len(nixisigs))
  109. # SpikeTrains referencing CHX and Units
  110. for sidx, neounit in enumerate(neochx.units):
  111. nixunit = nixchx.sources[neounit.annotations["nix_name"]]
  112. neosts = list(st.annotations["nix_name"]
  113. for st in neounit.spiketrains)
  114. nixsts = list(mt for mt in nixblock.multi_tags
  115. if mt.type == "neo.spiketrain" and
  116. nixunit.name in mt.sources)
  117. # SpikeTrains must also reference CHX
  118. for nixst in nixsts:
  119. self.assertIn(nixchx.name, nixst.sources)
  120. nixsts = list(st.name for st in nixsts)
  121. self.assertEqual(len(neosts), len(nixsts))
  122. for neoname in neosts:
  123. if neoname:
  124. self.assertIn(neoname, nixsts)
  125. # Events and Epochs must reference all Signals in the Group (NIX only)
  126. for nixgroup in nixblock.groups:
  127. nixevep = list(mt for mt in nixgroup.multi_tags
  128. if mt.type in ["neo.event", "neo.epoch"])
  129. nixsigs = list(da.name for da in nixgroup.data_arrays
  130. if da.type in ["neo.analogsignal",
  131. "neo.irregularlysampledsignal"])
  132. for nee in nixevep:
  133. for ns in nixsigs:
  134. self.assertIn(ns, nee.references)
  135. def compare_segment_group(self, neoseg, nixgroup):
  136. self.compare_attr(neoseg, nixgroup)
  137. neo_signals = neoseg.analogsignals + neoseg.irregularlysampledsignals
  138. self.compare_signals_das(neo_signals, nixgroup.data_arrays)
  139. neo_eests = neoseg.epochs + neoseg.events + neoseg.spiketrains
  140. self.compare_eests_mtags(neo_eests, nixgroup.multi_tags)
  141. def compare_signals_das(self, neosignals, data_arrays):
  142. for sig in neosignals:
  143. if self.io._find_lazy_loaded(sig) is not None:
  144. sig = self.io.load_lazy_object(sig)
  145. dalist = list()
  146. nixname = sig.annotations["nix_name"]
  147. for da in data_arrays:
  148. if da.metadata.name == nixname:
  149. dalist.append(da)
  150. _, nsig = np.shape(sig)
  151. self.assertEqual(nsig, len(dalist))
  152. self.compare_signal_dalist(sig, dalist)
  153. def compare_signal_dalist(self, neosig, nixdalist):
  154. """
  155. Check if a Neo Analog or IrregularlySampledSignal matches a list of
  156. NIX DataArrays.
  157. :param neosig: Neo Analog or IrregularlySampledSignal
  158. :param nixdalist: List of DataArrays
  159. """
  160. nixmd = nixdalist[0].metadata
  161. self.assertTrue(all(nixmd == da.metadata for da in nixdalist))
  162. neounit = str(neosig.dimensionality)
  163. for sig, da in zip(np.transpose(neosig),
  164. sorted(nixdalist, key=lambda d: d.name)):
  165. self.compare_attr(neosig, da)
  166. np.testing.assert_almost_equal(sig.magnitude, da)
  167. self.assertEqual(neounit, da.unit)
  168. timedim = da.dimensions[0]
  169. if isinstance(neosig, AnalogSignal):
  170. self.assertEqual(timedim.dimension_type,
  171. nix.DimensionType.Sample)
  172. self.assertEqual(
  173. pq.Quantity(timedim.sampling_interval, timedim.unit),
  174. neosig.sampling_period
  175. )
  176. self.assertEqual(timedim.offset, neosig.t_start.magnitude)
  177. if "t_start.units" in da.metadata.props:
  178. self.assertEqual(da.metadata["t_start.units"],
  179. str(neosig.t_start.dimensionality))
  180. elif isinstance(neosig, IrregularlySampledSignal):
  181. self.assertEqual(timedim.dimension_type,
  182. nix.DimensionType.Range)
  183. np.testing.assert_almost_equal(neosig.times.magnitude,
  184. timedim.ticks)
  185. self.assertEqual(timedim.unit,
  186. str(neosig.times.dimensionality))
  187. def compare_eests_mtags(self, eestlist, mtaglist):
  188. self.assertEqual(len(eestlist), len(mtaglist))
  189. for eest in eestlist:
  190. if self.io._find_lazy_loaded(eest) is not None:
  191. eest = self.io.load_lazy_object(eest)
  192. mtag = mtaglist[eest.annotations["nix_name"]]
  193. if isinstance(eest, Epoch):
  194. self.compare_epoch_mtag(eest, mtag)
  195. elif isinstance(eest, Event):
  196. self.compare_event_mtag(eest, mtag)
  197. elif isinstance(eest, SpikeTrain):
  198. self.compare_spiketrain_mtag(eest, mtag)
  199. def compare_epoch_mtag(self, epoch, mtag):
  200. self.assertEqual(mtag.type, "neo.epoch")
  201. self.compare_attr(epoch, mtag)
  202. np.testing.assert_almost_equal(epoch.times.magnitude, mtag.positions)
  203. np.testing.assert_almost_equal(epoch.durations.magnitude, mtag.extents)
  204. self.assertEqual(mtag.positions.unit,
  205. str(epoch.times.units.dimensionality))
  206. self.assertEqual(mtag.extents.unit,
  207. str(epoch.durations.units.dimensionality))
  208. for neol, nixl in zip(epoch.labels,
  209. mtag.positions.dimensions[0].labels):
  210. # Dirty. Should find the root cause instead
  211. if isinstance(neol, bytes):
  212. neol = neol.decode()
  213. if isinstance(nixl, bytes):
  214. nixl = nixl.decode()
  215. self.assertEqual(neol, nixl)
  216. def compare_event_mtag(self, event, mtag):
  217. self.assertEqual(mtag.type, "neo.event")
  218. self.compare_attr(event, mtag)
  219. np.testing.assert_almost_equal(event.times.magnitude, mtag.positions)
  220. self.assertEqual(mtag.positions.unit, str(event.units.dimensionality))
  221. for neol, nixl in zip(event.labels,
  222. mtag.positions.dimensions[0].labels):
  223. # Dirty. Should find the root cause instead
  224. # Only happens in 3.2
  225. if isinstance(neol, bytes):
  226. neol = neol.decode()
  227. if isinstance(nixl, bytes):
  228. nixl = nixl.decode()
  229. self.assertEqual(neol, nixl)
  230. def compare_spiketrain_mtag(self, spiketrain, mtag):
  231. self.assertEqual(mtag.type, "neo.spiketrain")
  232. self.compare_attr(spiketrain, mtag)
  233. np.testing.assert_almost_equal(spiketrain.times.magnitude,
  234. mtag.positions)
  235. if len(mtag.features):
  236. neowf = spiketrain.waveforms
  237. nixwf = mtag.features[0].data
  238. self.assertEqual(np.shape(neowf), np.shape(nixwf))
  239. self.assertEqual(nixwf.unit, str(neowf.units.dimensionality))
  240. np.testing.assert_almost_equal(neowf.magnitude, nixwf)
  241. self.assertEqual(nixwf.dimensions[0].dimension_type,
  242. nix.DimensionType.Set)
  243. self.assertEqual(nixwf.dimensions[1].dimension_type,
  244. nix.DimensionType.Set)
  245. self.assertEqual(nixwf.dimensions[2].dimension_type,
  246. nix.DimensionType.Sample)
  247. def compare_attr(self, neoobj, nixobj):
  248. if isinstance(neoobj, (AnalogSignal, IrregularlySampledSignal)):
  249. nix_name = ".".join(nixobj.name.split(".")[:-1])
  250. else:
  251. nix_name = nixobj.name
  252. self.assertEqual(neoobj.annotations["nix_name"], nix_name)
  253. self.assertEqual(neoobj.description, nixobj.definition)
  254. if hasattr(neoobj, "rec_datetime") and neoobj.rec_datetime:
  255. self.assertEqual(neoobj.rec_datetime,
  256. datetime.fromtimestamp(nixobj.created_at))
  257. if hasattr(neoobj, "file_datetime") and neoobj.file_datetime:
  258. self.assertEqual(neoobj.file_datetime,
  259. datetime.fromtimestamp(
  260. nixobj.metadata["file_datetime"]))
  261. if neoobj.annotations:
  262. nixmd = nixobj.metadata
  263. for k, v, in neoobj.annotations.items():
  264. if k == "nix_name":
  265. continue
  266. if isinstance(v, pq.Quantity):
  267. self.assertEqual(nixmd.props[str(k)].unit,
  268. str(v.dimensionality))
  269. np.testing.assert_almost_equal(nixmd[str(k)],
  270. v.magnitude)
  271. else:
  272. self.assertEqual(nixmd[str(k)], v)
  273. @classmethod
  274. def create_full_nix_file(cls, filename):
  275. nixfile = nix.File.open(filename, nix.FileMode.Overwrite,
  276. backend="h5py")
  277. nix_block_a = nixfile.create_block(cls.rword(10), "neo.block")
  278. nix_block_a.definition = cls.rsentence(5, 10)
  279. nix_block_b = nixfile.create_block(cls.rword(10), "neo.block")
  280. nix_block_b.definition = cls.rsentence(3, 3)
  281. nix_block_a.metadata = nixfile.create_section(
  282. nix_block_a.name, nix_block_a.name+".metadata"
  283. )
  284. nix_block_b.metadata = nixfile.create_section(
  285. nix_block_b.name, nix_block_b.name+".metadata"
  286. )
  287. nix_blocks = [nix_block_a, nix_block_b]
  288. for blk in nix_blocks:
  289. for ind in range(3):
  290. group = blk.create_group(cls.rword(), "neo.segment")
  291. group.definition = cls.rsentence(10, 15)
  292. group_md = blk.metadata.create_section(group.name,
  293. group.name+".metadata")
  294. group.metadata = group_md
  295. blk = nix_blocks[0]
  296. group = blk.groups[0]
  297. allspiketrains = list()
  298. allsignalgroups = list()
  299. # analogsignals
  300. for n in range(3):
  301. siggroup = list()
  302. asig_name = "{}_asig{}".format(cls.rword(10), n)
  303. asig_definition = cls.rsentence(5, 5)
  304. asig_md = group.metadata.create_section(asig_name,
  305. asig_name+".metadata")
  306. for idx in range(3):
  307. da_asig = blk.create_data_array(
  308. "{}.{}".format(asig_name, idx),
  309. "neo.analogsignal",
  310. data=cls.rquant(100, 1)
  311. )
  312. da_asig.definition = asig_definition
  313. da_asig.unit = "mV"
  314. da_asig.metadata = asig_md
  315. timedim = da_asig.append_sampled_dimension(0.01)
  316. timedim.unit = "ms"
  317. timedim.label = "time"
  318. timedim.offset = 10
  319. da_asig.append_set_dimension()
  320. group.data_arrays.append(da_asig)
  321. siggroup.append(da_asig)
  322. allsignalgroups.append(siggroup)
  323. # irregularlysampledsignals
  324. for n in range(2):
  325. siggroup = list()
  326. isig_name = "{}_isig{}".format(cls.rword(10), n)
  327. isig_definition = cls.rsentence(12, 12)
  328. isig_md = group.metadata.create_section(isig_name,
  329. isig_name+".metadata")
  330. isig_times = cls.rquant(200, 1, True)
  331. for idx in range(10):
  332. da_isig = blk.create_data_array(
  333. "{}.{}".format(isig_name, idx),
  334. "neo.irregularlysampledsignal",
  335. data=cls.rquant(200, 1)
  336. )
  337. da_isig.definition = isig_definition
  338. da_isig.unit = "mV"
  339. da_isig.metadata = isig_md
  340. timedim = da_isig.append_range_dimension(isig_times)
  341. timedim.unit = "s"
  342. timedim.label = "time"
  343. da_isig.append_set_dimension()
  344. group.data_arrays.append(da_isig)
  345. siggroup.append(da_isig)
  346. allsignalgroups.append(siggroup)
  347. # SpikeTrains with Waveforms
  348. for n in range(4):
  349. stname = "{}-st{}".format(cls.rword(20), n)
  350. times = cls.rquant(400, 1, True)
  351. times_da = blk.create_data_array(
  352. "{}.times".format(stname),
  353. "neo.spiketrain.times",
  354. data=times
  355. )
  356. times_da.unit = "ms"
  357. mtag_st = blk.create_multi_tag(stname,
  358. "neo.spiketrain",
  359. times_da)
  360. group.multi_tags.append(mtag_st)
  361. mtag_st.definition = cls.rsentence(20, 30)
  362. mtag_st_md = group.metadata.create_section(
  363. mtag_st.name, mtag_st.name+".metadata"
  364. )
  365. mtag_st.metadata = mtag_st_md
  366. mtag_st_md.create_property(
  367. "t_stop", nix.Value(times[-1]+1.0)
  368. )
  369. waveforms = cls.rquant((10, 8, 5), 1)
  370. wfname = "{}.waveforms".format(mtag_st.name)
  371. wfda = blk.create_data_array(wfname, "neo.waveforms",
  372. data=waveforms)
  373. wfda.unit = "mV"
  374. mtag_st.create_feature(wfda, nix.LinkType.Indexed)
  375. wfda.append_set_dimension() # spike dimension
  376. wfda.append_set_dimension() # channel dimension
  377. wftimedim = wfda.append_sampled_dimension(0.1)
  378. wftimedim.unit = "ms"
  379. wftimedim.label = "time"
  380. wfda.metadata = mtag_st_md.create_section(
  381. wfname, "neo.waveforms.metadata"
  382. )
  383. wfda.metadata.create_property("left_sweep",
  384. [nix.Value(20)]*5)
  385. allspiketrains.append(mtag_st)
  386. # Epochs
  387. for n in range(3):
  388. epname = "{}-ep{}".format(cls.rword(5), n)
  389. times = cls.rquant(5, 1, True)
  390. times_da = blk.create_data_array(
  391. "{}.times".format(epname),
  392. "neo.epoch.times",
  393. data=times
  394. )
  395. times_da.unit = "s"
  396. extents = cls.rquant(5, 1)
  397. extents_da = blk.create_data_array(
  398. "{}.durations".format(epname),
  399. "neo.epoch.durations",
  400. data=extents
  401. )
  402. extents_da.unit = "s"
  403. mtag_ep = blk.create_multi_tag(
  404. epname, "neo.epoch", times_da
  405. )
  406. mtag_ep.metadata = group.metadata.create_section(
  407. epname, epname+".metadata"
  408. )
  409. group.multi_tags.append(mtag_ep)
  410. mtag_ep.definition = cls.rsentence(2)
  411. mtag_ep.extents = extents_da
  412. label_dim = mtag_ep.positions.append_set_dimension()
  413. label_dim.labels = cls.rsentence(5).split(" ")
  414. # reference all signals in the group
  415. for siggroup in allsignalgroups:
  416. mtag_ep.references.extend(siggroup)
  417. # Events
  418. for n in range(2):
  419. evname = "{}-ev{}".format(cls.rword(5), n)
  420. times = cls.rquant(5, 1, True)
  421. times_da = blk.create_data_array(
  422. "{}.times".format(evname),
  423. "neo.event.times",
  424. data=times
  425. )
  426. times_da.unit = "s"
  427. mtag_ev = blk.create_multi_tag(
  428. evname, "neo.event", times_da
  429. )
  430. mtag_ev.metadata = group.metadata.create_section(
  431. evname, evname+".metadata"
  432. )
  433. group.multi_tags.append(mtag_ev)
  434. mtag_ev.definition = cls.rsentence(2)
  435. label_dim = mtag_ev.positions.append_set_dimension()
  436. label_dim.labels = cls.rsentence(5).split(" ")
  437. # reference all signals in the group
  438. for siggroup in allsignalgroups:
  439. mtag_ev.references.extend(siggroup)
  440. # CHX
  441. nixchx = blk.create_source(cls.rword(10),
  442. "neo.channelindex")
  443. nixchx.metadata = nix_blocks[0].metadata.create_section(
  444. nixchx.name, "neo.channelindex.metadata"
  445. )
  446. chantype = "neo.channelindex"
  447. # 3 channels
  448. for idx, chan in enumerate([2, 5, 9]):
  449. channame = "{}.ChannelIndex{}".format(nixchx.name, idx)
  450. nixrc = nixchx.create_source(channame, chantype)
  451. nixrc.definition = cls.rsentence(13)
  452. nixrc.metadata = nixchx.metadata.create_section(
  453. nixrc.name, "neo.channelindex.metadata"
  454. )
  455. nixrc.metadata.create_property("index", nix.Value(chan))
  456. nixrc.metadata.create_property("channel_id", nix.Value(chan+1))
  457. dims = tuple(map(nix.Value, cls.rquant(3, 1)))
  458. nixrc.metadata.create_property("coordinates", dims)
  459. nixrc.metadata.create_property("coordinates.units",
  460. nix.Value("um"))
  461. nunits = 1
  462. stsperunit = np.array_split(allspiketrains, nunits)
  463. for idx in range(nunits):
  464. unitname = "{}-unit{}".format(cls.rword(5), idx)
  465. nixunit = nixchx.create_source(unitname, "neo.unit")
  466. nixunit.metadata = nixchx.metadata.create_section(
  467. unitname, unitname+".metadata"
  468. )
  469. nixunit.definition = cls.rsentence(4, 10)
  470. for st in stsperunit[idx]:
  471. st.sources.append(nixchx)
  472. st.sources.append(nixunit)
  473. # pick a few signal groups to reference this CHX
  474. randsiggroups = np.random.choice(allsignalgroups, 5, False)
  475. for siggroup in randsiggroups:
  476. for sig in siggroup:
  477. sig.sources.append(nixchx)
  478. return nixfile
  479. @staticmethod
  480. def rdate():
  481. return datetime(year=np.random.randint(1980, 2020),
  482. month=np.random.randint(1, 13),
  483. day=np.random.randint(1, 29))
  484. @classmethod
  485. def populate_dates(cls, obj):
  486. obj.file_datetime = cls.rdate()
  487. obj.rec_datetime = cls.rdate()
  488. @staticmethod
  489. def rword(n=10):
  490. return "".join(np.random.choice(list(string.ascii_letters), n))
  491. @classmethod
  492. def rsentence(cls, n=3, maxwl=10):
  493. return " ".join(cls.rword(np.random.randint(1, maxwl))
  494. for _ in range(n))
  495. @classmethod
  496. def rdict(cls, nitems):
  497. rd = dict()
  498. for _ in range(nitems):
  499. key = cls.rword()
  500. value = cls.rword() if np.random.choice((0, 1)) \
  501. else np.random.uniform()
  502. rd[key] = value
  503. return rd
  504. @staticmethod
  505. def rquant(shape, unit, incr=False):
  506. try:
  507. dim = len(shape)
  508. except TypeError:
  509. dim = 1
  510. if incr and dim > 1:
  511. raise TypeError("Shape of quantity array may only be "
  512. "one-dimensional when incremental values are "
  513. "requested.")
  514. arr = np.random.random(shape)
  515. if incr:
  516. arr = np.array(np.cumsum(arr))
  517. return arr*unit
  518. @classmethod
  519. def create_all_annotated(cls):
  520. times = cls.rquant(1, pq.s)
  521. signal = cls.rquant(1, pq.V)
  522. blk = Block()
  523. blk.annotate(**cls.rdict(3))
  524. seg = Segment()
  525. seg.annotate(**cls.rdict(4))
  526. blk.segments.append(seg)
  527. asig = AnalogSignal(signal=signal, sampling_rate=pq.Hz)
  528. asig.annotate(**cls.rdict(2))
  529. seg.analogsignals.append(asig)
  530. isig = IrregularlySampledSignal(times=times, signal=signal,
  531. time_units=pq.s)
  532. isig.annotate(**cls.rdict(2))
  533. seg.irregularlysampledsignals.append(isig)
  534. epoch = Epoch(times=times, durations=times)
  535. epoch.annotate(**cls.rdict(4))
  536. seg.epochs.append(epoch)
  537. event = Event(times=times)
  538. event.annotate(**cls.rdict(4))
  539. seg.events.append(event)
  540. spiketrain = SpikeTrain(times=times, t_stop=pq.s, units=pq.s)
  541. d = cls.rdict(6)
  542. d["quantity"] = pq.Quantity(10, "mV")
  543. d["qarray"] = pq.Quantity(range(10), "mA")
  544. spiketrain.annotate(**d)
  545. seg.spiketrains.append(spiketrain)
  546. chx = ChannelIndex(name="achx", index=[1, 2], channel_ids=[0, 10])
  547. chx.annotate(**cls.rdict(5))
  548. blk.channel_indexes.append(chx)
  549. unit = Unit()
  550. unit.annotate(**cls.rdict(2))
  551. chx.units.append(unit)
  552. return blk
  553. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  554. class NixIOWriteTest(NixIOTest):
  555. def setUp(self):
  556. self.filename = "nixio_testfile_write.h5"
  557. self.writer = NixIO(self.filename, "ow")
  558. self.io = self.writer
  559. self.reader = nix.File.open(self.filename,
  560. nix.FileMode.ReadOnly,
  561. backend="h5py")
  562. def tearDown(self):
  563. self.writer.close()
  564. self.reader.close()
  565. os.remove(self.filename)
  566. def write_and_compare(self, blocks):
  567. self.writer.write_all_blocks(blocks)
  568. self.compare_blocks(self.writer.read_all_blocks(), self.reader.blocks)
  569. self.compare_blocks(blocks, self.reader.blocks)
  570. def test_block_write(self):
  571. block = Block(name=self.rword(),
  572. description=self.rsentence())
  573. self.write_and_compare([block])
  574. block.annotate(**self.rdict(5))
  575. self.write_and_compare([block])
  576. def test_segment_write(self):
  577. block = Block(name=self.rword())
  578. segment = Segment(name=self.rword(), description=self.rword())
  579. block.segments.append(segment)
  580. self.write_and_compare([block])
  581. segment.annotate(**self.rdict(2))
  582. self.write_and_compare([block])
  583. def test_channel_index_write(self):
  584. block = Block(name=self.rword())
  585. chx = ChannelIndex(name=self.rword(),
  586. description=self.rsentence(),
  587. channel_ids=[10, 20, 30, 50, 80, 130],
  588. index=[1, 2, 3, 5, 8, 13])
  589. block.channel_indexes.append(chx)
  590. self.write_and_compare([block])
  591. chx.annotate(**self.rdict(3))
  592. self.write_and_compare([block])
  593. chx.channel_names = ["one", "two", "three", "five",
  594. "eight", "xiii"]
  595. self.write_and_compare([block])
  596. def test_signals_write(self):
  597. block = Block()
  598. seg = Segment()
  599. block.segments.append(seg)
  600. asig = AnalogSignal(signal=self.rquant((10, 3), pq.mV),
  601. sampling_rate=pq.Quantity(10, "Hz"))
  602. seg.analogsignals.append(asig)
  603. self.write_and_compare([block])
  604. anotherblock = Block("ir signal block")
  605. seg = Segment("ir signal seg")
  606. anotherblock.segments.append(seg)
  607. irsig = IrregularlySampledSignal(
  608. signal=np.random.random((20, 3)),
  609. times=self.rquant(20, pq.ms, True),
  610. units=pq.A
  611. )
  612. seg.irregularlysampledsignals.append(irsig)
  613. self.write_and_compare([block, anotherblock])
  614. block.segments[0].analogsignals.append(
  615. AnalogSignal(signal=[10.0, 1.0, 3.0], units=pq.S,
  616. sampling_period=pq.Quantity(3, "s"),
  617. dtype=np.double, name="signal42",
  618. description="this is an analogsignal",
  619. t_start=45 * pq.ms),
  620. )
  621. self.write_and_compare([block, anotherblock])
  622. block.segments[0].irregularlysampledsignals.append(
  623. IrregularlySampledSignal(times=np.random.random(10),
  624. signal=np.random.random((10, 3)),
  625. units="mV", time_units="s",
  626. dtype=np.float,
  627. name="some sort of signal",
  628. description="the signal is described")
  629. )
  630. self.write_and_compare([block, anotherblock])
  631. def test_epoch_write(self):
  632. block = Block()
  633. seg = Segment()
  634. block.segments.append(seg)
  635. epoch = Epoch(times=[1, 1, 10, 3]*pq.ms, durations=[3, 3, 3, 1]*pq.ms,
  636. labels=np.array(["one", "two", "three", "four"]),
  637. name="test epoch", description="an epoch for testing")
  638. seg.epochs.append(epoch)
  639. self.write_and_compare([block])
  640. def test_event_write(self):
  641. block = Block()
  642. seg = Segment()
  643. block.segments.append(seg)
  644. event = Event(times=np.arange(0, 30, 10)*pq.s,
  645. labels=np.array(["0", "1", "2"]),
  646. name="event name",
  647. description="event description")
  648. seg.events.append(event)
  649. self.write_and_compare([block])
  650. def test_spiketrain_write(self):
  651. block = Block()
  652. seg = Segment()
  653. block.segments.append(seg)
  654. spiketrain = SpikeTrain(times=[3, 4, 5]*pq.s, t_stop=10.0,
  655. name="spikes!", description="sssssspikes")
  656. seg.spiketrains.append(spiketrain)
  657. self.write_and_compare([block])
  658. waveforms = self.rquant((3, 5, 10), pq.mV)
  659. spiketrain = SpikeTrain(times=[1, 1.1, 1.2]*pq.ms, t_stop=1.5*pq.s,
  660. name="spikes with wf",
  661. description="spikes for waveform test",
  662. waveforms=waveforms)
  663. seg.spiketrains.append(spiketrain)
  664. self.write_and_compare([block])
  665. spiketrain.left_sweep = np.random.random(10)*pq.ms
  666. self.write_and_compare([block])
  667. def test_metadata_structure_write(self):
  668. neoblk = self.create_all_annotated()
  669. self.io.write_block(neoblk)
  670. blk = self.io.nix_file.blocks[0]
  671. blkmd = blk.metadata
  672. self.assertEqual(blk.name, blkmd.name)
  673. grp = blk.groups[0] # segment
  674. self.assertIn(grp.name, blkmd.sections)
  675. grpmd = blkmd.sections[grp.name]
  676. for da in grp.data_arrays: # signals
  677. name = ".".join(da.name.split(".")[:-1])
  678. self.assertIn(name, grpmd.sections)
  679. for mtag in grp.multi_tags: # spiketrains, events, and epochs
  680. self.assertIn(mtag.name, grpmd.sections)
  681. srcchx = blk.sources[0] # chx
  682. self.assertIn(srcchx.name, blkmd.sections)
  683. for srcunit in blk.sources: # units
  684. self.assertIn(srcunit.name, blkmd.sections)
  685. self.write_and_compare([neoblk])
  686. def test_anonymous_objects_write(self):
  687. nblocks = 2
  688. nsegs = 2
  689. nanasig = 4
  690. nirrseg = 2
  691. nepochs = 3
  692. nevents = 4
  693. nspiketrains = 3
  694. nchx = 5
  695. nunits = 10
  696. times = self.rquant(1, pq.s)
  697. signal = self.rquant(1, pq.V)
  698. blocks = []
  699. for blkidx in range(nblocks):
  700. blk = Block()
  701. blocks.append(blk)
  702. for segidx in range(nsegs):
  703. seg = Segment()
  704. blk.segments.append(seg)
  705. for anaidx in range(nanasig):
  706. seg.analogsignals.append(AnalogSignal(signal=signal,
  707. sampling_rate=pq.Hz))
  708. for irridx in range(nirrseg):
  709. seg.irregularlysampledsignals.append(
  710. IrregularlySampledSignal(times=times,
  711. signal=signal,
  712. time_units=pq.s)
  713. )
  714. for epidx in range(nepochs):
  715. seg.epochs.append(Epoch(times=times, durations=times))
  716. for evidx in range(nevents):
  717. seg.events.append(Event(times=times))
  718. for stidx in range(nspiketrains):
  719. seg.spiketrains.append(SpikeTrain(times=times,
  720. t_stop=times[-1]+pq.s,
  721. units=pq.s))
  722. for chidx in range(nchx):
  723. chx = ChannelIndex(name="chx{}".format(chidx),
  724. index=[1, 2],
  725. channel_ids=[11, 22])
  726. blk.channel_indexes.append(chx)
  727. for unidx in range(nunits):
  728. unit = Unit()
  729. chx.units.append(unit)
  730. self.writer.write_all_blocks(blocks)
  731. self.compare_blocks(blocks, self.reader.blocks)
  732. def test_multiref_write(self):
  733. blk = Block("blk1")
  734. signal = AnalogSignal(name="sig1", signal=[0, 1, 2], units="mV",
  735. sampling_period=pq.Quantity(1, "ms"))
  736. for idx in range(3):
  737. segname = "seg" + str(idx)
  738. seg = Segment(segname)
  739. blk.segments.append(seg)
  740. seg.analogsignals.append(signal)
  741. chidx = ChannelIndex([10, 20, 29])
  742. seg = blk.segments[0]
  743. st = SpikeTrain(name="choochoo", times=[10, 11, 80], t_stop=1000,
  744. units="s")
  745. seg.spiketrains.append(st)
  746. blk.channel_indexes.append(chidx)
  747. for idx in range(6):
  748. unit = Unit("unit" + str(idx))
  749. chidx.units.append(unit)
  750. unit.spiketrains.append(st)
  751. self.writer.write_block(blk)
  752. self.compare_blocks([blk], self.reader.blocks)
  753. def test_to_value(self):
  754. section = self.io.nix_file.create_section("Metadata value test",
  755. "Test")
  756. writeprop = self.io._write_property
  757. # quantity
  758. qvalue = pq.Quantity(10, "mV")
  759. writeprop(section, "qvalue", qvalue)
  760. self.assertEqual(section["qvalue"], 10)
  761. self.assertEqual(section.props["qvalue"].unit, "mV")
  762. # datetime
  763. dt = self.rdate()
  764. writeprop(section, "dt", dt)
  765. self.assertEqual(datetime.fromtimestamp(section["dt"]), dt)
  766. # string
  767. randstr = self.rsentence()
  768. writeprop(section, "randstr", randstr)
  769. self.assertEqual(section["randstr"], randstr)
  770. # bytes
  771. bytestring = b"bytestring"
  772. writeprop(section, "randbytes", bytestring)
  773. self.assertEqual(section["randbytes"], bytestring.decode())
  774. # iterables
  775. randlist = np.random.random(10).tolist()
  776. writeprop(section, "randlist", randlist)
  777. self.assertEqual(randlist, section["randlist"])
  778. randarray = np.random.random(10)
  779. writeprop(section, "randarray", randarray)
  780. np.testing.assert_almost_equal(randarray, section["randarray"])
  781. # numpy item
  782. npval = np.float64(2398)
  783. writeprop(section, "npval", npval)
  784. self.assertEqual(npval, section["npval"])
  785. # number
  786. val = 42
  787. writeprop(section, "val", val)
  788. self.assertEqual(val, section["val"])
  789. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  790. class NixIOReadTest(NixIOTest):
  791. filename = "testfile_readtest.h5"
  792. nixfile = None
  793. nix_blocks = None
  794. original_methods = dict()
  795. @classmethod
  796. def setUpClass(cls):
  797. if HAVE_NIX:
  798. cls.nixfile = cls.create_full_nix_file(cls.filename)
  799. def setUp(self):
  800. self.io = NixIO(self.filename, "ro")
  801. self.original_methods["_read_cascade"] = self.io._read_cascade
  802. self.original_methods["_update_maps"] = self.io._update_maps
  803. @classmethod
  804. def tearDownClass(cls):
  805. if HAVE_NIX:
  806. cls.nixfile.close()
  807. os.remove(cls.filename)
  808. def tearDown(self):
  809. self.io.close()
  810. def test_all_read(self):
  811. neo_blocks = self.io.read_all_blocks(cascade=True, lazy=False)
  812. nix_blocks = self.io.nix_file.blocks
  813. self.compare_blocks(neo_blocks, nix_blocks)
  814. def test_lazyload_fullcascade_read(self):
  815. neo_blocks = self.io.read_all_blocks(cascade=True, lazy=True)
  816. nix_blocks = self.io.nix_file.blocks
  817. # data objects should be empty
  818. for block in neo_blocks:
  819. for seg in block.segments:
  820. for asig in seg.analogsignals:
  821. self.assertEqual(len(asig), 0)
  822. for isig in seg.irregularlysampledsignals:
  823. self.assertEqual(len(isig), 0)
  824. for epoch in seg.epochs:
  825. self.assertEqual(len(epoch), 0)
  826. for event in seg.events:
  827. self.assertEqual(len(event), 0)
  828. for st in seg.spiketrains:
  829. self.assertEqual(len(st), 0)
  830. self.compare_blocks(neo_blocks, nix_blocks)
  831. def test_lazyload_lazycascade_read(self):
  832. neo_blocks = self.io.read_all_blocks(cascade="lazy", lazy=True)
  833. nix_blocks = self.io.nix_file.blocks
  834. self.compare_blocks(neo_blocks, nix_blocks)
  835. def test_lazycascade_read(self):
  836. def getitem(self, index):
  837. return self._data.__getitem__(index)
  838. from neo.io.nixio import LazyList
  839. getitem_original = LazyList.__getitem__
  840. LazyList.__getitem__ = getitem
  841. neo_blocks = self.io.read_all_blocks(cascade="lazy", lazy=False)
  842. for block in neo_blocks:
  843. self.assertIsInstance(block.segments, LazyList)
  844. self.assertIsInstance(block.channel_indexes, LazyList)
  845. for seg in block.segments:
  846. self.assertIsInstance(seg, string_types)
  847. for chx in block.channel_indexes:
  848. self.assertIsInstance(chx, string_types)
  849. LazyList.__getitem__ = getitem_original
  850. def test_load_lazy_cascade(self):
  851. from neo.io.nixio import LazyList
  852. neo_blocks = self.io.read_all_blocks(cascade="lazy", lazy=False)
  853. for block in neo_blocks:
  854. self.assertIsInstance(block.segments, LazyList)
  855. self.assertIsInstance(block.channel_indexes, LazyList)
  856. name = block.annotations["nix_name"]
  857. block = self.io.load_lazy_cascade("/" + name, lazy=False)
  858. self.assertIsInstance(block.segments, list)
  859. self.assertIsInstance(block.channel_indexes, list)
  860. for seg in block.segments:
  861. self.assertIsInstance(seg.analogsignals, list)
  862. self.assertIsInstance(seg.irregularlysampledsignals, list)
  863. self.assertIsInstance(seg.epochs, list)
  864. self.assertIsInstance(seg.events, list)
  865. self.assertIsInstance(seg.spiketrains, list)
  866. def test_nocascade_read(self):
  867. self.io._read_cascade = mock.Mock()
  868. neo_blocks = self.io.read_all_blocks(cascade=False)
  869. self.io._read_cascade.assert_not_called()
  870. for block in neo_blocks:
  871. self.assertEqual(len(block.segments), 0)
  872. nix_block = self.io.nix_file.blocks[block.annotations["nix_name"]]
  873. self.compare_attr(block, nix_block)
  874. def test_lazy_load_subschema(self):
  875. blk = self.io.nix_file.blocks[0]
  876. segpath = "/" + blk.name + "/segments/" + blk.groups[0].name
  877. segment = self.io.load_lazy_cascade(segpath, lazy=True)
  878. self.assertIsInstance(segment, Segment)
  879. self.assertEqual(segment.annotations["nix_name"], blk.groups[0].name)
  880. self.assertIs(segment.block, None)
  881. self.assertEqual(len(segment.analogsignals[0]), 0)
  882. segment = self.io.load_lazy_cascade(segpath, lazy=False)
  883. self.assertEqual(np.shape(segment.analogsignals[0]), (100, 3))
  884. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  885. class NixIOHashTest(NixIOTest):
  886. def setUp(self):
  887. self.hash = NixIO._hash_object
  888. def _hash_test(self, objtype, argfuncs):
  889. attr = {}
  890. for arg, func in argfuncs.items():
  891. attr[arg] = func()
  892. obj_one = objtype(**attr)
  893. obj_two = objtype(**attr)
  894. hash_one = self.hash(obj_one)
  895. hash_two = self.hash(obj_two)
  896. self.assertEqual(hash_one, hash_two)
  897. for arg, func in argfuncs.items():
  898. chattr = attr.copy()
  899. chattr[arg] = func()
  900. obj_two = objtype(**chattr)
  901. hash_two = self.hash(obj_two)
  902. self.assertNotEqual(
  903. hash_one, hash_two,
  904. "Hash test failed with different '{}'".format(arg)
  905. )
  906. def test_block_seg_hash(self):
  907. argfuncs = {"name": self.rword,
  908. "description": self.rsentence,
  909. "rec_datetime": self.rdate,
  910. "file_datetime": self.rdate,
  911. # annotations
  912. self.rword(): self.rword,
  913. self.rword(): lambda: self.rquant((10, 10), pq.mV)}
  914. self._hash_test(Block, argfuncs)
  915. self._hash_test(Segment, argfuncs)
  916. self._hash_test(Unit, argfuncs)
  917. def test_chx_hash(self):
  918. argfuncs = {"name": self.rword,
  919. "description": self.rsentence,
  920. "index": lambda: np.random.random(10).tolist(),
  921. "channel_names": lambda: self.rsentence(10).split(" "),
  922. "coordinates": lambda: [(np.random.random() * pq.cm,
  923. np.random.random() * pq.cm,
  924. np.random.random() * pq.cm)]*10,
  925. # annotations
  926. self.rword(): self.rword,
  927. self.rword(): lambda: self.rquant((10, 10), pq.mV)}
  928. self._hash_test(ChannelIndex, argfuncs)
  929. def test_analogsignal_hash(self):
  930. argfuncs = {"name": self.rword,
  931. "description": self.rsentence,
  932. "signal": lambda: self.rquant((10, 10), pq.mV),
  933. "sampling_rate": lambda: np.random.random() * pq.Hz,
  934. "t_start": lambda: np.random.random() * pq.sec,
  935. "t_stop": lambda: np.random.random() * pq.sec,
  936. # annotations
  937. self.rword(): self.rword,
  938. self.rword(): lambda: self.rquant((10, 10), pq.mV)}
  939. self._hash_test(AnalogSignal, argfuncs)
  940. def test_irregularsignal_hash(self):
  941. argfuncs = {"name": self.rword,
  942. "description": self.rsentence,
  943. "signal": lambda: self.rquant((10, 10), pq.mV),
  944. "times": lambda: self.rquant(10, pq.ms, True),
  945. # annotations
  946. self.rword(): self.rword,
  947. self.rword(): lambda: self.rquant((10, 10), pq.mV)}
  948. self._hash_test(IrregularlySampledSignal, argfuncs)
  949. def test_event_hash(self):
  950. argfuncs = {"name": self.rword,
  951. "description": self.rsentence,
  952. "times": lambda: self.rquant(10, pq.ms),
  953. "durations": lambda: self.rquant(10, pq.ms),
  954. "labels": lambda: self.rsentence(10).split(" "),
  955. # annotations
  956. self.rword(): self.rword,
  957. self.rword(): lambda: self.rquant((10, 10), pq.mV)}
  958. self._hash_test(Event, argfuncs)
  959. self._hash_test(Epoch, argfuncs)
  960. def test_spiketrain_hash(self):
  961. argfuncs = {"name": self.rword,
  962. "description": self.rsentence,
  963. "times": lambda: self.rquant(10, pq.ms, True),
  964. "t_start": lambda: -np.random.random() * pq.sec,
  965. "t_stop": lambda: np.random.random() * 100 * pq.sec,
  966. "waveforms": lambda: self.rquant((10, 10, 20), pq.mV),
  967. # annotations
  968. self.rword(): self.rword,
  969. self.rword(): lambda: self.rquant((10, 10), pq.mV)}
  970. self._hash_test(SpikeTrain, argfuncs)
  971. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  972. class NixIOPartialWriteTest(NixIOTest):
  973. filename = "testfile_partialwrite.h5"
  974. nixfile = None
  975. neo_blocks = None
  976. original_methods = dict()
  977. @classmethod
  978. def setUpClass(cls):
  979. if HAVE_NIX:
  980. cls.nixfile = cls.create_full_nix_file(cls.filename)
  981. def setUp(self):
  982. self.io = NixIO(self.filename, "rw")
  983. self.neo_blocks = self.io.read_all_blocks()
  984. self.original_methods["_write_attr_annotations"] =\
  985. self.io._write_attr_annotations
  986. @classmethod
  987. def tearDownClass(cls):
  988. if HAVE_NIX:
  989. cls.nixfile.close()
  990. os.remove(cls.filename)
  991. def tearDown(self):
  992. self.restore_methods()
  993. self.io.close()
  994. def restore_methods(self):
  995. for name, method in self.original_methods.items():
  996. setattr(self.io, name, self.original_methods[name])
  997. def _mock_write_attr(self, objclass):
  998. typestr = str(objclass.__name__).lower()
  999. self.io._write_attr_annotations = mock.Mock(
  1000. wraps=self.io._write_attr_annotations,
  1001. side_effect=self.check_obj_type("neo.{}".format(typestr))
  1002. )
  1003. neo_blocks = self.neo_blocks
  1004. self.modify_objects(neo_blocks, excludes=[objclass])
  1005. self.io.write_all_blocks(neo_blocks)
  1006. self.restore_methods()
  1007. def check_obj_type(self, typestring):
  1008. neq = self.assertNotEqual
  1009. def side_effect_func(*args, **kwargs):
  1010. obj = kwargs.get("nixobj", args[0])
  1011. if isinstance(obj, list):
  1012. for sig in obj:
  1013. neq(sig.type, typestring)
  1014. else:
  1015. neq(obj.type, typestring)
  1016. return side_effect_func
  1017. @classmethod
  1018. def modify_objects(cls, objs, excludes=()):
  1019. excludes = tuple(excludes)
  1020. for obj in objs:
  1021. if not (excludes and isinstance(obj, excludes)):
  1022. obj.description = cls.rsentence()
  1023. for container in getattr(obj, "_child_containers", []):
  1024. children = getattr(obj, container)
  1025. cls.modify_objects(children, excludes)
  1026. def test_partial(self):
  1027. for objclass in NixIO.supported_objects:
  1028. self._mock_write_attr(objclass)
  1029. self.compare_blocks(self.neo_blocks, self.io.nix_file.blocks)
  1030. def test_no_modifications(self):
  1031. self.io._write_attr_annotations = mock.Mock()
  1032. self.io.write_all_blocks(self.neo_blocks)
  1033. self.io._write_attr_annotations.assert_not_called()
  1034. self.compare_blocks(self.neo_blocks, self.io.nix_file.blocks)
  1035. # clearing hashes and checking again
  1036. for k in self.io._object_hashes.keys():
  1037. self.io._object_hashes[k] = None
  1038. self.io.write_all_blocks(self.neo_blocks)
  1039. self.io._write_attr_annotations.assert_not_called()
  1040. # changing hashes to force rewrite
  1041. for k in self.io._object_hashes.keys():
  1042. self.io._object_hashes[k] = "_"
  1043. self.io.write_all_blocks(self.neo_blocks)
  1044. callcount = self.io._write_attr_annotations.call_count
  1045. self.assertEqual(callcount, len(self.io._object_hashes))
  1046. self.compare_blocks(self.neo_blocks, self.io.nix_file.blocks)
  1047. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  1048. class NixIOContextTests(NixIOTest):
  1049. filename = "context_test.h5"
  1050. def test_context_write(self):
  1051. neoblock = Block(name=self.rword(), description=self.rsentence())
  1052. with NixIO(self.filename, "ow") as iofile:
  1053. iofile.write_block(neoblock)
  1054. nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly,
  1055. backend="h5py")
  1056. self.compare_blocks([neoblock], nixfile.blocks)
  1057. nixfile.close()
  1058. neoblock.annotate(**self.rdict(5))
  1059. with NixIO(self.filename, "rw") as iofile:
  1060. iofile.write_block(neoblock)
  1061. nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly,
  1062. backend="h5py")
  1063. self.compare_blocks([neoblock], nixfile.blocks)
  1064. nixfile.close()
  1065. os.remove(self.filename)
  1066. def test_context_read(self):
  1067. nixfile = nix.File.open(self.filename, nix.FileMode.Overwrite,
  1068. backend="h5py")
  1069. name_one = self.rword()
  1070. name_two = self.rword()
  1071. nixfile.create_block(name_one, "neo.block")
  1072. nixfile.create_block(name_two, "neo.block")
  1073. nixfile.close()
  1074. with NixIO(self.filename, "ro") as iofile:
  1075. blocks = iofile.read_all_blocks()
  1076. self.assertEqual(blocks[0].annotations["nix_name"], name_one)
  1077. self.assertEqual(blocks[1].annotations["nix_name"], name_two)
  1078. os.remove(self.filename)
  1079. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  1080. class CommonTests(BaseTestIO, unittest.TestCase):
  1081. ioclass = NixIO
  1082. read_and_write_is_bijective = False