test_pynnio.py 8.1 KB


  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of the neo.io.pynnio.PyNNNumpyIO and neo.io.pynnio.PyNNTextIO classes
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import, division
  7. import os
  8. import unittest
  9. import numpy as np
  10. import quantities as pq
  11. from neo.core import Segment, AnalogSignal, SpikeTrain
  12. from neo.io import PyNNNumpyIO, PyNNTextIO
  13. from numpy.testing import assert_array_equal
  14. from neo.test.tools import assert_arrays_equal, assert_file_contents_equal
  15. from neo.test.iotest.common_io_test import BaseTestIO
  16. #class CommonTestPyNNNumpyIO(BaseTestIO, unittest.TestCase):
  17. # ioclass = PyNNNumpyIO
  18. NCELLS = 5
  19. class CommonTestPyNNTextIO(BaseTestIO, unittest.TestCase):
  20. ioclass = PyNNTextIO
  21. read_and_write_is_bijective = False
  22. def read_test_file(filename):
  23. contents = np.load(filename)
  24. data = contents["data"]
  25. metadata = {}
  26. for name, value in contents['metadata']:
  27. try:
  28. metadata[name] = eval(value)
  29. except Exception:
  30. metadata[name] = value
  31. return data, metadata
  32. read_test_file.__test__ = False
  33. class BaseTestPyNNIO(object):
  34. __test__ = False
  35. def tearDown(self):
  36. if os.path.exists(self.test_file):
  37. os.remove(self.test_file)
  38. def test_write_segment(self):
  39. in_ = self.io_cls(self.test_file)
  40. write_test_file = "write_test.%s" % self.file_extension
  41. out = self.io_cls(write_test_file)
  42. out.write_segment(in_.read_segment(lazy=False, cascade=True))
  43. assert_file_contents_equal(self.test_file, write_test_file)
  44. if os.path.exists(write_test_file):
  45. os.remove(write_test_file)
  46. def build_test_data(self, variable='v'):
  47. metadata = {
  48. 'size': NCELLS,
  49. 'first_index': 0,
  50. 'first_id': 0,
  51. 'n': 505,
  52. 'variable': variable,
  53. 'last_id': NCELLS - 1,
  54. 'last_index': NCELLS - 1,
  55. 'dt': 0.1,
  56. 'label': "population0",
  57. }
  58. if variable == 'v':
  59. metadata['units'] = 'mV'
  60. elif variable == 'spikes':
  61. metadata['units'] = 'ms'
  62. data = np.empty((505, 2))
  63. for i in range(NCELLS):
  64. # signal
  65. data[i*101:(i+1)*101, 0] = np.arange(i, i+101, dtype=float)
  66. # index
  67. data[i*101:(i+1)*101, 1] = i*np.ones((101,), dtype=float)
  68. return data, metadata
  69. build_test_data.__test__ = False
  70. class BaseTestPyNNIO_Signals(BaseTestPyNNIO):
  71. def setUp(self):
  72. self.test_file = "test_file_v.%s" % self.file_extension
  73. self.write_test_file("v")
  74. def test_read_segment_containing_analogsignals_using_eager_cascade(self):
  75. # eager == not lazy
  76. io = self.io_cls(self.test_file)
  77. segment = io.read_segment(lazy=False, cascade=True)
  78. self.assertIsInstance(segment, Segment)
  79. self.assertEqual(len(segment.analogsignals), 1)
  80. as0 = segment.analogsignals[0]
  81. self.assertIsInstance(as0, AnalogSignal)
  82. self.assertEqual(as0.shape, (101, NCELLS))
  83. assert_array_equal(as0[:, 0],
  84. AnalogSignal(np.arange(0, 101, dtype=float),
  85. sampling_period=0.1*pq.ms,
  86. t_start=0*pq.s,
  87. units=pq.mV))
  88. as4 = as0[:, 4]
  89. self.assertIsInstance(as4, AnalogSignal)
  90. assert_array_equal(as4,
  91. AnalogSignal(np.arange(4, 105, dtype=float),
  92. sampling_period=0.1*pq.ms,
  93. t_start=0*pq.s,
  94. units=pq.mV))
  95. # test annotations (stuff from file metadata)
  96. def test_read_analogsignal_using_eager(self):
  97. io = self.io_cls(self.test_file)
  98. sig = io.read_analogsignal(lazy=False)
  99. self.assertIsInstance(sig, AnalogSignal)
  100. assert_array_equal(sig[:, 3],
  101. AnalogSignal(np.arange(3, 104, dtype=float),
  102. sampling_period=0.1*pq.ms,
  103. t_start=0*pq.s,
  104. units=pq.mV))
  105. # should test annotations: 'channel_index', etc.
  106. def test_read_spiketrain_should_fail_with_analogsignal_file(self):
  107. io = self.io_cls(self.test_file)
  108. self.assertRaises(TypeError, io.read_spiketrain, channel_index=0)
  109. class BaseTestPyNNIO_Spikes(BaseTestPyNNIO):
  110. def setUp(self):
  111. self.test_file = "test_file_spikes.%s" % self.file_extension
  112. self.write_test_file("spikes")
  113. def test_read_segment_containing_spiketrains_using_eager_cascade(self):
  114. io = self.io_cls(self.test_file)
  115. segment = io.read_segment(lazy=False, cascade=True)
  116. self.assertIsInstance(segment, Segment)
  117. self.assertEqual(len(segment.spiketrains), NCELLS)
  118. st0 = segment.spiketrains[0]
  119. self.assertIsInstance(st0, SpikeTrain)
  120. assert_arrays_equal(st0,
  121. SpikeTrain(np.arange(0, 101, dtype=float),
  122. t_start=0*pq.s,
  123. t_stop=101*pq.ms,
  124. units=pq.ms))
  125. st4 = segment.spiketrains[4]
  126. self.assertIsInstance(st4, SpikeTrain)
  127. assert_arrays_equal(st4,
  128. SpikeTrain(np.arange(4, 105, dtype=float),
  129. t_start=0*pq.s,
  130. t_stop=105*pq.ms,
  131. units=pq.ms))
  132. # test annotations (stuff from file metadata)
  133. def test_read_spiketrain_using_eager(self):
  134. io = self.io_cls(self.test_file)
  135. st3 = io.read_spiketrain(lazy=False, channel_index=3)
  136. self.assertIsInstance(st3, SpikeTrain)
  137. assert_arrays_equal(st3,
  138. SpikeTrain(np.arange(3, 104, dtype=float),
  139. t_start=0*pq.s,
  140. t_stop=104*pq.s,
  141. units=pq.ms))
  142. # should test annotations: 'channel_index', etc.
  143. def test_read_analogsignal_should_fail_with_spiketrain_file(self):
  144. io = self.io_cls(self.test_file)
  145. self.assertRaises(TypeError, io.read_analogsignal, channel_index=2)
  146. class BaseTestPyNNNumpyIO(object):
  147. io_cls = PyNNNumpyIO
  148. file_extension = "npz"
  149. def write_test_file(self, variable='v', check=False):
  150. data, metadata = self.build_test_data(variable)
  151. metadata_array = np.array(sorted(metadata.items()))
  152. np.savez(self.test_file, data=data, metadata=metadata_array)
  153. if check:
  154. data1, metadata1 = read_test_file(self.test_file)
  155. assert metadata == metadata1, "%s != %s" % (metadata, metadata1)
  156. assert data.shape == data1.shape == (505, 2), \
  157. "%s, %s, (505, 2)" % (data.shape, data1.shape)
  158. assert (data == data1).all()
  159. assert metadata["n"] == 505
  160. write_test_file.__test__ = False
  161. class BaseTestPyNNTextIO(object):
  162. io_cls = PyNNTextIO
  163. file_extension = "txt"
  164. def write_test_file(self, variable='v', check=False):
  165. data, metadata = self.build_test_data(variable)
  166. with open(self.test_file, 'wb') as f:
  167. for item in sorted(metadata.items()):
  168. f.write(("# %s = %s\n" % item).encode('utf8'))
  169. np.savetxt(f, data)
  170. if check:
  171. raise NotImplementedError
  172. write_test_file.__test__ = False
  173. class TestPyNNNumpyIO_Signals(BaseTestPyNNNumpyIO, BaseTestPyNNIO_Signals,
  174. unittest.TestCase):
  175. __test__ = True
  176. class TestPyNNNumpyIO_Spikes(BaseTestPyNNNumpyIO, BaseTestPyNNIO_Spikes,
  177. unittest.TestCase):
  178. __test__ = True
  179. class TestPyNNTextIO_Signals(BaseTestPyNNTextIO, BaseTestPyNNIO_Signals,
  180. unittest.TestCase):
  181. __test__ = True
  182. class TestPyNNTextIO_Spikes(BaseTestPyNNTextIO, BaseTestPyNNIO_Spikes,
  183. unittest.TestCase):
  184. __test__ = True
  185. if __name__ == '__main__':
  186. unittest.main()