test_spike_train_dissimilarity.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests for the spike train dissimilarity measures module.
  4. :copyright: Copyright 2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. import unittest
  8. from neo import SpikeTrain
  9. import numpy as np
  10. from numpy.testing import assert_array_equal, assert_array_almost_equal
  11. import scipy.integrate as spint
  12. from quantities import ms, s, Hz
  13. import elephant.kernels as kernels
  14. import elephant.spike_train_generation as stg
  15. import elephant.spike_train_dissimilarity as stds
  16. class TimeScaleDependSpikeTrainDissimMeasures_TestCase(unittest.TestCase):
  17. def setUp(self):
  18. self.st00 = SpikeTrain([], units='ms', t_stop=1000.0)
  19. self.st01 = SpikeTrain([1], units='ms', t_stop=1000.0)
  20. self.st02 = SpikeTrain([2], units='ms', t_stop=1000.0)
  21. self.st03 = SpikeTrain([2.9], units='ms', t_stop=1000.0)
  22. self.st04 = SpikeTrain([3.1], units='ms', t_stop=1000.0)
  23. self.st05 = SpikeTrain([5], units='ms', t_stop=1000.0)
  24. self.st06 = SpikeTrain([500], units='ms', t_stop=1000.0)
  25. self.st07 = SpikeTrain([12, 32], units='ms', t_stop=1000.0)
  26. self.st08 = SpikeTrain([32, 52], units='ms', t_stop=1000.0)
  27. self.st09 = SpikeTrain([42], units='ms', t_stop=1000.0)
  28. self.st10 = SpikeTrain([18, 60], units='ms', t_stop=1000.0)
  29. self.st11 = SpikeTrain([10, 20, 30, 40], units='ms', t_stop=1000.0)
  30. self.st12 = SpikeTrain([40, 30, 20, 10], units='ms', t_stop=1000.0)
  31. self.st13 = SpikeTrain([15, 25, 35, 45], units='ms', t_stop=1000.0)
  32. self.st14 = SpikeTrain([10, 20, 30, 40, 50], units='ms', t_stop=1000.0)
  33. self.st15 = SpikeTrain([0.01, 0.02, 0.03, 0.04, 0.05],
  34. units='s', t_stop=1000.0)
  35. self.st16 = SpikeTrain([12, 16, 28, 30, 42], units='ms', t_stop=1000.0)
  36. self.st21 = stg.homogeneous_poisson_process(50*Hz, 0*ms, 1000*ms)
  37. self.st22 = stg.homogeneous_poisson_process(40*Hz, 0*ms, 1000*ms)
  38. self.st23 = stg.homogeneous_poisson_process(30*Hz, 0*ms, 1000*ms)
  39. self.rd_st_list = [self.st21, self.st22, self.st23]
  40. self.st31 = SpikeTrain([12.0], units='ms', t_stop=1000.0)
  41. self.st32 = SpikeTrain([12.0, 12.0], units='ms', t_stop=1000.0)
  42. self.st33 = SpikeTrain([20.0], units='ms', t_stop=1000.0)
  43. self.st34 = SpikeTrain([20.0, 20.0], units='ms', t_stop=1000.0)
  44. self.array1 = np.arange(1, 10)
  45. self.array2 = np.arange(1.2, 10)
  46. self.qarray1 = self.array1 * Hz
  47. self.qarray2 = self.array2 * Hz
  48. self.tau0 = 0.0 * ms
  49. self.q0 = np.inf / ms
  50. self.tau1 = 0.000000001 * ms
  51. self.q1 = 1.0 / self.tau1
  52. self.tau2 = 1.0 * ms
  53. self.q2 = 1.0 / self.tau2
  54. self.tau3 = 10.0 * ms
  55. self.q3 = 1.0 / self.tau3
  56. self.tau4 = 100.0 * ms
  57. self.q4 = 1.0 / self.tau4
  58. self.tau5 = 1000000000.0 * ms
  59. self.q5 = 1.0 / self.tau5
  60. self.tau6 = np.inf * ms
  61. self.q6 = 0.0 / ms
  62. self.tau7 = 0.01 * s
  63. self.q7 = 1.0 / self.tau7
  64. self.t = np.linspace(0, 200, 20000001) * ms
  65. def test_wrong_input(self):
  66. self.assertRaises(TypeError, stds.victor_purpura_dist,
  67. [self.array1, self.array2], self.q3)
  68. self.assertRaises(TypeError, stds.victor_purpura_dist,
  69. [self.qarray1, self.qarray2], self.q3)
  70. self.assertRaises(TypeError, stds.victor_purpura_dist,
  71. [self.qarray1, self.qarray2], 5.0 * ms)
  72. self.assertRaises(TypeError, stds.victor_purpura_dist,
  73. [self.array1, self.array2], self.q3,
  74. algorithm='intuitive')
  75. self.assertRaises(TypeError, stds.victor_purpura_dist,
  76. [self.qarray1, self.qarray2], self.q3,
  77. algorithm='intuitive')
  78. self.assertRaises(TypeError, stds.victor_purpura_dist,
  79. [self.qarray1, self.qarray2], 5.0 * ms,
  80. algorithm='intuitive')
  81. self.assertRaises(TypeError, stds.van_rossum_dist,
  82. [self.array1, self.array2], self.tau3)
  83. self.assertRaises(TypeError, stds.van_rossum_dist,
  84. [self.qarray1, self.qarray2], self.tau3)
  85. self.assertRaises(TypeError, stds.van_rossum_dist,
  86. [self.qarray1, self.qarray2], 5.0 * Hz)
  87. self.assertRaises(TypeError, stds.victor_purpura_dist,
  88. [self.st11, self.st13], self.tau2)
  89. self.assertRaises(TypeError, stds.victor_purpura_dist,
  90. [self.st11, self.st13], 5.0)
  91. self.assertRaises(TypeError, stds.victor_purpura_dist,
  92. [self.st11, self.st13], self.tau2,
  93. algorithm='intuitive')
  94. self.assertRaises(TypeError, stds.victor_purpura_dist,
  95. [self.st11, self.st13], 5.0,
  96. algorithm='intuitive')
  97. self.assertRaises(TypeError, stds.van_rossum_dist,
  98. [self.st11, self.st13], self.q4)
  99. self.assertRaises(TypeError, stds.van_rossum_dist,
  100. [self.st11, self.st13], 5.0)
  101. self.assertRaises(NotImplementedError, stds.victor_purpura_dist,
  102. [self.st01, self.st02], self.q3,
  103. kernel=kernels.Kernel(2.0 / self.q3))
  104. self.assertRaises(NotImplementedError, stds.victor_purpura_dist,
  105. [self.st01, self.st02], self.q3,
  106. kernel=kernels.SymmetricKernel(2.0 / self.q3))
  107. self.assertEqual(stds.victor_purpura_dist(
  108. [self.st01, self.st02], self.q1,
  109. kernel=kernels.TriangularKernel(
  110. 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1],
  111. stds.victor_purpura_dist(
  112. [self.st01, self.st02], self.q3,
  113. kernel=kernels.TriangularKernel(
  114. 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1])
  115. self.assertEqual(stds.victor_purpura_dist(
  116. [self.st01, self.st02],
  117. kernel=kernels.TriangularKernel(
  118. 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1], 1.0)
  119. self.assertNotEqual(stds.victor_purpura_dist(
  120. [self.st01, self.st02],
  121. kernel=kernels.AlphaKernel(
  122. 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1], 1.0)
  123. self.assertRaises(NameError, stds.victor_purpura_dist,
  124. [self.st11, self.st13], self.q2, algorithm='slow')
  125. def test_victor_purpura_distance_fast(self):
  126. # Tests of distances of simplest spike trains:
  127. self.assertEqual(stds.victor_purpura_dist(
  128. [self.st00, self.st00], self.q2)[0, 1], 0.0)
  129. self.assertEqual(stds.victor_purpura_dist(
  130. [self.st00, self.st01], self.q2)[0, 1], 1.0)
  131. self.assertEqual(stds.victor_purpura_dist(
  132. [self.st01, self.st00], self.q2)[0, 1], 1.0)
  133. self.assertEqual(stds.victor_purpura_dist(
  134. [self.st01, self.st01], self.q2)[0, 1], 0.0)
  135. # Tests of distances under elementary spike operations
  136. self.assertEqual(stds.victor_purpura_dist(
  137. [self.st01, self.st02], self.q2)[0, 1], 1.0)
  138. self.assertEqual(stds.victor_purpura_dist(
  139. [self.st01, self.st03], self.q2)[0, 1], 1.9)
  140. self.assertEqual(stds.victor_purpura_dist(
  141. [self.st01, self.st04], self.q2)[0, 1], 2.0)
  142. self.assertEqual(stds.victor_purpura_dist(
  143. [self.st01, self.st05], self.q2)[0, 1], 2.0)
  144. self.assertEqual(stds.victor_purpura_dist(
  145. [self.st00, self.st07], self.q2)[0, 1], 2.0)
  146. self.assertAlmostEqual(stds.victor_purpura_dist(
  147. [self.st07, self.st08], self.q4)[0, 1], 0.4)
  148. self.assertAlmostEqual(stds.victor_purpura_dist(
  149. [self.st07, self.st10], self.q3)[0, 1], 0.6 + 2)
  150. self.assertEqual(stds.victor_purpura_dist(
  151. [self.st11, self.st14], self.q2)[0, 1], 1)
  152. # Tests on timescales
  153. self.assertEqual(stds.victor_purpura_dist(
  154. [self.st11, self.st14], self.q1)[0, 1],
  155. stds.victor_purpura_dist(
  156. [self.st11, self.st14], self.q5)[0, 1])
  157. self.assertEqual(stds.victor_purpura_dist(
  158. [self.st07, self.st11], self.q0)[0, 1], 6.0)
  159. self.assertEqual(stds.victor_purpura_dist(
  160. [self.st07, self.st11], self.q1)[0, 1], 6.0)
  161. self.assertAlmostEqual(stds.victor_purpura_dist(
  162. [self.st07, self.st11], self.q5)[0, 1], 2.0, 5)
  163. self.assertEqual(stds.victor_purpura_dist(
  164. [self.st07, self.st11], self.q6)[0, 1], 2.0)
  165. # Tests on unordered spiketrains
  166. self.assertEqual(stds.victor_purpura_dist(
  167. [self.st11, self.st13], self.q4)[0, 1],
  168. stds.victor_purpura_dist(
  169. [self.st12, self.st13], self.q4)[0, 1])
  170. self.assertNotEqual(stds.victor_purpura_dist(
  171. [self.st11, self.st13], self.q4,
  172. sort=False)[0, 1],
  173. stds.victor_purpura_dist(
  174. [self.st12, self.st13], self.q4,
  175. sort=False)[0, 1])
  176. # Tests on metric properties with random spiketrains
  177. # (explicit calculation of second metric axiom in particular case,
  178. # because from dist_matrix it is trivial)
  179. dist_matrix = stds.victor_purpura_dist(
  180. [self.st21, self.st22, self.st23], self.q3)
  181. for i in range(3):
  182. for j in range(3):
  183. self.assertGreaterEqual(dist_matrix[i, j], 0)
  184. if dist_matrix[i, j] == 0:
  185. assert_array_equal(self.rd_st_list[i], self.rd_st_list[j])
  186. assert_array_equal(stds.victor_purpura_dist(
  187. [self.st21, self.st22], self.q3),
  188. stds.victor_purpura_dist(
  189. [self.st22, self.st21], self.q3))
  190. self.assertLessEqual(dist_matrix[0, 1],
  191. dist_matrix[0, 2] + dist_matrix[1, 2])
  192. self.assertLessEqual(dist_matrix[0, 2],
  193. dist_matrix[1, 2] + dist_matrix[0, 1])
  194. self.assertLessEqual(dist_matrix[1, 2],
  195. dist_matrix[0, 1] + dist_matrix[0, 2])
  196. # Tests on proper unit conversion
  197. self.assertAlmostEqual(
  198. stds.victor_purpura_dist([self.st14, self.st16], self.q3)[0, 1],
  199. stds.victor_purpura_dist([self.st15, self.st16], self.q3)[0, 1])
  200. self.assertAlmostEqual(
  201. stds.victor_purpura_dist([self.st16, self.st14], self.q3)[0, 1],
  202. stds.victor_purpura_dist([self.st16, self.st15], self.q3)[0, 1])
  203. self.assertEqual(
  204. stds.victor_purpura_dist([self.st01, self.st05], self.q3)[0, 1],
  205. stds.victor_purpura_dist([self.st01, self.st05], self.q7)[0, 1])
  206. # Tests on algorithmic behaviour for equal spike times
  207. self.assertEqual(
  208. stds.victor_purpura_dist([self.st31, self.st34], self.q3)[0, 1],
  209. 0.8 + 1.0)
  210. self.assertEqual(
  211. stds.victor_purpura_dist([self.st31, self.st34], self.q3)[0, 1],
  212. stds.victor_purpura_dist([self.st32, self.st33], self.q3)[0, 1])
  213. self.assertEqual(
  214. stds.victor_purpura_dist(
  215. [self.st31, self.st33], self.q3)[0, 1] * 2.0,
  216. stds.victor_purpura_dist(
  217. [self.st32, self.st34], self.q3)[0, 1])
  218. # Tests on spike train list lengthes smaller than 2
  219. self.assertEqual(stds.victor_purpura_dist(
  220. [self.st21], self.q3)[0, 0], 0)
  221. self.assertEqual(len(stds.victor_purpura_dist([], self.q3)), 0)
  222. def test_victor_purpura_distance_intuitive(self):
  223. # Tests of distances of simplest spike trains
  224. self.assertEqual(stds.victor_purpura_dist(
  225. [self.st00, self.st00], self.q2,
  226. algorithm='intuitive')[0, 1], 0.0)
  227. self.assertEqual(stds.victor_purpura_dist(
  228. [self.st00, self.st01], self.q2,
  229. algorithm='intuitive')[0, 1], 1.0)
  230. self.assertEqual(stds.victor_purpura_dist(
  231. [self.st01, self.st00], self.q2,
  232. algorithm='intuitive')[0, 1], 1.0)
  233. self.assertEqual(stds.victor_purpura_dist(
  234. [self.st01, self.st01], self.q2,
  235. algorithm='intuitive')[0, 1], 0.0)
  236. # Tests of distances under elementary spike operations
  237. self.assertEqual(stds.victor_purpura_dist(
  238. [self.st01, self.st02], self.q2,
  239. algorithm='intuitive')[0, 1], 1.0)
  240. self.assertEqual(stds.victor_purpura_dist(
  241. [self.st01, self.st03], self.q2,
  242. algorithm='intuitive')[0, 1], 1.9)
  243. self.assertEqual(stds.victor_purpura_dist(
  244. [self.st01, self.st04], self.q2,
  245. algorithm='intuitive')[0, 1], 2.0)
  246. self.assertEqual(stds.victor_purpura_dist(
  247. [self.st01, self.st05], self.q2,
  248. algorithm='intuitive')[0, 1], 2.0)
  249. self.assertEqual(stds.victor_purpura_dist(
  250. [self.st00, self.st07], self.q2,
  251. algorithm='intuitive')[0, 1], 2.0)
  252. self.assertAlmostEqual(stds.victor_purpura_dist(
  253. [self.st07, self.st08], self.q4,
  254. algorithm='intuitive')[0, 1], 0.4)
  255. self.assertAlmostEqual(stds.victor_purpura_dist(
  256. [self.st07, self.st10], self.q3,
  257. algorithm='intuitive')[0, 1], 2.6)
  258. self.assertEqual(stds.victor_purpura_dist(
  259. [self.st11, self.st14], self.q2,
  260. algorithm='intuitive')[0, 1], 1)
  261. # Tests on timescales
  262. self.assertEqual(stds.victor_purpura_dist(
  263. [self.st11, self.st14], self.q1,
  264. algorithm='intuitive')[0, 1],
  265. stds.victor_purpura_dist(
  266. [self.st11, self.st14], self.q5,
  267. algorithm='intuitive')[0, 1])
  268. self.assertEqual(stds.victor_purpura_dist(
  269. [self.st07, self.st11], self.q0,
  270. algorithm='intuitive')[0, 1], 6.0)
  271. self.assertEqual(stds.victor_purpura_dist(
  272. [self.st07, self.st11], self.q1,
  273. algorithm='intuitive')[0, 1], 6.0)
  274. self.assertAlmostEqual(stds.victor_purpura_dist(
  275. [self.st07, self.st11], self.q5,
  276. algorithm='intuitive')[0, 1], 2.0, 5)
  277. self.assertEqual(stds.victor_purpura_dist(
  278. [self.st07, self.st11], self.q6,
  279. algorithm='intuitive')[0, 1], 2.0)
  280. # Tests on unordered spiketrains
  281. self.assertEqual(stds.victor_purpura_dist(
  282. [self.st11, self.st13], self.q4,
  283. algorithm='intuitive')[0, 1],
  284. stds.victor_purpura_dist(
  285. [self.st12, self.st13], self.q4,
  286. algorithm='intuitive')[0, 1])
  287. self.assertNotEqual(stds.victor_purpura_dist(
  288. [self.st11, self.st13], self.q4,
  289. sort=False, algorithm='intuitive')[0, 1],
  290. stds.victor_purpura_dist(
  291. [self.st12, self.st13], self.q4,
  292. sort=False, algorithm='intuitive')[0, 1])
  293. # Tests on metric properties with random spiketrains
  294. # (explicit calculation of second metric axiom in particular case,
  295. # because from dist_matrix it is trivial)
  296. dist_matrix = stds.victor_purpura_dist(
  297. [self.st21, self.st22, self.st23],
  298. self.q3, algorithm='intuitive')
  299. for i in range(3):
  300. for j in range(3):
  301. self.assertGreaterEqual(dist_matrix[i, j], 0)
  302. if dist_matrix[i, j] == 0:
  303. assert_array_equal(self.rd_st_list[i], self.rd_st_list[j])
  304. assert_array_equal(stds.victor_purpura_dist(
  305. [self.st21, self.st22], self.q3,
  306. algorithm='intuitive'),
  307. stds.victor_purpura_dist(
  308. [self.st22, self.st21], self.q3,
  309. algorithm='intuitive'))
  310. self.assertLessEqual(dist_matrix[0, 1],
  311. dist_matrix[0, 2] + dist_matrix[1, 2])
  312. self.assertLessEqual(dist_matrix[0, 2],
  313. dist_matrix[1, 2] + dist_matrix[0, 1])
  314. self.assertLessEqual(dist_matrix[1, 2],
  315. dist_matrix[0, 1] + dist_matrix[0, 2])
  316. # Tests on proper unit conversion
  317. self.assertAlmostEqual(stds.victor_purpura_dist(
  318. [self.st14, self.st16], self.q3,
  319. algorithm='intuitive')[0, 1],
  320. stds.victor_purpura_dist(
  321. [self.st15, self.st16], self.q3,
  322. algorithm='intuitive')[0, 1])
  323. self.assertAlmostEqual(stds.victor_purpura_dist(
  324. [self.st16, self.st14], self.q3,
  325. algorithm='intuitive')[0, 1],
  326. stds.victor_purpura_dist(
  327. [self.st16, self.st15], self.q3,
  328. algorithm='intuitive')[0, 1])
  329. self.assertEqual(stds.victor_purpura_dist(
  330. [self.st01, self.st05], self.q3,
  331. algorithm='intuitive')[0, 1],
  332. stds.victor_purpura_dist(
  333. [self.st01, self.st05], self.q7,
  334. algorithm='intuitive')[0, 1])
  335. # Tests on algorithmic behaviour for equal spike times
  336. self.assertEqual(stds.victor_purpura_dist(
  337. [self.st31, self.st34], self.q3,
  338. algorithm='intuitive')[0, 1],
  339. 0.8 + 1.0)
  340. self.assertEqual(stds.victor_purpura_dist(
  341. [self.st31, self.st34], self.q3,
  342. algorithm='intuitive')[0, 1],
  343. stds.victor_purpura_dist(
  344. [self.st32, self.st33], self.q3,
  345. algorithm='intuitive')[0, 1])
  346. self.assertEqual(stds.victor_purpura_dist(
  347. [self.st31, self.st33], self.q3,
  348. algorithm='intuitive')[0, 1] * 2.0,
  349. stds.victor_purpura_dist(
  350. [self.st32, self.st34], self.q3,
  351. algorithm='intuitive')[0, 1])
  352. # Tests on spike train list lengthes smaller than 2
  353. self.assertEqual(stds.victor_purpura_dist(
  354. [self.st21], self.q3,
  355. algorithm='intuitive')[0, 0], 0)
  356. self.assertEqual(len(stds.victor_purpura_dist(
  357. [], self.q3, algorithm='intuitive')), 0)
  358. def test_victor_purpura_algorithm_comparison(self):
  359. assert_array_almost_equal(
  360. stds.victor_purpura_dist([self.st21, self.st22, self.st23],
  361. self.q3),
  362. stds.victor_purpura_dist([self.st21, self.st22, self.st23],
  363. self.q3, algorithm='intuitive'))
  364. def test_van_rossum_distance(self):
  365. # Tests of distances of simplest spike trains
  366. self.assertEqual(stds.van_rossum_dist(
  367. [self.st00, self.st00], self.tau2)[0, 1], 0.0)
  368. self.assertEqual(stds.van_rossum_dist(
  369. [self.st00, self.st01], self.tau2)[0, 1], 1.0)
  370. self.assertEqual(stds.van_rossum_dist(
  371. [self.st01, self.st00], self.tau2)[0, 1], 1.0)
  372. self.assertEqual(stds.van_rossum_dist(
  373. [self.st01, self.st01], self.tau2)[0, 1], 0.0)
  374. # Tests of distances under elementary spike operations
  375. self.assertAlmostEqual(stds.van_rossum_dist(
  376. [self.st01, self.st02], self.tau2)[0, 1],
  377. float(np.sqrt(2*(1.0-np.exp(-np.absolute(
  378. ((self.st01[0]-self.st02[0]) /
  379. self.tau2).simplified))))))
  380. self.assertAlmostEqual(stds.van_rossum_dist(
  381. [self.st01, self.st05], self.tau2)[0, 1],
  382. float(np.sqrt(2*(1.0-np.exp(-np.absolute(
  383. ((self.st01[0]-self.st05[0]) /
  384. self.tau2).simplified))))))
  385. self.assertAlmostEqual(stds.van_rossum_dist(
  386. [self.st01, self.st05], self.tau2)[0, 1],
  387. np.sqrt(2.0), 1)
  388. self.assertAlmostEqual(stds.van_rossum_dist(
  389. [self.st01, self.st06], self.tau2)[0, 1],
  390. np.sqrt(2.0), 20)
  391. self.assertAlmostEqual(stds.van_rossum_dist(
  392. [self.st00, self.st07], self.tau1)[0, 1],
  393. np.sqrt(0 + 2))
  394. self.assertAlmostEqual(stds.van_rossum_dist(
  395. [self.st07, self.st08], self.tau4)[0, 1],
  396. float(np.sqrt(2*(1.0-np.exp(-np.absolute(
  397. ((self.st07[0]-self.st08[-1]) /
  398. self.tau4).simplified))))))
  399. f_minus_g_squared = (
  400. (self.t > self.st08[0]) * np.exp(
  401. -((self.t-self.st08[0])/self.tau3).simplified) +
  402. (self.t > self.st08[1]) * np.exp(
  403. -((self.t-self.st08[1])/self.tau3).simplified) -
  404. (self.t > self.st09[0]) * np.exp(
  405. -((self.t-self.st09[0])/self.tau3).simplified))**2
  406. distance = np.sqrt(2.0 * spint.cumtrapz(
  407. y=f_minus_g_squared, x=self.t.magnitude)[-1] /
  408. self.tau3.rescale(self.t.units).magnitude)
  409. self.assertAlmostEqual(stds.van_rossum_dist(
  410. [self.st08, self.st09], self.tau3)[0, 1], distance, 5)
  411. self.assertAlmostEqual(stds.van_rossum_dist(
  412. [self.st11, self.st14], self.tau2)[0, 1], 1)
  413. # Tests on timescales
  414. self.assertAlmostEqual(
  415. stds.van_rossum_dist([self.st11, self.st14], self.tau1)[0, 1],
  416. stds.van_rossum_dist([self.st11, self.st14], self.tau5)[0, 1])
  417. self.assertAlmostEqual(
  418. stds.van_rossum_dist([self.st07, self.st11], self.tau0)[0, 1],
  419. np.sqrt(len(self.st07) + len(self.st11)))
  420. self.assertAlmostEqual(
  421. stds.van_rossum_dist([self.st07, self.st14], self.tau0)[0, 1],
  422. np.sqrt(len(self.st07) + len(self.st14)))
  423. self.assertAlmostEqual(
  424. stds.van_rossum_dist([self.st07, self.st11], self.tau1)[0, 1],
  425. np.sqrt(len(self.st07) + len(self.st11)))
  426. self.assertAlmostEqual(
  427. stds.van_rossum_dist([self.st07, self.st14], self.tau1)[0, 1],
  428. np.sqrt(len(self.st07) + len(self.st14)))
  429. self.assertAlmostEqual(
  430. stds.van_rossum_dist([self.st07, self.st11], self.tau5)[0, 1],
  431. np.absolute(len(self.st07) - len(self.st11)))
  432. self.assertAlmostEqual(
  433. stds.van_rossum_dist([self.st07, self.st14], self.tau5)[0, 1],
  434. np.absolute(len(self.st07) - len(self.st14)))
  435. self.assertAlmostEqual(
  436. stds.van_rossum_dist([self.st07, self.st11], self.tau6)[0, 1],
  437. np.absolute(len(self.st07) - len(self.st11)))
  438. self.assertAlmostEqual(
  439. stds.van_rossum_dist([self.st07, self.st14], self.tau6)[0, 1],
  440. np.absolute(len(self.st07) - len(self.st14)))
  441. # Tests on unordered spiketrains
  442. self.assertEqual(
  443. stds.van_rossum_dist([self.st11, self.st13], self.tau4)[0, 1],
  444. stds.van_rossum_dist([self.st12, self.st13], self.tau4)[0, 1])
  445. self.assertNotEqual(
  446. stds.van_rossum_dist([self.st11, self.st13],
  447. self.tau4, sort=False)[0, 1],
  448. stds.van_rossum_dist([self.st12, self.st13],
  449. self.tau4, sort=False)[0, 1])
  450. # Tests on metric properties with random spiketrains
  451. # (explicit calculation of second metric axiom in particular case,
  452. # because from dist_matrix it is trivial)
  453. dist_matrix = stds.van_rossum_dist(
  454. [self.st21, self.st22, self.st23], self.tau3)
  455. for i in range(3):
  456. for j in range(3):
  457. self.assertGreaterEqual(dist_matrix[i, j], 0)
  458. if dist_matrix[i, j] == 0:
  459. assert_array_equal(self.rd_st_list[i], self.rd_st_list[j])
  460. assert_array_equal(
  461. stds.van_rossum_dist([self.st21, self.st22], self.tau3),
  462. stds.van_rossum_dist([self.st22, self.st21], self.tau3))
  463. self.assertLessEqual(dist_matrix[0, 1],
  464. dist_matrix[0, 2] + dist_matrix[1, 2])
  465. self.assertLessEqual(dist_matrix[0, 2],
  466. dist_matrix[1, 2] + dist_matrix[0, 1])
  467. self.assertLessEqual(dist_matrix[1, 2],
  468. dist_matrix[0, 1] + dist_matrix[0, 2])
  469. # Tests on proper unit conversion
  470. self.assertAlmostEqual(
  471. stds.van_rossum_dist([self.st14, self.st16], self.tau3)[0, 1],
  472. stds.van_rossum_dist([self.st15, self.st16], self.tau3)[0, 1])
  473. self.assertAlmostEqual(
  474. stds.van_rossum_dist([self.st16, self.st14], self.tau3)[0, 1],
  475. stds.van_rossum_dist([self.st16, self.st15], self.tau3)[0, 1])
  476. self.assertEqual(
  477. stds.van_rossum_dist([self.st01, self.st05], self.tau3)[0, 1],
  478. stds.van_rossum_dist([self.st01, self.st05], self.tau7)[0, 1])
  479. # Tests on algorithmic behaviour for equal spike times
  480. f_minus_g_squared = (
  481. (self.t > self.st31[0]) * np.exp(
  482. -((self.t-self.st31[0])/self.tau3).simplified) -
  483. (self.t > self.st34[0]) * np.exp(
  484. -((self.t-self.st34[0])/self.tau3).simplified) -
  485. (self.t > self.st34[1]) * np.exp(
  486. -((self.t-self.st34[1])/self.tau3).simplified))**2
  487. distance = np.sqrt(2.0 * spint.cumtrapz(
  488. y=f_minus_g_squared, x=self.t.magnitude)[-1] /
  489. self.tau3.rescale(self.t.units).magnitude)
  490. self.assertAlmostEqual(stds.van_rossum_dist([self.st31, self.st34],
  491. self.tau3)[0, 1],
  492. distance, 5)
  493. self.assertEqual(stds.van_rossum_dist([self.st31, self.st34],
  494. self.tau3)[0, 1],
  495. stds.van_rossum_dist([self.st32, self.st33],
  496. self.tau3)[0, 1])
  497. self.assertEqual(stds.van_rossum_dist([self.st31, self.st33],
  498. self.tau3)[0, 1] * 2.0,
  499. stds.van_rossum_dist([self.st32, self.st34],
  500. self.tau3)[0, 1])
  501. # Tests on spike train list lengthes smaller than 2
  502. self.assertEqual(stds.van_rossum_dist([self.st21], self.tau3)[0, 0], 0)
  503. self.assertEqual(len(stds.van_rossum_dist([], self.tau3)), 0)
  504. if __name__ == '__main__':
  505. unittest.main()