test_spike_train_generation.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for the spike_train_generation module.
  4. :copyright: Copyright 2014-2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. from __future__ import division
  8. import unittest
  9. import os
  10. import warnings
  11. import neo
  12. import numpy as np
  13. from numpy.testing.utils import assert_array_almost_equal
  14. from scipy.stats import kstest, expon
  15. from quantities import ms, second, Hz, kHz, mV, dimensionless
  16. import elephant.spike_train_generation as stgen
  17. from elephant.statistics import isi
  18. def pdiff(a, b):
  19. """Difference between a and b as a fraction of a
  20. i.e. abs((a - b)/a)
  21. """
  22. return abs((a - b)/a)
  23. class AnalogSignalThresholdDetectionTestCase(unittest.TestCase):
  24. def setUp(self):
  25. pass
  26. def test_threshold_detection(self):
  27. # Test whether spikes are extracted at the correct times from
  28. # an analog signal.
  29. # Load membrane potential simulated using Brian2
  30. # according to make_spike_extraction_test_data.py.
  31. curr_dir = os.path.dirname(os.path.realpath(__file__))
  32. npz_file_loc = os.path.join(curr_dir,'spike_extraction_test_data.npz')
  33. iom2 = neo.io.PyNNNumpyIO(npz_file_loc)
  34. data = iom2.read()
  35. vm = data[0].segments[0].analogsignals[0]
  36. spike_train = stgen.threshold_detection(vm)
  37. try:
  38. len(spike_train)
  39. except TypeError: # Handles an error in Neo related to some zero length
  40. # spike trains being treated as unsized objects.
  41. warnings.warn(("The spike train may be an unsized object. This may be related "
  42. "to an issue in Neo with some zero-length SpikeTrain objects. "
  43. "Bypassing this by creating an empty SpikeTrain object."))
  44. spike_train = neo.core.SpikeTrain([],t_start=spike_train.t_start,
  45. t_stop=spike_train.t_stop,
  46. units=spike_train.units)
  47. # Correct values determined previously.
  48. true_spike_train = [0.0123, 0.0354, 0.0712, 0.1191,
  49. 0.1694, 0.22, 0.2711]
  50. # Does threshold_detection gives the correct number of spikes?
  51. self.assertEqual(len(spike_train),len(true_spike_train))
  52. # Does threshold_detection gives the correct times for the spikes?
  53. try:
  54. assert_array_almost_equal(spike_train,spike_train)
  55. except AttributeError: # If numpy version too old to have allclose
  56. self.assertTrue(np.array_equal(spike_train,spike_train))
  57. class AnalogSignalPeakDetectionTestCase(unittest.TestCase):
  58. def setUp(self):
  59. curr_dir = os.path.dirname(os.path.realpath(__file__))
  60. npz_file_loc = os.path.join(curr_dir, 'spike_extraction_test_data.npz')
  61. iom2 = neo.io.PyNNNumpyIO(npz_file_loc)
  62. data = iom2.read()
  63. self.vm = data[0].segments[0].analogsignals[0]
  64. self.true_time_stamps = [0.0124, 0.0354, 0.0713, 0.1192, 0.1695,
  65. 0.2201, 0.2711] * second
  66. def test_peak_detection_time_stamps(self):
  67. # Test with default arguments
  68. result = stgen.peak_detection(self.vm)
  69. self.assertEqual(len(self.true_time_stamps), len(result))
  70. self.assertIsInstance(result, neo.core.SpikeTrain)
  71. try:
  72. assert_array_almost_equal(result, self.true_time_stamps)
  73. except AttributeError:
  74. self.assertTrue(np.array_equal(result, self.true_time_stamps))
  75. def test_peak_detection_threshold(self):
  76. # Test for empty SpikeTrain when threshold is too high
  77. result = stgen.peak_detection(self.vm, threshold=30 * mV)
  78. self.assertEqual(len(result), 0)
  79. class AnalogSignalSpikeExtractionTestCase(unittest.TestCase):
  80. def setUp(self):
  81. curr_dir = os.path.dirname(os.path.realpath(__file__))
  82. npz_file_loc = os.path.join(curr_dir, 'spike_extraction_test_data.npz')
  83. iom2 = neo.io.PyNNNumpyIO(npz_file_loc)
  84. data = iom2.read()
  85. self.vm = data[0].segments[0].analogsignals[0]
  86. self.first_spike = np.array([-0.04084546, -0.03892033, -0.03664779,
  87. -0.03392689, -0.03061474, -0.02650277,
  88. -0.0212756, -0.01443531, -0.00515365,
  89. 0.00803962, 0.02797951, -0.07,
  90. -0.06974495, -0.06950466, -0.06927778,
  91. -0.06906314, -0.06885969, -0.06866651,
  92. -0.06848277, -0.06830773, -0.06814071,
  93. -0.06798113, -0.06782843, -0.06768213,
  94. -0.06754178, -0.06740699, -0.06727737,
  95. -0.06715259, -0.06703235, -0.06691635])
  96. def test_spike_extraction_waveform(self):
  97. spike_train = stgen.spike_extraction(self.vm.reshape(-1),
  98. extr_interval = (-1*ms, 2*ms))
  99. try:
  100. assert_array_almost_equal(spike_train.waveforms[0][0].magnitude.reshape(-1),
  101. self.first_spike)
  102. except AttributeError:
  103. self.assertTrue(
  104. np.array_equal(spike_train.waveforms[0][0].magnitude,
  105. self.first_spike))
  106. class HomogeneousPoissonProcessTestCase(unittest.TestCase):
  107. def setUp(self):
  108. pass
  109. def test_statistics(self):
  110. # This is a statistical test that has a non-zero chance of failure
  111. # during normal operation. Thus, we set the random seed to a value that
  112. # creates a realization passing the test.
  113. np.random.seed(seed=12345)
  114. for rate in [123.0*Hz, 0.123*kHz]:
  115. for t_stop in [2345*ms, 2.345*second]:
  116. spiketrain = stgen.homogeneous_poisson_process(rate, t_stop=t_stop)
  117. intervals = isi(spiketrain)
  118. expected_spike_count = int((rate * t_stop).simplified)
  119. self.assertLess(pdiff(expected_spike_count, spiketrain.size), 0.2) # should fail about 1 time in 1000
  120. expected_mean_isi = (1/rate)
  121. self.assertLess(pdiff(expected_mean_isi, intervals.mean()), 0.2)
  122. expected_first_spike = 0*ms
  123. self.assertLess(spiketrain[0] - expected_first_spike, 7*expected_mean_isi)
  124. expected_last_spike = t_stop
  125. self.assertLess(expected_last_spike - spiketrain[-1], 7*expected_mean_isi)
  126. # Kolmogorov-Smirnov test
  127. D, p = kstest(intervals.rescale(t_stop.units),
  128. "expon",
  129. args=(0, expected_mean_isi.rescale(t_stop.units)), # args are (loc, scale)
  130. alternative='two-sided')
  131. self.assertGreater(p, 0.001)
  132. self.assertLess(D, 0.12)
  133. def test_low_rates(self):
  134. spiketrain = stgen.homogeneous_poisson_process(0*Hz, t_stop=1000*ms)
  135. self.assertEqual(spiketrain.size, 0)
  136. # not really a test, just making sure that all code paths are covered
  137. for i in range(10):
  138. spiketrain = stgen.homogeneous_poisson_process(1*Hz, t_stop=1000*ms)
  139. def test_buffer_overrun(self):
  140. np.random.seed(6085) # this seed should produce a buffer overrun
  141. t_stop=1000*ms
  142. rate = 10*Hz
  143. spiketrain = stgen.homogeneous_poisson_process(rate, t_stop=t_stop)
  144. expected_last_spike = t_stop
  145. expected_mean_isi = (1/rate).rescale(ms)
  146. self.assertLess(expected_last_spike - spiketrain[-1], 4*expected_mean_isi)
  147. class HomogeneousGammaProcessTestCase(unittest.TestCase):
  148. def setUp(self):
  149. pass
  150. def test_statistics(self):
  151. # This is a statistical test that has a non-zero chance of failure
  152. # during normal operation. Thus, we set the random seed to a value that
  153. # creates a realization passing the test.
  154. np.random.seed(seed=12345)
  155. a = 3.0
  156. for b in (67.0*Hz, 0.067*kHz):
  157. for t_stop in (2345*ms, 2.345*second):
  158. spiketrain = stgen.homogeneous_gamma_process(a, b, t_stop=t_stop)
  159. intervals = isi(spiketrain)
  160. expected_spike_count = int((b/a * t_stop).simplified)
  161. self.assertLess(pdiff(expected_spike_count, spiketrain.size), 0.25) # should fail about 1 time in 1000
  162. expected_mean_isi = (a/b).rescale(ms)
  163. self.assertLess(pdiff(expected_mean_isi, intervals.mean()), 0.3)
  164. expected_first_spike = 0*ms
  165. self.assertLess(spiketrain[0] - expected_first_spike, 4*expected_mean_isi)
  166. expected_last_spike = t_stop
  167. self.assertLess(expected_last_spike - spiketrain[-1], 4*expected_mean_isi)
  168. # Kolmogorov-Smirnov test
  169. D, p = kstest(intervals.rescale(t_stop.units),
  170. "gamma",
  171. args=(a, 0, (1/b).rescale(t_stop.units)), # args are (a, loc, scale)
  172. alternative='two-sided')
  173. self.assertGreater(p, 0.001)
  174. self.assertLess(D, 0.25)
  175. class _n_poisson_TestCase(unittest.TestCase):
  176. def setUp(self):
  177. self.n = 4
  178. self.rate = 10*Hz
  179. self.rates = range(1, self.n + 1)*Hz
  180. self.t_stop = 10000*ms
  181. def test_poisson(self):
  182. # Check the output types for input rate + n number of neurons
  183. pp = stgen._n_poisson(rate=self.rate, t_stop=self.t_stop, n=self.n)
  184. self.assertIsInstance(pp, list)
  185. self.assertIsInstance(pp[0], neo.core.spiketrain.SpikeTrain)
  186. self.assertEqual(pp[0].simplified.units, 1000*ms)
  187. self.assertEqual(len(pp), self.n)
  188. # Check the output types for input list of rates
  189. pp = stgen._n_poisson(rate=self.rates, t_stop=self.t_stop)
  190. self.assertIsInstance(pp, list)
  191. self.assertIsInstance(pp[0], neo.core.spiketrain.SpikeTrain)
  192. self.assertEqual(pp[0].simplified.units, 1000*ms)
  193. self.assertEqual(len(pp), self.n)
  194. def test_poisson_error(self):
  195. # Dimensionless rate
  196. self.assertRaises(
  197. ValueError, stgen._n_poisson, rate=5, t_stop=self.t_stop)
  198. # Negative rate
  199. self.assertRaises(
  200. ValueError, stgen._n_poisson, rate=-5*Hz, t_stop=self.t_stop)
  201. # Negative value when rate is a list
  202. self.assertRaises(
  203. ValueError, stgen._n_poisson, rate=[-5, 3]*Hz, t_stop=self.t_stop)
  204. # Negative n
  205. self.assertRaises(
  206. ValueError, stgen._n_poisson, rate=self.rate, t_stop=self.t_stop,
  207. n=-1)
  208. # t_start>t_stop
  209. self.assertRaises(
  210. ValueError, stgen._n_poisson, rate=self.rate, t_start=4*ms,
  211. t_stop=3*ms, n=3)
  212. class singleinteractionprocess_TestCase(unittest.TestCase):
  213. def setUp(self):
  214. self.n = 4
  215. self.rate = 10*Hz
  216. self.rates = range(1, self.n + 1)*Hz
  217. self.t_stop = 10000*ms
  218. self.rate_c = 1*Hz
  219. def test_sip(self):
  220. # Generate an example SIP mode
  221. sip, coinc = stgen.single_interaction_process(
  222. n=self.n, t_stop=self.t_stop, rate=self.rate,
  223. rate_c=self.rate_c, return_coinc=True)
  224. # Check the output types
  225. self.assertEqual(type(sip), list)
  226. self.assertEqual(type(sip[0]), neo.core.spiketrain.SpikeTrain)
  227. self.assertEqual(type(coinc[0]), neo.core.spiketrain.SpikeTrain)
  228. self.assertEqual(sip[0].simplified.units, 1000*ms)
  229. self.assertEqual(coinc[0].simplified.units, 1000*ms)
  230. # Check the output length
  231. self.assertEqual(len(sip), self.n)
  232. self.assertEqual(
  233. len(coinc[0]), (self.rate_c*self.t_stop).rescale(dimensionless))
  234. # Generate an example SIP mode giving a list of rates as imput
  235. sip, coinc = stgen.single_interaction_process(
  236. t_stop=self.t_stop, rate=self.rates,
  237. rate_c=self.rate_c, return_coinc=True)
  238. # Check the output types
  239. self.assertEqual(type(sip), list)
  240. self.assertEqual(type(sip[0]), neo.core.spiketrain.SpikeTrain)
  241. self.assertEqual(type(coinc[0]), neo.core.spiketrain.SpikeTrain)
  242. self.assertEqual(sip[0].simplified.units, 1000*ms)
  243. self.assertEqual(coinc[0].simplified.units, 1000*ms)
  244. # Check the output length
  245. self.assertEqual(len(sip), self.n)
  246. self.assertEqual(
  247. len(coinc[0]), (self.rate_c*self.t_stop).rescale(dimensionless))
  248. # Generate an example SIP mode stochastic number of coincidences
  249. sip = stgen.single_interaction_process(
  250. n=self.n, t_stop=self.t_stop, rate=self.rate,
  251. rate_c=self.rate_c, coincidences='stochastic', return_coinc=False)
  252. # Check the output types
  253. self.assertEqual(type(sip), list)
  254. self.assertEqual(type(sip[0]), neo.core.spiketrain.SpikeTrain)
  255. self.assertEqual(sip[0].simplified.units, 1000*ms)
  256. def test_sip_error(self):
  257. # Negative rate
  258. self.assertRaises(
  259. ValueError, stgen.single_interaction_process, n=self.n, rate=-5*Hz,
  260. rate_c=self.rate_c, t_stop=self.t_stop)
  261. # Negative coincidence rate
  262. self.assertRaises(
  263. ValueError, stgen.single_interaction_process, n=self.n,
  264. rate=self.rate, rate_c=-3*Hz, t_stop=self.t_stop)
  265. # Negative value when rate is a list
  266. self.assertRaises(
  267. ValueError, stgen.single_interaction_process, n=self.n,
  268. rate=[-5, 3, 4, 2]*Hz, rate_c=self.rate_c, t_stop=self.t_stop)
  269. # Negative n
  270. self.assertRaises(
  271. ValueError, stgen.single_interaction_process, n=-1,
  272. rate=self.rate, rate_c=self.rate_c, t_stop=self.t_stop)
  273. # Rate_c < rate
  274. self.assertRaises(
  275. ValueError, stgen.single_interaction_process, n=self.n,
  276. rate=self.rate, rate_c=self.rate + 1*Hz, t_stop=self.t_stop)
  277. class cppTestCase(unittest.TestCase):
  278. def test_cpp_hom(self):
  279. # testing output with generic inputs
  280. A = [0, .9, .1]
  281. t_stop = 10 * 1000 * ms
  282. t_start = 5 * 1000 * ms
  283. rate = 3 * Hz
  284. cpp_hom = stgen.cpp(rate, A, t_stop, t_start=t_start)
  285. # testing the ouput formats
  286. self.assertEqual(
  287. [type(train) for train in cpp_hom], [neo.SpikeTrain]*len(cpp_hom))
  288. self.assertEqual(cpp_hom[0].simplified.units, 1000 * ms)
  289. self.assertEqual(type(cpp_hom), list)
  290. # testing quantities format of the output
  291. self.assertEqual(
  292. [train.simplified.units for train in cpp_hom], [1000 * ms]*len(
  293. cpp_hom))
  294. # testing output t_start t_stop
  295. for st in cpp_hom:
  296. self.assertEqual(st.t_stop, t_stop)
  297. self.assertEqual(st.t_start, t_start)
  298. self.assertEqual(len(cpp_hom), len(A) - 1)
  299. # testing the units
  300. A = [0, 0.9, 0.1]
  301. t_stop = 10000*ms
  302. t_start = 5 * 1000 * ms
  303. rate = 3 * Hz
  304. cpp_unit = stgen.cpp(rate, A, t_stop, t_start=t_start)
  305. self.assertEqual(cpp_unit[0].units, t_stop.units)
  306. self.assertEqual(cpp_unit[0].t_stop.units, t_stop.units)
  307. self.assertEqual(cpp_unit[0].t_start.units, t_stop.units)
  308. # testing output without copy of spikes
  309. A = [1]
  310. t_stop = 10 * 1000 * ms
  311. t_start = 5 * 1000 * ms
  312. rate = 3 * Hz
  313. cpp_hom_empty = stgen.cpp(rate, A, t_stop, t_start=t_start)
  314. self.assertEqual(
  315. [len(train) for train in cpp_hom_empty], [0]*len(cpp_hom_empty))
  316. # testing output with rate equal to 0
  317. A = [0, .9, .1]
  318. t_stop = 10 * 1000 * ms
  319. t_start = 5 * 1000 * ms
  320. rate = 0 * Hz
  321. cpp_hom_empty_r = stgen.cpp(rate, A, t_stop, t_start=t_start)
  322. self.assertEqual(
  323. [len(train) for train in cpp_hom_empty_r], [0]*len(
  324. cpp_hom_empty_r))
  325. # testing output with same spike trains in output
  326. A = [0, 0, 1]
  327. t_stop = 10 * 1000 * ms
  328. t_start = 5 * 1000 * ms
  329. rate = 3 * Hz
  330. cpp_hom_eq = stgen.cpp(rate, A, t_stop, t_start=t_start)
  331. self.assertTrue(
  332. np.allclose(cpp_hom_eq[0].magnitude, cpp_hom_eq[1].magnitude))
  333. def test_cpp_hom_errors(self):
  334. # testing raises of ValueError (wrong inputs)
  335. # testing empty amplitude
  336. self.assertRaises(
  337. ValueError, stgen.cpp, A=[], t_stop=10*1000 * ms, rate=3*Hz)
  338. # testing sum of amplitude>1
  339. self.assertRaises(
  340. ValueError, stgen.cpp, A=[1, 1, 1], t_stop=10*1000 * ms, rate=3*Hz)
  341. # testing negative value in the amplitude
  342. self.assertRaises(
  343. ValueError, stgen.cpp, A=[-1, 1, 1], t_stop=10*1000 * ms,
  344. rate=3*Hz)
  345. # test negative rate
  346. self.assertRaises(
  347. AssertionError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  348. rate=-3*Hz)
  349. # test wrong unit for rate
  350. self.assertRaises(
  351. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  352. rate=3*1000 * ms)
  353. # testing raises of AttributeError (missing input units)
  354. # Testing missing unit to t_stop
  355. self.assertRaises(
  356. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10, rate=3*Hz)
  357. # Testing missing unit to t_start
  358. self.assertRaises(
  359. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms, rate=3*Hz,
  360. t_start=3)
  361. # testing rate missing unit
  362. self.assertRaises(
  363. AttributeError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  364. rate=3)
  365. def test_cpp_het(self):
  366. # testing output with generic inputs
  367. A = [0, .9, .1]
  368. t_stop = 10 * 1000 * ms
  369. t_start = 5 * 1000 * ms
  370. rate = [3, 4] * Hz
  371. cpp_het = stgen.cpp(rate, A, t_stop, t_start=t_start)
  372. # testing the ouput formats
  373. self.assertEqual(
  374. [type(train) for train in cpp_het], [neo.SpikeTrain]*len(cpp_het))
  375. self.assertEqual(cpp_het[0].simplified.units, 1000 * ms)
  376. self.assertEqual(type(cpp_het), list)
  377. # testing units
  378. self.assertEqual(
  379. [train.simplified.units for train in cpp_het], [1000 * ms]*len(
  380. cpp_het))
  381. # testing output t_start and t_stop
  382. for st in cpp_het:
  383. self.assertEqual(st.t_stop, t_stop)
  384. self.assertEqual(st.t_start, t_start)
  385. # testing the number of output spiketrains
  386. self.assertEqual(len(cpp_het), len(A) - 1)
  387. self.assertEqual(len(cpp_het), len(rate))
  388. # testing the units
  389. A = [0, 0.9, 0.1]
  390. t_stop = 10000*ms
  391. t_start = 5 * 1000 * ms
  392. rate = [3, 4] * Hz
  393. cpp_unit = stgen.cpp(rate, A, t_stop, t_start=t_start)
  394. self.assertEqual(cpp_unit[0].units, t_stop.units)
  395. self.assertEqual(cpp_unit[0].t_stop.units, t_stop.units)
  396. self.assertEqual(cpp_unit[0].t_start.units, t_stop.units)
  397. # testing without copying any spikes
  398. A = [1, 0, 0]
  399. t_stop = 10 * 1000 * ms
  400. t_start = 5 * 1000 * ms
  401. rate = [3, 4] * Hz
  402. cpp_het_empty = stgen.cpp(rate, A, t_stop, t_start=t_start)
  403. self.assertEqual(len(cpp_het_empty[0]), 0)
  404. # testing output with rate equal to 0
  405. A = [0, .9, .1]
  406. t_stop = 10 * 1000 * ms
  407. t_start = 5 * 1000 * ms
  408. rate = [0, 0] * Hz
  409. cpp_het_empty_r = stgen.cpp(rate, A, t_stop, t_start=t_start)
  410. self.assertEqual(
  411. [len(train) for train in cpp_het_empty_r], [0]*len(
  412. cpp_het_empty_r))
  413. # testing completely sync spiketrains
  414. A = [0, 0, 1]
  415. t_stop = 10 * 1000 * ms
  416. t_start = 5 * 1000 * ms
  417. rate = [3, 3] * Hz
  418. cpp_het_eq = stgen.cpp(rate, A, t_stop, t_start=t_start)
  419. self.assertTrue(np.allclose(
  420. cpp_het_eq[0].magnitude, cpp_het_eq[1].magnitude))
  421. def test_cpp_het_err(self):
  422. # testing raises of ValueError (wrong inputs)
  423. # testing empty amplitude
  424. self.assertRaises(
  425. ValueError, stgen.cpp, A=[], t_stop=10*1000 * ms, rate=[3, 4]*Hz)
  426. # testing sum amplitude>1
  427. self.assertRaises(
  428. ValueError, stgen.cpp, A=[1, 1, 1], t_stop=10*1000 * ms,
  429. rate=[3, 4]*Hz)
  430. # testing amplitude negative value
  431. self.assertRaises(
  432. ValueError, stgen.cpp, A=[-1, 1, 1], t_stop=10*1000 * ms,
  433. rate=[3, 4]*Hz)
  434. # testing negative rate
  435. self.assertRaises(
  436. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  437. rate=[-3, 4]*Hz)
  438. # testing empty rate
  439. self.assertRaises(
  440. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms, rate=[]*Hz)
  441. # testing empty amplitude
  442. self.assertRaises(
  443. ValueError, stgen.cpp, A=[], t_stop=10*1000 * ms, rate=[3, 4]*Hz)
  444. # testing different len(A)-1 and len(rate)
  445. self.assertRaises(
  446. ValueError, stgen.cpp, A=[0, 1], t_stop=10*1000 * ms, rate=[3, 4]*Hz)
  447. # testing rate with different unit from Hz
  448. self.assertRaises(
  449. ValueError, stgen.cpp, A=[0, 1], t_stop=10*1000 * ms,
  450. rate=[3, 4]*1000 * ms)
  451. # Testing analytical constrain between amplitude and rate
  452. self.assertRaises(
  453. ValueError, stgen.cpp, A=[0, 0, 1], t_stop=10*1000 * ms,
  454. rate=[3, 4]*Hz, t_start=3)
  455. # testing raises of AttributeError (missing input units)
  456. # Testing missing unit to t_stop
  457. self.assertRaises(
  458. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10, rate=[3, 4]*Hz)
  459. # Testing missing unit to t_start
  460. self.assertRaises(
  461. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  462. rate=[3, 4]*Hz, t_start=3)
  463. # Testing missing unit to rate
  464. self.assertRaises(
  465. AttributeError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  466. rate=[3, 4])
  467. def test_cpp_jttered(self):
  468. # testing output with generic inputs
  469. A = [0, .9, .1]
  470. t_stop = 10 * 1000 * ms
  471. t_start = 5 * 1000 * ms
  472. rate = 3 * Hz
  473. cpp_shift = stgen.cpp(
  474. rate, A, t_stop, t_start=t_start, shift=3*ms)
  475. # testing the ouput formats
  476. self.assertEqual(
  477. [type(train) for train in cpp_shift], [neo.SpikeTrain]*len(
  478. cpp_shift))
  479. self.assertEqual(cpp_shift[0].simplified.units, 1000 * ms)
  480. self.assertEqual(type(cpp_shift), list)
  481. # testing quantities format of the output
  482. self.assertEqual(
  483. [train.simplified.units for train in cpp_shift],
  484. [1000 * ms]*len(cpp_shift))
  485. # testing output t_start t_stop
  486. for st in cpp_shift:
  487. self.assertEqual(st.t_stop, t_stop)
  488. self.assertEqual(st.t_start, t_start)
  489. self.assertEqual(len(cpp_shift), len(A) - 1)
  490. if __name__ == '__main__':
  491. unittest.main()