test_pickleio.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of the neo.io.pickleio.PickleIO class
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import, division
  7. import os
  8. import unittest
  9. import numpy as np
  10. import quantities as pq
  11. from neo.core import Block, Segment, AnalogSignal, SpikeTrain, Unit, Epoch, Event, ChannelIndex, IrregularlySampledSignal
  12. from neo.io import PickleIO
  13. from numpy.testing import assert_array_equal
  14. from neo.test.tools import assert_arrays_equal, assert_file_contents_equal
  15. from neo.test.iotest.common_io_test import BaseTestIO
  16. NCELLS = 5
  17. class CommonTestPickleIO(BaseTestIO, unittest.TestCase):
  18. ioclass = PickleIO
  19. def test_readed_with_cascade_is_compliant(self):
  20. pass
  21. test_readed_with_cascade_is_compliant.__test__ = False # PickleIO does not support lazy loading
  22. def test_readed_with_lazy_is_compliant(self):
  23. pass
  24. test_readed_with_lazy_is_compliant.__test__ = False
  25. class TestPickleIO(unittest.TestCase):
  26. def test__issue_285(self):
  27. ##Spiketrain
  28. train = SpikeTrain([3, 4, 5] * pq.s, t_stop=10.0)
  29. unit = Unit()
  30. train.unit = unit
  31. unit.spiketrains.append(train)
  32. epoch = Epoch([0, 10, 20], [2, 2, 2], ["a", "b", "c"], units="ms")
  33. blk = Block()
  34. seg = Segment()
  35. seg.spiketrains.append(train)
  36. seg.epochs.append(epoch)
  37. epoch.segment = seg
  38. blk.segments.append(seg)
  39. reader = PickleIO(filename="blk.pkl")
  40. reader.write(blk)
  41. reader = PickleIO(filename="blk.pkl")
  42. r_blk = reader.read_block()
  43. r_seg = r_blk.segments[0]
  44. self.assertIsInstance(r_seg.spiketrains[0].unit, Unit)
  45. self.assertIsInstance(r_seg.epochs[0], Epoch)
  46. os.remove('blk.pkl')
  47. ##Epoch
  48. train = Epoch(times=np.arange(0, 30, 10)*pq.s,durations=[10, 5, 7]*pq.ms,labels=np.array(['btn0', 'btn1', 'btn2'], dtype='S'))
  49. train.segment = Segment()
  50. unit = Unit()
  51. unit.spiketrains.append(train)
  52. blk = Block()
  53. seg = Segment()
  54. seg.spiketrains.append(train)
  55. blk.segments.append(seg)
  56. reader = PickleIO(filename="blk.pkl")
  57. reader.write(blk)
  58. reader = PickleIO(filename="blk.pkl")
  59. r_blk = reader.read_block()
  60. r_seg = r_blk.segments[0]
  61. self.assertIsInstance(r_seg.spiketrains[0].segment, Segment)
  62. os.remove('blk.pkl')
  63. ##Event
  64. train = Event(np.arange(0, 30, 10)*pq.s,labels=np.array(['trig0', 'trig1', 'trig2'],dtype='S'))
  65. train.segment = Segment()
  66. unit = Unit()
  67. unit.spiketrains.append(train)
  68. blk = Block()
  69. seg = Segment()
  70. seg.spiketrains.append(train)
  71. blk.segments.append(seg)
  72. reader = PickleIO(filename="blk.pkl")
  73. reader.write(blk)
  74. reader = PickleIO(filename="blk.pkl")
  75. r_blk = reader.read_block()
  76. r_seg = r_blk.segments[0]
  77. self.assertIsInstance(r_seg.spiketrains[0].segment, Segment)
  78. os.remove('blk.pkl')
  79. ##IrregularlySampledSignal
  80. train = IrregularlySampledSignal([0.0, 1.23, 6.78], [1, 2, 3],units='mV', time_units='ms')
  81. train.segment = Segment()
  82. unit = Unit()
  83. train.channel_index = ChannelIndex(1)
  84. unit.spiketrains.append(train)
  85. blk = Block()
  86. seg = Segment()
  87. seg.spiketrains.append(train)
  88. blk.segments.append(seg)
  89. blk.segments[0].block = blk
  90. reader = PickleIO(filename="blk.pkl")
  91. reader.write(blk)
  92. reader = PickleIO(filename="blk.pkl")
  93. r_blk = reader.read_block()
  94. r_seg = r_blk.segments[0]
  95. self.assertIsInstance(r_seg.spiketrains[0].segment, Segment)
  96. self.assertIsInstance(r_seg.spiketrains[0].channel_index, ChannelIndex)
  97. os.remove('blk.pkl')
  98. if __name__ == '__main__':
  99. unittest.main()