test_nsdfio.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of neo.io.NSDFIO
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import, division
  7. import numpy as np
  8. import quantities as pq
  9. from datetime import datetime
  10. import os
  11. import unittest
  12. from neo.io.nsdfio import HAVE_NSDF, NSDFIO
  13. from neo.test.iotest.common_io_test import BaseTestIO
  14. from neo.core import AnalogSignal, Segment, Block, ChannelIndex
  15. from neo.test.tools import assert_same_attributes, assert_same_annotations, assert_neo_object_is_compliant
  16. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  17. class CommonTests(BaseTestIO, unittest.TestCase):
  18. ioclass = NSDFIO
  19. read_and_write_is_bijective = False
  20. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  21. class NSDFIOTest(unittest.TestCase):
  22. """
  23. Base class for all NSDFIO tests.
  24. setUp and tearDown methods are responsible for respectively: setting up and cleaning after tests
  25. All create_{object} methods create and return an example {object}.
  26. """
  27. def setUp(self):
  28. self.filename = 'nsdfio_testfile.h5'
  29. self.io = NSDFIO(self.filename)
  30. def tearDown(self):
  31. os.remove(self.filename)
  32. def create_list_of_blocks(self):
  33. blocks = []
  34. for i in range(2):
  35. blocks.append(self.create_block(name='Block #{}'.format(i)))
  36. return blocks
  37. def create_block(self, name='Block'):
  38. block = Block()
  39. self._assign_basic_attributes(block, name=name)
  40. self._assign_datetime_attributes(block)
  41. self._assign_index_attribute(block)
  42. self._create_block_children(block)
  43. self._assign_annotations(block)
  44. return block
  45. def _create_block_children(self, block):
  46. for i in range(3):
  47. block.segments.append(self.create_segment(block, name='Segment #{}'.format(i)))
  48. for i in range(3):
  49. block.channel_indexes.append(self.create_channelindex(block, name='ChannelIndex #{}'.format(i),
  50. analogsignals=[seg.analogsignals[i] for seg in block.segments]))
  51. def create_segment(self, parent=None, name='Segment'):
  52. segment = Segment()
  53. segment.block = parent
  54. self._assign_basic_attributes(segment, name=name)
  55. self._assign_datetime_attributes(segment)
  56. self._assign_index_attribute(segment)
  57. self._create_segment_children(segment)
  58. self._assign_annotations(segment)
  59. return segment
  60. def _create_segment_children(self, segment):
  61. for i in range(2):
  62. segment.analogsignals.append(self.create_analogsignal(segment, name='Signal #{}'.format(i * 3)))
  63. segment.analogsignals.append(self.create_analogsignal2(segment, name='Signal #{}'.format(i * 3 + 1)))
  64. segment.analogsignals.append(self.create_analogsignal3(segment, name='Signal #{}'.format(i * 3 + 2)))
  65. def create_analogsignal(self, parent=None, name='AnalogSignal1'):
  66. signal = AnalogSignal([[1.0, 2.5], [2.2, 3.1], [3.2, 4.4]], units='mV',
  67. sampling_rate=100 * pq.Hz, t_start=2 * pq.min)
  68. signal.segment = parent
  69. self._assign_basic_attributes(signal, name=name)
  70. self._assign_annotations(signal)
  71. return signal
  72. def create_analogsignal2(self, parent=None, name='AnalogSignal2'):
  73. signal = AnalogSignal([[1], [2], [3], [4], [5]], units='mA',
  74. sampling_period=0.5 * pq.ms)
  75. signal.segment = parent
  76. self._assign_annotations(signal)
  77. return signal
  78. def create_analogsignal3(self, parent=None, name='AnalogSignal3'):
  79. signal = AnalogSignal([[1, 2, 3], [4, 5, 6]], units='mV',
  80. sampling_rate=2 * pq.kHz, t_start=100 * pq.s)
  81. signal.segment = parent
  82. self._assign_basic_attributes(signal, name=name)
  83. return signal
  84. def create_channelindex(self, parent=None, name='ChannelIndex', analogsignals=None):
  85. channels_num = min([signal.shape[1] for signal in analogsignals])
  86. channelindex = ChannelIndex(index=np.arange(channels_num),
  87. channel_names=['Channel{}'.format(i) for i in range(channels_num)],
  88. channel_ids=np.arange(channels_num),
  89. coordinates=([[1.87, -5.2, 4.0]] * channels_num) * pq.cm)
  90. for signal in analogsignals:
  91. channelindex.analogsignals.append(signal)
  92. self._assign_basic_attributes(channelindex, name)
  93. self._assign_annotations(channelindex)
  94. return channelindex
  95. def _assign_basic_attributes(self, object, name=None):
  96. if name is None:
  97. object.name = 'neo object'
  98. else:
  99. object.name = name
  100. object.description = 'Example of neo object'
  101. object.file_origin = 'datafile.pp'
  102. def _assign_datetime_attributes(self, object):
  103. object.file_datetime = datetime(2017, 6, 11, 14, 53, 23)
  104. object.rec_datetime = datetime(2017, 5, 29, 13, 12, 47)
  105. def _assign_index_attribute(self, object):
  106. object.index = 12
  107. def _assign_annotations(self, object):
  108. object.annotations = {'str': 'value',
  109. 'int': 56,
  110. 'float': 5.234}
  111. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  112. class NSDFIOTestWriteThenRead(NSDFIOTest):
  113. """
  114. Class for testing NSDFIO.
  115. It first creates example neo objects, then writes them to the file,
  116. reads the file and compares the result with the original ones.
  117. all test_{object} methods run "write then read" test for a/an {object}
  118. all compare_{object} methods check if the second {object} is a proper copy
  119. of the first one, read in suitable lazy and cascade mode
  120. """
  121. lazy_modes = [False, True]
  122. cascade_modes = [False, True]
  123. def test_list_of_blocks(self, lazy=False, cascade=True):
  124. blocks = self.create_list_of_blocks()
  125. self.io.write(blocks)
  126. for lazy in self.lazy_modes:
  127. for cascade in self.cascade_modes:
  128. blocks2 = self.io.read(lazy=lazy, cascade=cascade)
  129. self.compare_list_of_blocks(blocks, blocks2, lazy, cascade)
  130. def test_block(self, lazy=False, cascade=True):
  131. block = self.create_block()
  132. self.io.write_block(block)
  133. for lazy in self.lazy_modes:
  134. for cascade in self.cascade_modes:
  135. block2 = self.io.read_block(lazy=lazy, cascade=cascade)
  136. self.compare_blocks(block, block2, lazy, cascade)
  137. def test_segment(self, lazy=False, cascade=True):
  138. segment = self.create_segment()
  139. self.io.write_segment(segment)
  140. for lazy in self.lazy_modes:
  141. for cascade in self.cascade_modes:
  142. segment2 = self.io.read_segment(lazy=lazy, cascade=cascade)
  143. self.compare_segments(segment, segment2, lazy, cascade)
  144. def compare_list_of_blocks(self, blocks1, blocks2, lazy=False, cascade=True):
  145. assert len(blocks1) == len(blocks2)
  146. for block1, block2 in zip(blocks1, blocks2):
  147. self.compare_blocks(block1, block2, lazy, cascade)
  148. def compare_blocks(self, block1, block2, lazy=False, cascade=True):
  149. self._compare_objects(block1, block2)
  150. assert block2.file_datetime == datetime.fromtimestamp(os.stat(self.filename).st_mtime)
  151. assert_neo_object_is_compliant(block2)
  152. if cascade:
  153. self._compare_blocks_children(block1, block2, lazy=lazy)
  154. else:
  155. assert len(block2.segments) == 0
  156. def _compare_blocks_children(self, block1, block2, lazy):
  157. assert len(block1.segments) == len(block2.segments)
  158. for segment1, segment2 in zip(block1.segments, block2.segments):
  159. self.compare_segments(segment1, segment2, lazy=lazy)
  160. def compare_segments(self, segment1, segment2, lazy=False, cascade=True):
  161. self._compare_objects(segment1, segment2)
  162. assert segment2.file_datetime == datetime.fromtimestamp(os.stat(self.filename).st_mtime)
  163. if cascade:
  164. self._compare_segments_children(segment1, segment2, lazy=lazy)
  165. else:
  166. assert len(segment2.analogsignals) == 0
  167. def _compare_segments_children(self, segment1, segment2, lazy):
  168. assert len(segment1.analogsignals) == len(segment2.analogsignals)
  169. for signal1, signal2 in zip(segment1.analogsignals, segment2.analogsignals):
  170. self.compare_analogsignals(signal1, signal2, lazy=lazy)
  171. def compare_analogsignals(self, signal1, signal2, lazy=False, cascade=True):
  172. if not lazy:
  173. self._compare_objects(signal1, signal2)
  174. else:
  175. self._compare_objects(signal1, signal2, exclude_attr=['shape', 'signal'])
  176. assert signal2.lazy_shape == signal1.shape
  177. assert signal2.dtype == signal1.dtype
  178. def _compare_objects(self, object1, object2, exclude_attr=[]):
  179. assert object1.__class__.__name__ == object2.__class__.__name__
  180. assert object2.file_origin == self.filename
  181. assert_same_attributes(object1, object2, exclude=['file_origin', 'file_datetime'] + exclude_attr)
  182. assert_same_annotations(object1, object2)
  183. if __name__ == "__main__":
  184. unittest.main()