test_spike_train_surrogates.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # -*- coding: utf-8 -*-
  2. """
  3. unittests for spike_train_surrogates module.
  4. :copyright: Copyright 2014-2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. import unittest
  8. import elephant.spike_train_surrogates as surr
  9. import numpy as np
  10. import quantities as pq
  11. import neo
  12. np.random.seed(0)
  13. class SurrogatesTestCase(unittest.TestCase):
  14. def test_dither_spikes_output_format(self):
  15. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  16. nr_surr = 2
  17. dither = 10 * pq.ms
  18. surrs = surr.dither_spikes(st, dither=dither, n=nr_surr)
  19. self.assertIsInstance(surrs, list)
  20. self.assertEqual(len(surrs), nr_surr)
  21. for surrog in surrs:
  22. self.assertIsInstance(surrs[0], neo.SpikeTrain)
  23. self.assertEqual(surrog.units, st.units)
  24. self.assertEqual(surrog.t_start, st.t_start)
  25. self.assertEqual(surrog.t_stop, st.t_stop)
  26. self.assertEqual(len(surrog), len(st))
  27. def test_dither_spikes_empty_train(self):
  28. st = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms)
  29. dither = 10 * pq.ms
  30. surrog = surr.dither_spikes(st, dither=dither, n=1)[0]
  31. self.assertEqual(len(surrog), 0)
  32. def test_dither_spikes_output_decimals(self):
  33. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  34. nr_surr = 2
  35. dither = 10 * pq.ms
  36. np.random.seed(42)
  37. surrs = surr.dither_spikes(st, dither=dither, decimals=3, n=nr_surr)
  38. np.random.seed(42)
  39. dither_values = np.random.random_sample((nr_surr, len(st)))
  40. expected_non_dithered = np.sum(dither_values==0)
  41. observed_non_dithered = 0
  42. for surrog in surrs:
  43. for i in range(len(surrog)):
  44. if surrog[i] - int(surrog[i]) * pq.ms == surrog[i] - surrog[i]:
  45. observed_non_dithered += 1
  46. self.assertEqual(observed_non_dithered, expected_non_dithered)
  47. def test_dither_spikes_false_edges(self):
  48. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  49. nr_surr = 2
  50. dither = 10 * pq.ms
  51. surrs = surr.dither_spikes(st, dither=dither, n=nr_surr, edges=False)
  52. for surrog in surrs:
  53. for i in range(len(surrog)):
  54. self.assertLessEqual(surrog[i], st.t_stop)
  55. def test_randomise_spikes_output_format(self):
  56. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  57. nr_surr = 2
  58. surrs = surr.randomise_spikes(st, n=nr_surr)
  59. self.assertIsInstance(surrs, list)
  60. self.assertEqual(len(surrs), nr_surr)
  61. for surrog in surrs:
  62. self.assertIsInstance(surrs[0], neo.SpikeTrain)
  63. self.assertEqual(surrog.units, st.units)
  64. self.assertEqual(surrog.t_start, st.t_start)
  65. self.assertEqual(surrog.t_stop, st.t_stop)
  66. self.assertEqual(len(surrog), len(st))
  67. def test_randomise_spikes_empty_train(self):
  68. st = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms)
  69. surrog = surr.randomise_spikes(st, n=1)[0]
  70. self.assertEqual(len(surrog), 0)
  71. def test_randomise_spikes_output_decimals(self):
  72. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  73. nr_surr = 2
  74. surrs = surr.randomise_spikes(st, n=nr_surr, decimals=3)
  75. for surrog in surrs:
  76. for i in range(len(surrog)):
  77. self.assertNotEqual(surrog[i] - int(surrog[i]) * pq.ms,
  78. surrog[i] - surrog[i])
  79. def test_shuffle_isis_output_format(self):
  80. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  81. nr_surr = 2
  82. surrs = surr.shuffle_isis(st, n=nr_surr)
  83. self.assertIsInstance(surrs, list)
  84. self.assertEqual(len(surrs), nr_surr)
  85. for surrog in surrs:
  86. self.assertIsInstance(surrs[0], neo.SpikeTrain)
  87. self.assertEqual(surrog.units, st.units)
  88. self.assertEqual(surrog.t_start, st.t_start)
  89. self.assertEqual(surrog.t_stop, st.t_stop)
  90. self.assertEqual(len(surrog), len(st))
  91. def test_shuffle_isis_empty_train(self):
  92. st = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms)
  93. surrog = surr.shuffle_isis(st, n=1)[0]
  94. self.assertEqual(len(surrog), 0)
  95. def test_shuffle_isis_same_isis(self):
  96. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  97. surrog = surr.shuffle_isis(st, n=1)[0]
  98. st_pq = st.view(pq.Quantity)
  99. surr_pq = surrog.view(pq.Quantity)
  100. isi0_orig = st[0] - st.t_start
  101. ISIs_orig = np.sort([isi0_orig] + [isi for isi in np.diff(st_pq)])
  102. isi0_surr = surrog[0] - surrog.t_start
  103. ISIs_surr = np.sort([isi0_surr] + [isi for isi in np.diff(surr_pq)])
  104. self.assertTrue(np.all(ISIs_orig == ISIs_surr))
  105. def test_shuffle_isis_output_decimals(self):
  106. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  107. surrog = surr.shuffle_isis(st, n=1, decimals=95)[0]
  108. st_pq = st.view(pq.Quantity)
  109. surr_pq = surrog.view(pq.Quantity)
  110. isi0_orig = st[0] - st.t_start
  111. ISIs_orig = np.sort([isi0_orig] + [isi for isi in np.diff(st_pq)])
  112. isi0_surr = surrog[0] - surrog.t_start
  113. ISIs_surr = np.sort([isi0_surr] + [isi for isi in np.diff(surr_pq)])
  114. self.assertTrue(np.all(ISIs_orig == ISIs_surr))
  115. def test_dither_spike_train_output_format(self):
  116. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  117. nr_surr = 2
  118. shift = 10 * pq.ms
  119. surrs = surr.dither_spike_train(st, shift=shift, n=nr_surr)
  120. self.assertIsInstance(surrs, list)
  121. self.assertEqual(len(surrs), nr_surr)
  122. for surrog in surrs:
  123. self.assertIsInstance(surrs[0], neo.SpikeTrain)
  124. self.assertEqual(surrog.units, st.units)
  125. self.assertEqual(surrog.t_start, st.t_start)
  126. self.assertEqual(surrog.t_stop, st.t_stop)
  127. self.assertEqual(len(surrog), len(st))
  128. def test_dither_spike_train_empty_train(self):
  129. st = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms)
  130. shift = 10 * pq.ms
  131. surrog = surr.dither_spike_train(st, shift=shift, n=1)[0]
  132. self.assertEqual(len(surrog), 0)
  133. def test_dither_spike_train_output_decimals(self):
  134. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  135. nr_surr = 2
  136. shift = 10 * pq.ms
  137. surrs = surr.dither_spike_train(st, shift=shift, n=nr_surr, decimals=3)
  138. for surrog in surrs:
  139. for i in range(len(surrog)):
  140. self.assertNotEqual(surrog[i] - int(surrog[i]) * pq.ms,
  141. surrog[i] - surrog[i])
  142. def test_dither_spike_train_false_edges(self):
  143. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  144. nr_surr = 2
  145. shift = 10 * pq.ms
  146. surrs = surr.dither_spike_train(
  147. st, shift=shift, n=nr_surr, edges=False)
  148. for surrog in surrs:
  149. for i in range(len(surrog)):
  150. self.assertLessEqual(surrog[i], st.t_stop)
  151. def test_jitter_spikes_output_format(self):
  152. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  153. nr_surr = 2
  154. binsize = 100 * pq.ms
  155. surrs = surr.jitter_spikes(st, binsize=binsize, n=nr_surr)
  156. self.assertIsInstance(surrs, list)
  157. self.assertEqual(len(surrs), nr_surr)
  158. for surrog in surrs:
  159. self.assertIsInstance(surrs[0], neo.SpikeTrain)
  160. self.assertEqual(surrog.units, st.units)
  161. self.assertEqual(surrog.t_start, st.t_start)
  162. self.assertEqual(surrog.t_stop, st.t_stop)
  163. self.assertEqual(len(surrog), len(st))
  164. def test_jitter_spikes_empty_train(self):
  165. st = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms)
  166. binsize = 75 * pq.ms
  167. surrog = surr.jitter_spikes(st, binsize=binsize, n=1)[0]
  168. self.assertEqual(len(surrog), 0)
  169. def test_jitter_spikes_same_bins(self):
  170. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  171. binsize = 100 * pq.ms
  172. surrog = surr.jitter_spikes(st, binsize=binsize, n=1)[0]
  173. bin_ids_orig = np.array((st.view(pq.Quantity) / binsize).rescale(
  174. pq.dimensionless).magnitude, dtype=int)
  175. bin_ids_surr = np.array((surrog.view(pq.Quantity) / binsize).rescale(
  176. pq.dimensionless).magnitude, dtype=int)
  177. self.assertTrue(np.all(bin_ids_orig == bin_ids_surr))
  178. # Bug encountered when the original and surrogate trains have
  179. # different number of spikes
  180. self.assertEqual(len(st), len(surrog))
  181. def test_jitter_spikes_unequal_binsize(self):
  182. st = neo.SpikeTrain([90, 150, 180, 480] * pq.ms, t_stop=500 * pq.ms)
  183. binsize = 75 * pq.ms
  184. surrog = surr.jitter_spikes(st, binsize=binsize, n=1)[0]
  185. bin_ids_orig = np.array((st.view(pq.Quantity) / binsize).rescale(
  186. pq.dimensionless).magnitude, dtype=int)
  187. bin_ids_surr = np.array((surrog.view(pq.Quantity) / binsize).rescale(
  188. pq.dimensionless).magnitude, dtype=int)
  189. self.assertTrue(np.all(bin_ids_orig == bin_ids_surr))
  190. def test_surr_method(self):
  191. st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
  192. nr_surr = 2
  193. surrs = surr.surrogates(st, dt=3 * pq.ms, n=nr_surr,
  194. surr_method='shuffle_isis', edges=False)
  195. self.assertRaises(ValueError, surr.surrogates, st, n=1,
  196. surr_method='spike_shifting',
  197. dt=None, decimals=None, edges=True)
  198. self.assertTrue(len(surrs) == nr_surr)
  199. nr_surr2 = 4
  200. surrs2 = surr.surrogates(st, dt=5 * pq.ms, n=nr_surr2,
  201. surr_method='dither_spike_train', edges=True)
  202. for surrog in surrs:
  203. self.assertTrue(isinstance(surrs[0], neo.SpikeTrain))
  204. self.assertEqual(surrog.units, st.units)
  205. self.assertEqual(surrog.t_start, st.t_start)
  206. self.assertEqual(surrog.t_stop, st.t_stop)
  207. self.assertEqual(len(surrog), len(st))
  208. self.assertTrue(len(surrs) == nr_surr)
  209. for surrog in surrs2:
  210. self.assertTrue(isinstance(surrs2[0], neo.SpikeTrain))
  211. self.assertEqual(surrog.units, st.units)
  212. self.assertEqual(surrog.t_start, st.t_start)
  213. self.assertEqual(surrog.t_stop, st.t_stop)
  214. self.assertEqual(len(surrog), len(st))
  215. self.assertTrue(len(surrs2) == nr_surr2)
  216. def suite():
  217. suite = unittest.makeSuite(SurrogatesTestCase, 'test')
  218. return suite
  219. if __name__ == "__main__":
  220. runner = unittest.TextTestRunner(verbosity=2)
  221. runner.run(suite())