test_asset.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for the ASSET analysis.
  4. :copyright: Copyright 2014-2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. import unittest
  8. import numpy as np
  9. import scipy.spatial
  10. import quantities as pq
  11. import neo
  12. try:
  13. import sklearn
  14. except ImportError:
  15. HAVE_SKLEARN = False
  16. else:
  17. import elephant.asset as asset
  18. HAVE_SKLEARN = True
  19. stretchedmetric2d = asset._stretched_metric_2d
  20. cluster = asset.cluster_matrix_entries
  21. @unittest.skipUnless(HAVE_SKLEARN, 'requires sklearn')
  22. class AssetTestCase(unittest.TestCase):
  23. def test_stretched_metric_2d_size(self):
  24. nr_points = 4
  25. x = np.arange(nr_points)
  26. D = stretchedmetric2d(x, x, stretch=1, ref_angle=45)
  27. self.assertEqual(D.shape, (nr_points, nr_points))
  28. def test_stretched_metric_2d_correct_stretching(self):
  29. x = (0, 1, 0)
  30. y = (0, 0, 1)
  31. stretch = 10
  32. ref_angle = 0
  33. D = stretchedmetric2d(x, y, stretch=stretch, ref_angle=ref_angle)
  34. self.assertEqual(D[0, 1], 1)
  35. self.assertEqual(D[0, 2], stretch)
  36. def test_stretched_metric_2d_symmetric(self):
  37. x = (1, 2, 2)
  38. y = (1, 2, 0)
  39. stretch = 10
  40. D = stretchedmetric2d(x, y, stretch=stretch, ref_angle=45)
  41. np.testing.assert_array_almost_equal(D, D.T, decimal=12)
  42. def test_stretched_metric_2d_equals_euclidean_if_stretch_1(self):
  43. x = np.arange(10)
  44. y = y = x ** 2 - 2 * x - 4
  45. # compute stretched distance matrix
  46. stretch = 1
  47. D = stretchedmetric2d(x, y, stretch=stretch, ref_angle=45)
  48. # Compute Euclidean distance matrix
  49. points = np.vstack([x, y]).T
  50. E = scipy.spatial.distance_matrix(points, points)
  51. # assert D == E
  52. np.testing.assert_array_almost_equal(D, E, decimal=12)
  53. def test_cluster_correct(self):
  54. mat = np.zeros((6, 6))
  55. mat[[2, 4, 5], [0, 0, 1]] = 1
  56. mat_clustered = cluster(mat, eps=4, min=2, stretch=6)
  57. mat_correct = np.zeros((6, 6))
  58. mat_correct[[4, 5], [0, 1]] = 1
  59. mat_correct[2, 0] = -1
  60. np.testing.assert_array_equal(mat_clustered, mat_correct)
  61. def test_cluster_symmetric(self):
  62. x = [0, 1, 2, 5, 6, 7]
  63. y = [3, 4, 5, 1, 2, 3]
  64. mat = np.zeros((10, 10))
  65. mat[x, y] = 1
  66. mat = mat + mat.T
  67. # compute stretched distance matrix
  68. mat_clustered = cluster(mat, eps=4, min=2, stretch=6)
  69. mat_equals_m1 = (mat_clustered == -1)
  70. mat_equals_0 = (mat_clustered == 0)
  71. mat_larger_0 = (mat_clustered > 0)
  72. np.testing.assert_array_equal(mat_equals_m1, mat_equals_m1.T)
  73. np.testing.assert_array_equal(mat_equals_0, mat_equals_0.T)
  74. np.testing.assert_array_equal(mat_larger_0, mat_larger_0.T)
  75. def test_sse_difference(self):
  76. a = {(1, 2): set([1, 2, 3]), (3, 4): set([5, 6]), (6, 7): set([0, 1])}
  77. b = {(1, 2): set([1, 2, 5]), (5, 6): set([0, 2]), (6, 7): set([0, 1])}
  78. diff_ab_pixelwise = {(3, 4): set([5, 6])}
  79. diff_ba_pixelwise = {(5, 6): set([0, 2])}
  80. diff_ab_linkwise = {(1, 2): set([3]), (3, 4): set([5, 6])}
  81. diff_ba_linkwise = {(1, 2): set([5]), (5, 6): set([0, 2])}
  82. self.assertEqual(
  83. asset.sse_difference(a, b, 'pixelwise'), diff_ab_pixelwise)
  84. self.assertEqual(
  85. asset.sse_difference(b, a, 'pixelwise'), diff_ba_pixelwise)
  86. self.assertEqual(
  87. asset.sse_difference(a, b, 'linkwise'), diff_ab_linkwise)
  88. self.assertEqual(
  89. asset.sse_difference(b, a, 'linkwise'), diff_ba_linkwise)
  90. def test_sse_intersection(self):
  91. a = {(1, 2): set([1, 2, 3]), (3, 4): set([5, 6]), (6, 7): set([0, 1])}
  92. b = {(1, 2): set([1, 2, 5]), (5, 6): set([0, 2]), (6, 7): set([0, 1])}
  93. inters_ab_pixelwise = {(1, 2): set([1, 2, 3]), (6, 7): set([0, 1])}
  94. inters_ba_pixelwise = {(1, 2): set([1, 2, 5]), (6, 7): set([0, 1])}
  95. inters_ab_linkwise = {(1, 2): set([1, 2]), (6, 7): set([0, 1])}
  96. inters_ba_linkwise = {(1, 2): set([1, 2]), (6, 7): set([0, 1])}
  97. self.assertEqual(
  98. asset.sse_intersection(a, b, 'pixelwise'), inters_ab_pixelwise)
  99. self.assertEqual(
  100. asset.sse_intersection(b, a, 'pixelwise'), inters_ba_pixelwise)
  101. self.assertEqual(
  102. asset.sse_intersection(a, b, 'linkwise'), inters_ab_linkwise)
  103. self.assertEqual(
  104. asset.sse_intersection(b, a, 'linkwise'), inters_ba_linkwise)
  105. def test_sse_relations(self):
  106. a = {(1, 2): set([1, 2, 3]), (3, 4): set([5, 6]), (6, 7): set([0, 1])}
  107. b = {(1, 2): set([1, 2, 5]), (5, 6): set([0, 2]), (6, 7): set([0, 1])}
  108. c = {(5, 6): set([0, 2])}
  109. d = {(3, 4): set([0, 1]), (5, 6): set([0, 1, 2])}
  110. self.assertTrue(asset.sse_isequal({}, {}))
  111. self.assertTrue(asset.sse_isequal(a, a))
  112. self.assertFalse(asset.sse_isequal(b, c))
  113. self.assertTrue(asset.sse_isdisjoint(a, c))
  114. self.assertTrue(asset.sse_isdisjoint(a, d))
  115. self.assertFalse(asset.sse_isdisjoint(a, b))
  116. self.assertTrue(asset.sse_issub(c, b))
  117. self.assertTrue(asset.sse_issub(c, d))
  118. self.assertFalse(asset.sse_issub(a, b))
  119. self.assertTrue(asset.sse_issuper(b, c))
  120. self.assertTrue(asset.sse_issuper(d, c))
  121. self.assertFalse(asset.sse_issuper(a, b))
  122. self.assertTrue(asset.sse_overlap(a, b))
  123. self.assertFalse(asset.sse_overlap(c, d))
  124. def test_mask_matrix(self):
  125. mat1 = np.array([[0, 1], [1, 2]])
  126. mat2 = np.array([[2, 1], [1, 3]])
  127. mask_1_2 = asset.mask_matrices([mat1, mat2], [1, 2])
  128. mask_1_2_correct = np.array([[False, False], [False, True]])
  129. self.assertTrue(np.all(mask_1_2 == mask_1_2_correct))
  130. self.assertIsInstance(mask_1_2[0, 0], np.bool_)
  131. def test_cluster_matrix_entries(self):
  132. mat = np.array([[False, False, True, False],
  133. [False, True, False, False],
  134. [True, False, False, True],
  135. [False, False, True, False]])
  136. clustered1 = asset.cluster_matrix_entries(
  137. mat, eps=1.5, min=2, stretch=1)
  138. clustered2 = asset.cluster_matrix_entries(
  139. mat, eps=1.5, min=3, stretch=1)
  140. clustered1_correctA = np.array([[0, 0, 1, 0],
  141. [0, 1, 0, 0],
  142. [1, 0, 0, 2],
  143. [0, 0, 2, 0]])
  144. clustered1_correctB = np.array([[0, 0, 2, 0],
  145. [0, 2, 0, 0],
  146. [2, 0, 0, 1],
  147. [0, 0, 1, 0]])
  148. clustered2_correct = np.array([[0, 0, 1, 0],
  149. [0, 1, 0, 0],
  150. [1, 0, 0, -1],
  151. [0, 0, -1, 0]])
  152. self.assertTrue(np.all(clustered1 == clustered1_correctA) or
  153. np.all(clustered1 == clustered1_correctB))
  154. self.assertTrue(np.all(clustered2 == clustered2_correct))
  155. def test_intersection_matrix(self):
  156. st1 = neo.SpikeTrain([1, 2, 4]*pq.ms, t_stop=6*pq.ms)
  157. st2 = neo.SpikeTrain([1, 3, 4]*pq.ms, t_stop=6*pq.ms)
  158. st3 = neo.SpikeTrain([2, 5]*pq.ms, t_start=1*pq.ms, t_stop=6*pq.ms)
  159. st4 = neo.SpikeTrain([1, 3, 6]*pq.ms, t_stop=8*pq.ms)
  160. binsize = 1 * pq.ms
  161. # Check that the routine works for correct input...
  162. # ...same t_start, t_stop on both time axes
  163. imat_1_2, xedges, yedges = asset.intersection_matrix(
  164. [st1, st2], binsize, dt=5*pq.ms)
  165. trueimat_1_2 = np.array([[0., 0., 0., 0., 0.],
  166. [0., 2., 1., 1., 2.],
  167. [0., 1., 1., 0., 1.],
  168. [0., 1., 0., 1., 1.],
  169. [0., 2., 1., 1., 2.]])
  170. self.assertTrue(np.all(xedges == np.arange(6)*pq.ms)) # correct bins
  171. self.assertTrue(np.all(yedges == np.arange(6)*pq.ms)) # correct bins
  172. self.assertTrue(np.all(imat_1_2 == trueimat_1_2)) # correct matrix
  173. # ...different t_start, t_stop on the two time axes
  174. imat_1_2, xedges, yedges = asset.intersection_matrix(
  175. [st1, st2], binsize, t_start_y=1*pq.ms, dt=5*pq.ms)
  176. trueimat_1_2 = np.array([[0., 0., 0., 0., 0.],
  177. [2., 1., 1., 2., 0.],
  178. [1., 1., 0., 1., 0.],
  179. [1., 0., 1., 1., 0.],
  180. [2., 1., 1., 2., 0.]])
  181. self.assertTrue(np.all(xedges == np.arange(6)*pq.ms)) # correct bins
  182. self.assertTrue(np.all(imat_1_2 == trueimat_1_2)) # correct matrix
  183. # Check that errors are raised correctly...
  184. # ...for dt too large compared to length of spike trains
  185. self.assertRaises(ValueError, asset.intersection_matrix,
  186. spiketrains=[st1, st2], binsize=binsize, dt=8*pq.ms)
  187. # ...for different SpikeTrain's t_starts
  188. self.assertRaises(ValueError, asset.intersection_matrix,
  189. spiketrains=[st1, st3], binsize=binsize, dt=8*pq.ms)
  190. # ...when the analysis is specified for a time span where the
  191. # spike trains are not defined (e.g. t_start_x < SpikeTrain.t_start)
  192. self.assertRaises(ValueError, asset.intersection_matrix,
  193. spiketrains=[st1, st2], binsize=binsize, dt=8*pq.ms,
  194. t_start_x=-2*pq.ms, t_start_y=-2*pq.ms)
  195. def suite():
  196. suite = unittest.makeSuite(AssetTestCase, 'test')
  197. return suite
  198. def run():
  199. runner = unittest.TextTestRunner(verbosity=2)
  200. runner.run(suite())
  201. if __name__ == "__main__":
  202. unittest.main()