test_unitary_event_analysis.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. """
  2. Unit tests for the Unitary Events analysis
  3. :copyright: Copyright 2016 by the Elephant team, see AUTHORS.txt.
  4. :license: Modified BSD, see LICENSE.txt for details.
  5. """
  6. import unittest
  7. import numpy as np
  8. import quantities as pq
  9. import types
  10. import elephant.unitary_event_analysis as ue
  11. import neo
  12. class UETestCase(unittest.TestCase):
  13. def setUp(self):
  14. sts1_with_trial = [[ 26., 48., 78., 144., 178.],
  15. [ 4., 45., 85., 123., 156., 185.],
  16. [ 22., 53., 73., 88., 120., 147., 167., 193.],
  17. [ 23., 49., 74., 116., 142., 166., 189.],
  18. [ 5., 34., 54., 80., 108., 128., 150., 181.],
  19. [ 18., 61., 107., 170.],
  20. [ 62., 98., 131., 161.],
  21. [ 37., 63., 86., 131., 168.],
  22. [ 39., 76., 100., 127., 153., 198.],
  23. [ 3., 35., 60., 88., 108., 141., 171., 184.],
  24. [ 39., 170.],
  25. [ 25., 68., 170.],
  26. [ 19., 57., 84., 116., 157., 192.],
  27. [ 17., 80., 131., 172.],
  28. [ 33., 65., 124., 162., 192.],
  29. [ 58., 87., 185.],
  30. [ 19., 101., 174.],
  31. [ 84., 118., 156., 198., 199.],
  32. [ 5., 55., 67., 96., 114., 148., 172., 199.],
  33. [ 61., 105., 131., 169., 195.],
  34. [ 26., 96., 129., 157.],
  35. [ 41., 85., 157., 199.],
  36. [ 6., 30., 53., 76., 109., 142., 167., 194.],
  37. [ 159.],
  38. [ 6., 51., 78., 113., 154., 183.],
  39. [ 138.],
  40. [ 23., 59., 154., 185.],
  41. [ 12., 14., 52., 54., 109., 145., 192.],
  42. [ 29., 61., 84., 122., 145., 168.],
  43. [ 26., 99.],
  44. [ 3., 31., 55., 85., 108., 158., 191.],
  45. [ 5., 37., 70., 119., 170.],
  46. [ 38., 79., 117., 157., 192.],
  47. [ 174.],
  48. [ 114.],
  49. []]
  50. sts2_with_trial = [[ 3., 119.],
  51. [ 54., 155., 183.],
  52. [ 35., 133.],
  53. [ 25., 100., 176.],
  54. [ 9., 98.],
  55. [ 6., 97., 198.],
  56. [ 7., 62., 148.],
  57. [ 100., 158.],
  58. [ 7., 62., 122., 179., 191.],
  59. [ 125., 182.],
  60. [ 30., 55., 127., 157., 196.],
  61. [ 27., 70., 173.],
  62. [ 82., 84., 198.],
  63. [ 11., 29., 137.],
  64. [ 5., 49., 61., 101., 142., 190.],
  65. [ 78., 162., 178.],
  66. [ 13., 14., 130., 172.],
  67. [ 22.],
  68. [ 16., 55., 109., 113., 175.],
  69. [ 17., 33., 63., 102., 144., 189., 190.],
  70. [ 58.],
  71. [ 27., 30., 99., 145., 176.],
  72. [ 10., 58., 116., 182.],
  73. [ 14., 68., 104., 126., 162., 194.],
  74. [ 56., 129., 196.],
  75. [ 50., 78., 105., 152., 190., 197.],
  76. [ 24., 66., 113., 117., 161.],
  77. [ 9., 31., 81., 95., 136., 154.],
  78. [ 10., 115., 185., 191.],
  79. [ 71., 140., 157.],
  80. [ 15., 27., 88., 102., 103., 151., 181., 188.],
  81. [ 51., 75., 95., 134., 195.],
  82. [ 18., 55., 75., 131., 186.],
  83. [ 10., 16., 41., 42., 75., 127.],
  84. [ 62., 76., 102., 145., 171., 183.],
  85. [ 66., 71., 85., 140., 154.]]
  86. self.sts1_neo = [neo.SpikeTrain(
  87. i*pq.ms,t_stop = 200*pq.ms) for i in sts1_with_trial]
  88. self.sts2_neo = [neo.SpikeTrain(
  89. i*pq.ms,t_stop = 200*pq.ms) for i in sts2_with_trial]
  90. self.binary_sts = np.array([[[1, 1, 1, 1, 0],
  91. [0, 1, 1, 1, 0],
  92. [0, 1, 1, 0, 1]],
  93. [[1, 1, 1, 1, 1],
  94. [0, 1, 1, 1, 1],
  95. [1, 1, 0, 1, 0]]])
  96. def test_hash_default(self):
  97. m = np.array([[0,0,0], [1,0,0], [0,1,0], [0,0,1], [1,1,0],
  98. [1,0,1],[0,1,1],[1,1,1]])
  99. expected = np.array([77,43,23])
  100. h = ue.hash_from_pattern(m, N=8)
  101. self.assertTrue(np.all(expected == h))
  102. def test_hash_default_longpattern(self):
  103. m = np.zeros((100,2))
  104. m[0,0] = 1
  105. expected = np.array([2**99,0])
  106. h = ue.hash_from_pattern(m, N=100)
  107. self.assertTrue(np.all(expected == h))
  108. def test_hash_ValueError_wrong_orientation(self):
  109. m = np.array([[0,0,0], [1,0,0], [0,1,0], [0,0,1], [1,1,0],
  110. [1,0,1],[0,1,1],[1,1,1]])
  111. self.assertRaises(ValueError, ue.hash_from_pattern, m, N=3)
  112. def test_hash_ValueError_wrong_entries(self):
  113. m = np.array([[0,0,0], [1,0,0], [0,2,0], [0,0,1], [1,1,0],
  114. [1,0,1],[0,1,1],[1,1,1]])
  115. self.assertRaises(ValueError, ue.hash_from_pattern, m, N=3)
  116. def test_hash_base_not_two(self):
  117. m = np.array([[0,0,0], [1,0,0], [0,1,0], [0,0,1], [1,1,0],
  118. [1,0,1],[0,1,1],[1,1,1]])
  119. m = m.T
  120. base = 3
  121. expected = np.array([0,9,3,1,12,10,4,13])
  122. h = ue.hash_from_pattern(m, N=3, base=base)
  123. self.assertTrue(np.all(expected == h))
  124. ## TODO: write a test for ValueError in inverse_hash_from_pattern
  125. def test_invhash_ValueError(self):
  126. self.assertRaises(ValueError, ue.inverse_hash_from_pattern, [128, 8], 4)
  127. def test_invhash_default_base(self):
  128. N = 3
  129. h = np.array([0, 4, 2, 1, 6, 5, 3, 7])
  130. expected = np.array([[0, 1, 0, 0, 1, 1, 0, 1],[0, 0, 1, 0, 1, 0, 1, 1],[0, 0, 0, 1, 0, 1, 1, 1]])
  131. m = ue.inverse_hash_from_pattern(h, N)
  132. self.assertTrue(np.all(expected == m))
  133. def test_invhash_base_not_two(self):
  134. N = 3
  135. h = np.array([1,4,13])
  136. base = 3
  137. expected = np.array([[0,0,1],[0,1,1],[1,1,1]])
  138. m = ue.inverse_hash_from_pattern(h, N, base)
  139. self.assertTrue(np.all(expected == m))
  140. def test_invhash_shape_mat(self):
  141. N = 8
  142. h = np.array([178, 212, 232])
  143. expected = np.array([[0,0,0], [1,0,0], [0,1,0], [0,0,1], [1,1,0],[1,0,1],[0,1,1],[1,1,1]])
  144. m = ue.inverse_hash_from_pattern(h, N)
  145. self.assertTrue(np.shape(m)[0] == N)
  146. def test_hash_invhash_consistency(self):
  147. m = np.array([[0, 0, 0],[1, 0, 0],[0, 1, 0],[0, 0, 1],[1, 1, 0],[1, 0, 1],[0, 1, 1],[1, 1, 1]])
  148. inv_h = ue.hash_from_pattern(m, N=8)
  149. m1 = ue.inverse_hash_from_pattern(inv_h, N = 8)
  150. self.assertTrue(np.all(m == m1))
  151. def test_n_emp_mat_default(self):
  152. mat = np.array([[0, 0, 0, 1, 1],[0, 0, 0, 0, 1],[1, 0, 1, 1, 1],[1, 0, 1, 1, 1]])
  153. N = 4
  154. pattern_hash = [3, 15]
  155. expected1 = np.array([ 2., 1.])
  156. expected2 = [[0, 2], [4]]
  157. nemp,nemp_indices = ue.n_emp_mat(mat,N,pattern_hash)
  158. self.assertTrue(np.all(nemp == expected1))
  159. for item_cnt,item in enumerate(nemp_indices):
  160. self.assertTrue(np.allclose(expected2[item_cnt],item))
  161. def test_n_emp_mat_sum_trial_default(self):
  162. mat = self.binary_sts
  163. pattern_hash = np.array([4,6])
  164. N = 3
  165. expected1 = np.array([ 1., 3.])
  166. expected2 = [[[0], [3]],[[],[2,4]]]
  167. n_emp, n_emp_idx = ue.n_emp_mat_sum_trial(mat, N,pattern_hash)
  168. self.assertTrue(np.all(n_emp == expected1))
  169. for item0_cnt,item0 in enumerate(n_emp_idx):
  170. for item1_cnt,item1 in enumerate(item0):
  171. self.assertTrue(np.allclose(expected2[item0_cnt][item1_cnt],item1))
  172. def test_n_emp_mat_sum_trial_ValueError(self):
  173. mat = np.array([[0,0,0], [1,0,0], [0,1,0], [0,0,1], [1,1,0],
  174. [1,0,1],[0,1,1],[1,1,1]])
  175. self.assertRaises(ValueError,ue.n_emp_mat_sum_trial,mat,N=2,pattern_hash = [3,6])
  176. def test_n_exp_mat_default(self):
  177. mat = np.array([[0, 0, 0, 1, 1],[0, 0, 0, 0, 1],[1, 0, 1, 1, 1],[1, 0, 1, 1, 1]])
  178. N = 4
  179. pattern_hash = [3, 11]
  180. expected = np.array([ 1.536, 1.024])
  181. nexp = ue.n_exp_mat(mat,N,pattern_hash)
  182. self.assertTrue(np.allclose(expected,nexp))
  183. def test_n_exp_mat_sum_trial_default(self):
  184. mat = self.binary_sts
  185. pattern_hash = np.array([5,6])
  186. N = 3
  187. expected = np.array([ 1.56, 2.56])
  188. n_exp = ue.n_exp_mat_sum_trial(mat, N,pattern_hash)
  189. self.assertTrue(np.allclose(n_exp,expected))
  190. def test_n_exp_mat_sum_trial_TrialAverage(self):
  191. mat = self.binary_sts
  192. pattern_hash = np.array([5,6])
  193. N = 3
  194. expected = np.array([ 1.62, 2.52])
  195. n_exp = ue.n_exp_mat_sum_trial(mat, N, pattern_hash, method='analytic_TrialAverage')
  196. self.assertTrue(np.allclose(n_exp,expected))
  197. def test_n_exp_mat_sum_trial_surrogate(self):
  198. mat = self.binary_sts
  199. pattern_hash = np.array([5])
  200. N = 3
  201. n_exp_anal = ue.n_exp_mat_sum_trial(mat, N, pattern_hash, method='analytic_TrialAverage')
  202. n_exp_surr = ue.n_exp_mat_sum_trial(mat, N, pattern_hash, method='surrogate_TrialByTrial',n_surr = 1000)
  203. self.assertLess((np.abs(n_exp_anal[0]-np.mean(n_exp_surr))/n_exp_anal[0]),0.1)
  204. def test_n_exp_mat_sum_trial_ValueError(self):
  205. mat = np.array([[0,0,0], [1,0,0], [0,1,0], [0,0,1], [1,1,0],
  206. [1,0,1],[0,1,1],[1,1,1]])
  207. self.assertRaises(ValueError,ue.n_exp_mat_sum_trial,mat,N=2,pattern_hash = [3,6])
  208. def test_gen_pval_anal_default(self):
  209. mat = np.array([[[1, 1, 1, 1, 0],
  210. [0, 1, 1, 1, 0],
  211. [0, 1, 1, 0, 1]],
  212. [[1, 1, 1, 1, 1],
  213. [0, 1, 1, 1, 1],
  214. [1, 1, 0, 1, 0]]])
  215. pattern_hash = np.array([5,6])
  216. N = 3
  217. expected = np.array([ 1.56, 2.56])
  218. pval_func,n_exp = ue.gen_pval_anal(mat, N,pattern_hash)
  219. self.assertTrue(np.allclose(n_exp,expected))
  220. self.assertTrue(isinstance(pval_func, types.FunctionType))
  221. def test_jointJ_default(self):
  222. p_val = np.array([0.31271072, 0.01175031])
  223. expected = np.array([0.3419968 , 1.92481736])
  224. self.assertTrue(np.allclose(ue.jointJ(p_val),expected))
  225. def test__rate_mat_avg_trial_default(self):
  226. mat = self.binary_sts
  227. expected = [0.9, 0.7,0.6]
  228. self.assertTrue(np.allclose(expected,ue._rate_mat_avg_trial(mat)))
  229. def test__bintime(self):
  230. t = 13*pq.ms
  231. binsize = 3*pq.ms
  232. expected = 4
  233. self.assertTrue(np.allclose(expected,ue._bintime(t,binsize)))
  234. def test__winpos(self):
  235. t_start = 10*pq.ms
  236. t_stop = 46*pq.ms
  237. winsize = 15*pq.ms
  238. winstep = 3*pq.ms
  239. expected = [ 10., 13., 16., 19., 22., 25., 28., 31.]*pq.ms
  240. self.assertTrue(
  241. np.allclose(
  242. ue._winpos(
  243. t_start, t_stop, winsize,
  244. winstep).rescale('ms').magnitude,
  245. expected.rescale('ms').magnitude))
  246. def test__UE_default(self):
  247. mat = self.binary_sts
  248. pattern_hash = np.array([4,6])
  249. N = 3
  250. expected_S = np.array([-0.26226523, 0.04959301])
  251. expected_idx = [[[0], [3]], [[], [2, 4]]]
  252. expected_nemp = np.array([ 1., 3.])
  253. expected_nexp = np.array([ 1.04, 2.56])
  254. expected_rate = np.array([ 0.9, 0.7, 0.6])
  255. S, rate_avg, n_exp, n_emp,indices = ue._UE(mat,N,pattern_hash)
  256. self.assertTrue(np.allclose(S ,expected_S))
  257. self.assertTrue(np.allclose(n_exp ,expected_nexp))
  258. self.assertTrue(np.allclose(n_emp ,expected_nemp))
  259. self.assertTrue(np.allclose(expected_rate ,rate_avg))
  260. for item0_cnt,item0 in enumerate(indices):
  261. for item1_cnt,item1 in enumerate(item0):
  262. self.assertTrue(np.allclose(expected_idx[item0_cnt][item1_cnt],item1))
  263. def test__UE_surrogate(self):
  264. mat = self.binary_sts
  265. pattern_hash = np.array([4])
  266. N = 3
  267. _, rate_avg_surr, _, n_emp_surr,indices_surr =\
  268. ue._UE(mat, N, pattern_hash, method='surrogate_TrialByTrial', n_surr=100)
  269. _, rate_avg, _, n_emp,indices =\
  270. ue._UE(mat, N, pattern_hash, method='analytic_TrialByTrial')
  271. self.assertTrue(np.allclose(n_emp ,n_emp_surr))
  272. self.assertTrue(np.allclose(rate_avg ,rate_avg_surr))
  273. for item0_cnt,item0 in enumerate(indices):
  274. for item1_cnt,item1 in enumerate(item0):
  275. self.assertTrue(np.allclose(indices_surr[item0_cnt][item1_cnt],item1))
  276. def test_jointJ_window_analysis(self):
  277. sts1 = self.sts1_neo
  278. sts2 = self.sts2_neo
  279. data = np.vstack((sts1,sts2)).T
  280. winsize = 100*pq.ms
  281. binsize = 5*pq.ms
  282. winstep = 20*pq.ms
  283. pattern_hash = [3]
  284. UE_dic = ue.jointJ_window_analysis(data, binsize, winsize, winstep, pattern_hash)
  285. expected_Js = np.array(
  286. [ 0.57953708, 0.47348757, 0.1729669 ,
  287. 0.01883295, -0.21934742,-0.80608759])
  288. expected_n_emp = np.array(
  289. [ 9., 9., 7., 7., 6., 6.])
  290. expected_n_exp = np.array(
  291. [ 6.5 , 6.85, 6.05, 6.6 , 6.45, 8.7 ])
  292. expected_rate = np.array(
  293. [[ 0.02166667, 0.01861111],
  294. [ 0.02277778, 0.01777778],
  295. [ 0.02111111, 0.01777778],
  296. [ 0.02277778, 0.01888889],
  297. [ 0.02305556, 0.01722222],
  298. [ 0.02388889, 0.02055556]])*pq.kHz
  299. expected_indecis_tril26 = [ 4., 4.]
  300. expected_indecis_tril4 = [ 1.]
  301. self.assertTrue(np.allclose(UE_dic['Js'] ,expected_Js))
  302. self.assertTrue(np.allclose(UE_dic['n_emp'] ,expected_n_emp))
  303. self.assertTrue(np.allclose(UE_dic['n_exp'] ,expected_n_exp))
  304. self.assertTrue(np.allclose(
  305. UE_dic['rate_avg'].rescale('Hz').magnitude ,
  306. expected_rate.rescale('Hz').magnitude))
  307. self.assertTrue(np.allclose(
  308. UE_dic['indices']['trial26'],expected_indecis_tril26))
  309. self.assertTrue(np.allclose(
  310. UE_dic['indices']['trial4'],expected_indecis_tril4))
  311. def suite():
  312. suite = unittest.makeSuite(UETestCase, 'test')
  313. return suite
  314. if __name__ == "__main__":
  315. runner = unittest.TextTestRunner(verbosity=2)
  316. runner.run(suite())