123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520 |
- # -*- coding: utf-8 -*-
- """
- Tests for the spike train dissimilarity measures module.
- :copyright: Copyright 2016 by the Elephant team, see AUTHORS.txt.
- :license: Modified BSD, see LICENSE.txt for details.
- """
- import unittest
- from neo import SpikeTrain
- import numpy as np
- from numpy.testing import assert_array_equal, assert_array_almost_equal
- import scipy.integrate as spint
- from quantities import ms, s, Hz
- import elephant.kernels as kernels
- import elephant.spike_train_generation as stg
- import elephant.spike_train_dissimilarity as stds
- class TimeScaleDependSpikeTrainDissimMeasures_TestCase(unittest.TestCase):
- def setUp(self):
- self.st00 = SpikeTrain([], units='ms', t_stop=1000.0)
- self.st01 = SpikeTrain([1], units='ms', t_stop=1000.0)
- self.st02 = SpikeTrain([2], units='ms', t_stop=1000.0)
- self.st03 = SpikeTrain([2.9], units='ms', t_stop=1000.0)
- self.st04 = SpikeTrain([3.1], units='ms', t_stop=1000.0)
- self.st05 = SpikeTrain([5], units='ms', t_stop=1000.0)
- self.st06 = SpikeTrain([500], units='ms', t_stop=1000.0)
- self.st07 = SpikeTrain([12, 32], units='ms', t_stop=1000.0)
- self.st08 = SpikeTrain([32, 52], units='ms', t_stop=1000.0)
- self.st09 = SpikeTrain([42], units='ms', t_stop=1000.0)
- self.st10 = SpikeTrain([18, 60], units='ms', t_stop=1000.0)
- self.st11 = SpikeTrain([10, 20, 30, 40], units='ms', t_stop=1000.0)
- self.st12 = SpikeTrain([40, 30, 20, 10], units='ms', t_stop=1000.0)
- self.st13 = SpikeTrain([15, 25, 35, 45], units='ms', t_stop=1000.0)
- self.st14 = SpikeTrain([10, 20, 30, 40, 50], units='ms', t_stop=1000.0)
- self.st15 = SpikeTrain([0.01, 0.02, 0.03, 0.04, 0.05],
- units='s', t_stop=1000.0)
- self.st16 = SpikeTrain([12, 16, 28, 30, 42], units='ms', t_stop=1000.0)
- self.st21 = stg.homogeneous_poisson_process(50*Hz, 0*ms, 1000*ms)
- self.st22 = stg.homogeneous_poisson_process(40*Hz, 0*ms, 1000*ms)
- self.st23 = stg.homogeneous_poisson_process(30*Hz, 0*ms, 1000*ms)
- self.rd_st_list = [self.st21, self.st22, self.st23]
- self.st31 = SpikeTrain([12.0], units='ms', t_stop=1000.0)
- self.st32 = SpikeTrain([12.0, 12.0], units='ms', t_stop=1000.0)
- self.st33 = SpikeTrain([20.0], units='ms', t_stop=1000.0)
- self.st34 = SpikeTrain([20.0, 20.0], units='ms', t_stop=1000.0)
- self.array1 = np.arange(1, 10)
- self.array2 = np.arange(1.2, 10)
- self.qarray1 = self.array1 * Hz
- self.qarray2 = self.array2 * Hz
- self.tau0 = 0.0 * ms
- self.q0 = np.inf / ms
- self.tau1 = 0.000000001 * ms
- self.q1 = 1.0 / self.tau1
- self.tau2 = 1.0 * ms
- self.q2 = 1.0 / self.tau2
- self.tau3 = 10.0 * ms
- self.q3 = 1.0 / self.tau3
- self.tau4 = 100.0 * ms
- self.q4 = 1.0 / self.tau4
- self.tau5 = 1000000000.0 * ms
- self.q5 = 1.0 / self.tau5
- self.tau6 = np.inf * ms
- self.q6 = 0.0 / ms
- self.tau7 = 0.01 * s
- self.q7 = 1.0 / self.tau7
- self.t = np.linspace(0, 200, 20000001) * ms
- def test_wrong_input(self):
- self.assertRaises(TypeError, stds.victor_purpura_dist,
- [self.array1, self.array2], self.q3)
- self.assertRaises(TypeError, stds.victor_purpura_dist,
- [self.qarray1, self.qarray2], self.q3)
- self.assertRaises(TypeError, stds.victor_purpura_dist,
- [self.qarray1, self.qarray2], 5.0 * ms)
- self.assertRaises(TypeError, stds.victor_purpura_dist,
- [self.array1, self.array2], self.q3,
- algorithm='intuitive')
- self.assertRaises(TypeError, stds.victor_purpura_dist,
- [self.qarray1, self.qarray2], self.q3,
- algorithm='intuitive')
- self.assertRaises(TypeError, stds.victor_purpura_dist,
- [self.qarray1, self.qarray2], 5.0 * ms,
- algorithm='intuitive')
- self.assertRaises(TypeError, stds.van_rossum_dist,
- [self.array1, self.array2], self.tau3)
- self.assertRaises(TypeError, stds.van_rossum_dist,
- [self.qarray1, self.qarray2], self.tau3)
- self.assertRaises(TypeError, stds.van_rossum_dist,
- [self.qarray1, self.qarray2], 5.0 * Hz)
- self.assertRaises(TypeError, stds.victor_purpura_dist,
- [self.st11, self.st13], self.tau2)
- self.assertRaises(TypeError, stds.victor_purpura_dist,
- [self.st11, self.st13], 5.0)
- self.assertRaises(TypeError, stds.victor_purpura_dist,
- [self.st11, self.st13], self.tau2,
- algorithm='intuitive')
- self.assertRaises(TypeError, stds.victor_purpura_dist,
- [self.st11, self.st13], 5.0,
- algorithm='intuitive')
- self.assertRaises(TypeError, stds.van_rossum_dist,
- [self.st11, self.st13], self.q4)
- self.assertRaises(TypeError, stds.van_rossum_dist,
- [self.st11, self.st13], 5.0)
- self.assertRaises(NotImplementedError, stds.victor_purpura_dist,
- [self.st01, self.st02], self.q3,
- kernel=kernels.Kernel(2.0 / self.q3))
- self.assertRaises(NotImplementedError, stds.victor_purpura_dist,
- [self.st01, self.st02], self.q3,
- kernel=kernels.SymmetricKernel(2.0 / self.q3))
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st02], self.q1,
- kernel=kernels.TriangularKernel(
- 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1],
- stds.victor_purpura_dist(
- [self.st01, self.st02], self.q3,
- kernel=kernels.TriangularKernel(
- 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1])
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st02],
- kernel=kernels.TriangularKernel(
- 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1], 1.0)
- self.assertNotEqual(stds.victor_purpura_dist(
- [self.st01, self.st02],
- kernel=kernels.AlphaKernel(
- 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1], 1.0)
- self.assertRaises(NameError, stds.victor_purpura_dist,
- [self.st11, self.st13], self.q2, algorithm='slow')
- def test_victor_purpura_distance_fast(self):
- # Tests of distances of simplest spike trains:
- self.assertEqual(stds.victor_purpura_dist(
- [self.st00, self.st00], self.q2)[0, 1], 0.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st00, self.st01], self.q2)[0, 1], 1.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st00], self.q2)[0, 1], 1.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st01], self.q2)[0, 1], 0.0)
- # Tests of distances under elementary spike operations
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st02], self.q2)[0, 1], 1.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st03], self.q2)[0, 1], 1.9)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st04], self.q2)[0, 1], 2.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st05], self.q2)[0, 1], 2.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st00, self.st07], self.q2)[0, 1], 2.0)
- self.assertAlmostEqual(stds.victor_purpura_dist(
- [self.st07, self.st08], self.q4)[0, 1], 0.4)
- self.assertAlmostEqual(stds.victor_purpura_dist(
- [self.st07, self.st10], self.q3)[0, 1], 0.6 + 2)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st11, self.st14], self.q2)[0, 1], 1)
- # Tests on timescales
- self.assertEqual(stds.victor_purpura_dist(
- [self.st11, self.st14], self.q1)[0, 1],
- stds.victor_purpura_dist(
- [self.st11, self.st14], self.q5)[0, 1])
- self.assertEqual(stds.victor_purpura_dist(
- [self.st07, self.st11], self.q0)[0, 1], 6.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st07, self.st11], self.q1)[0, 1], 6.0)
- self.assertAlmostEqual(stds.victor_purpura_dist(
- [self.st07, self.st11], self.q5)[0, 1], 2.0, 5)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st07, self.st11], self.q6)[0, 1], 2.0)
- # Tests on unordered spiketrains
- self.assertEqual(stds.victor_purpura_dist(
- [self.st11, self.st13], self.q4)[0, 1],
- stds.victor_purpura_dist(
- [self.st12, self.st13], self.q4)[0, 1])
- self.assertNotEqual(stds.victor_purpura_dist(
- [self.st11, self.st13], self.q4,
- sort=False)[0, 1],
- stds.victor_purpura_dist(
- [self.st12, self.st13], self.q4,
- sort=False)[0, 1])
- # Tests on metric properties with random spiketrains
- # (explicit calculation of second metric axiom in particular case,
- # because from dist_matrix it is trivial)
- dist_matrix = stds.victor_purpura_dist(
- [self.st21, self.st22, self.st23], self.q3)
- for i in range(3):
- for j in range(3):
- self.assertGreaterEqual(dist_matrix[i, j], 0)
- if dist_matrix[i, j] == 0:
- assert_array_equal(self.rd_st_list[i], self.rd_st_list[j])
- assert_array_equal(stds.victor_purpura_dist(
- [self.st21, self.st22], self.q3),
- stds.victor_purpura_dist(
- [self.st22, self.st21], self.q3))
- self.assertLessEqual(dist_matrix[0, 1],
- dist_matrix[0, 2] + dist_matrix[1, 2])
- self.assertLessEqual(dist_matrix[0, 2],
- dist_matrix[1, 2] + dist_matrix[0, 1])
- self.assertLessEqual(dist_matrix[1, 2],
- dist_matrix[0, 1] + dist_matrix[0, 2])
- # Tests on proper unit conversion
- self.assertAlmostEqual(
- stds.victor_purpura_dist([self.st14, self.st16], self.q3)[0, 1],
- stds.victor_purpura_dist([self.st15, self.st16], self.q3)[0, 1])
- self.assertAlmostEqual(
- stds.victor_purpura_dist([self.st16, self.st14], self.q3)[0, 1],
- stds.victor_purpura_dist([self.st16, self.st15], self.q3)[0, 1])
- self.assertEqual(
- stds.victor_purpura_dist([self.st01, self.st05], self.q3)[0, 1],
- stds.victor_purpura_dist([self.st01, self.st05], self.q7)[0, 1])
- # Tests on algorithmic behaviour for equal spike times
- self.assertEqual(
- stds.victor_purpura_dist([self.st31, self.st34], self.q3)[0, 1],
- 0.8 + 1.0)
- self.assertEqual(
- stds.victor_purpura_dist([self.st31, self.st34], self.q3)[0, 1],
- stds.victor_purpura_dist([self.st32, self.st33], self.q3)[0, 1])
- self.assertEqual(
- stds.victor_purpura_dist(
- [self.st31, self.st33], self.q3)[0, 1] * 2.0,
- stds.victor_purpura_dist(
- [self.st32, self.st34], self.q3)[0, 1])
- # Tests on spike train list lengthes smaller than 2
- self.assertEqual(stds.victor_purpura_dist(
- [self.st21], self.q3)[0, 0], 0)
- self.assertEqual(len(stds.victor_purpura_dist([], self.q3)), 0)
- def test_victor_purpura_distance_intuitive(self):
- # Tests of distances of simplest spike trains
- self.assertEqual(stds.victor_purpura_dist(
- [self.st00, self.st00], self.q2,
- algorithm='intuitive')[0, 1], 0.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st00, self.st01], self.q2,
- algorithm='intuitive')[0, 1], 1.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st00], self.q2,
- algorithm='intuitive')[0, 1], 1.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st01], self.q2,
- algorithm='intuitive')[0, 1], 0.0)
- # Tests of distances under elementary spike operations
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st02], self.q2,
- algorithm='intuitive')[0, 1], 1.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st03], self.q2,
- algorithm='intuitive')[0, 1], 1.9)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st04], self.q2,
- algorithm='intuitive')[0, 1], 2.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st05], self.q2,
- algorithm='intuitive')[0, 1], 2.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st00, self.st07], self.q2,
- algorithm='intuitive')[0, 1], 2.0)
- self.assertAlmostEqual(stds.victor_purpura_dist(
- [self.st07, self.st08], self.q4,
- algorithm='intuitive')[0, 1], 0.4)
- self.assertAlmostEqual(stds.victor_purpura_dist(
- [self.st07, self.st10], self.q3,
- algorithm='intuitive')[0, 1], 2.6)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st11, self.st14], self.q2,
- algorithm='intuitive')[0, 1], 1)
- # Tests on timescales
- self.assertEqual(stds.victor_purpura_dist(
- [self.st11, self.st14], self.q1,
- algorithm='intuitive')[0, 1],
- stds.victor_purpura_dist(
- [self.st11, self.st14], self.q5,
- algorithm='intuitive')[0, 1])
- self.assertEqual(stds.victor_purpura_dist(
- [self.st07, self.st11], self.q0,
- algorithm='intuitive')[0, 1], 6.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st07, self.st11], self.q1,
- algorithm='intuitive')[0, 1], 6.0)
- self.assertAlmostEqual(stds.victor_purpura_dist(
- [self.st07, self.st11], self.q5,
- algorithm='intuitive')[0, 1], 2.0, 5)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st07, self.st11], self.q6,
- algorithm='intuitive')[0, 1], 2.0)
- # Tests on unordered spiketrains
- self.assertEqual(stds.victor_purpura_dist(
- [self.st11, self.st13], self.q4,
- algorithm='intuitive')[0, 1],
- stds.victor_purpura_dist(
- [self.st12, self.st13], self.q4,
- algorithm='intuitive')[0, 1])
- self.assertNotEqual(stds.victor_purpura_dist(
- [self.st11, self.st13], self.q4,
- sort=False, algorithm='intuitive')[0, 1],
- stds.victor_purpura_dist(
- [self.st12, self.st13], self.q4,
- sort=False, algorithm='intuitive')[0, 1])
- # Tests on metric properties with random spiketrains
- # (explicit calculation of second metric axiom in particular case,
- # because from dist_matrix it is trivial)
- dist_matrix = stds.victor_purpura_dist(
- [self.st21, self.st22, self.st23],
- self.q3, algorithm='intuitive')
- for i in range(3):
- for j in range(3):
- self.assertGreaterEqual(dist_matrix[i, j], 0)
- if dist_matrix[i, j] == 0:
- assert_array_equal(self.rd_st_list[i], self.rd_st_list[j])
- assert_array_equal(stds.victor_purpura_dist(
- [self.st21, self.st22], self.q3,
- algorithm='intuitive'),
- stds.victor_purpura_dist(
- [self.st22, self.st21], self.q3,
- algorithm='intuitive'))
- self.assertLessEqual(dist_matrix[0, 1],
- dist_matrix[0, 2] + dist_matrix[1, 2])
- self.assertLessEqual(dist_matrix[0, 2],
- dist_matrix[1, 2] + dist_matrix[0, 1])
- self.assertLessEqual(dist_matrix[1, 2],
- dist_matrix[0, 1] + dist_matrix[0, 2])
- # Tests on proper unit conversion
- self.assertAlmostEqual(stds.victor_purpura_dist(
- [self.st14, self.st16], self.q3,
- algorithm='intuitive')[0, 1],
- stds.victor_purpura_dist(
- [self.st15, self.st16], self.q3,
- algorithm='intuitive')[0, 1])
- self.assertAlmostEqual(stds.victor_purpura_dist(
- [self.st16, self.st14], self.q3,
- algorithm='intuitive')[0, 1],
- stds.victor_purpura_dist(
- [self.st16, self.st15], self.q3,
- algorithm='intuitive')[0, 1])
- self.assertEqual(stds.victor_purpura_dist(
- [self.st01, self.st05], self.q3,
- algorithm='intuitive')[0, 1],
- stds.victor_purpura_dist(
- [self.st01, self.st05], self.q7,
- algorithm='intuitive')[0, 1])
- # Tests on algorithmic behaviour for equal spike times
- self.assertEqual(stds.victor_purpura_dist(
- [self.st31, self.st34], self.q3,
- algorithm='intuitive')[0, 1],
- 0.8 + 1.0)
- self.assertEqual(stds.victor_purpura_dist(
- [self.st31, self.st34], self.q3,
- algorithm='intuitive')[0, 1],
- stds.victor_purpura_dist(
- [self.st32, self.st33], self.q3,
- algorithm='intuitive')[0, 1])
- self.assertEqual(stds.victor_purpura_dist(
- [self.st31, self.st33], self.q3,
- algorithm='intuitive')[0, 1] * 2.0,
- stds.victor_purpura_dist(
- [self.st32, self.st34], self.q3,
- algorithm='intuitive')[0, 1])
- # Tests on spike train list lengthes smaller than 2
- self.assertEqual(stds.victor_purpura_dist(
- [self.st21], self.q3,
- algorithm='intuitive')[0, 0], 0)
- self.assertEqual(len(stds.victor_purpura_dist(
- [], self.q3, algorithm='intuitive')), 0)
- def test_victor_purpura_algorithm_comparison(self):
- assert_array_almost_equal(
- stds.victor_purpura_dist([self.st21, self.st22, self.st23],
- self.q3),
- stds.victor_purpura_dist([self.st21, self.st22, self.st23],
- self.q3, algorithm='intuitive'))
- def test_van_rossum_distance(self):
- # Tests of distances of simplest spike trains
- self.assertEqual(stds.van_rossum_dist(
- [self.st00, self.st00], self.tau2)[0, 1], 0.0)
- self.assertEqual(stds.van_rossum_dist(
- [self.st00, self.st01], self.tau2)[0, 1], 1.0)
- self.assertEqual(stds.van_rossum_dist(
- [self.st01, self.st00], self.tau2)[0, 1], 1.0)
- self.assertEqual(stds.van_rossum_dist(
- [self.st01, self.st01], self.tau2)[0, 1], 0.0)
- # Tests of distances under elementary spike operations
- self.assertAlmostEqual(stds.van_rossum_dist(
- [self.st01, self.st02], self.tau2)[0, 1],
- float(np.sqrt(2*(1.0-np.exp(-np.absolute(
- ((self.st01[0]-self.st02[0]) /
- self.tau2).simplified))))))
- self.assertAlmostEqual(stds.van_rossum_dist(
- [self.st01, self.st05], self.tau2)[0, 1],
- float(np.sqrt(2*(1.0-np.exp(-np.absolute(
- ((self.st01[0]-self.st05[0]) /
- self.tau2).simplified))))))
- self.assertAlmostEqual(stds.van_rossum_dist(
- [self.st01, self.st05], self.tau2)[0, 1],
- np.sqrt(2.0), 1)
- self.assertAlmostEqual(stds.van_rossum_dist(
- [self.st01, self.st06], self.tau2)[0, 1],
- np.sqrt(2.0), 20)
- self.assertAlmostEqual(stds.van_rossum_dist(
- [self.st00, self.st07], self.tau1)[0, 1],
- np.sqrt(0 + 2))
- self.assertAlmostEqual(stds.van_rossum_dist(
- [self.st07, self.st08], self.tau4)[0, 1],
- float(np.sqrt(2*(1.0-np.exp(-np.absolute(
- ((self.st07[0]-self.st08[-1]) /
- self.tau4).simplified))))))
- f_minus_g_squared = (
- (self.t > self.st08[0]) * np.exp(
- -((self.t-self.st08[0])/self.tau3).simplified) +
- (self.t > self.st08[1]) * np.exp(
- -((self.t-self.st08[1])/self.tau3).simplified) -
- (self.t > self.st09[0]) * np.exp(
- -((self.t-self.st09[0])/self.tau3).simplified))**2
- distance = np.sqrt(2.0 * spint.cumtrapz(
- y=f_minus_g_squared, x=self.t.magnitude)[-1] /
- self.tau3.rescale(self.t.units).magnitude)
- self.assertAlmostEqual(stds.van_rossum_dist(
- [self.st08, self.st09], self.tau3)[0, 1], distance, 5)
- self.assertAlmostEqual(stds.van_rossum_dist(
- [self.st11, self.st14], self.tau2)[0, 1], 1)
- # Tests on timescales
- self.assertAlmostEqual(
- stds.van_rossum_dist([self.st11, self.st14], self.tau1)[0, 1],
- stds.van_rossum_dist([self.st11, self.st14], self.tau5)[0, 1])
- self.assertAlmostEqual(
- stds.van_rossum_dist([self.st07, self.st11], self.tau0)[0, 1],
- np.sqrt(len(self.st07) + len(self.st11)))
- self.assertAlmostEqual(
- stds.van_rossum_dist([self.st07, self.st14], self.tau0)[0, 1],
- np.sqrt(len(self.st07) + len(self.st14)))
- self.assertAlmostEqual(
- stds.van_rossum_dist([self.st07, self.st11], self.tau1)[0, 1],
- np.sqrt(len(self.st07) + len(self.st11)))
- self.assertAlmostEqual(
- stds.van_rossum_dist([self.st07, self.st14], self.tau1)[0, 1],
- np.sqrt(len(self.st07) + len(self.st14)))
- self.assertAlmostEqual(
- stds.van_rossum_dist([self.st07, self.st11], self.tau5)[0, 1],
- np.absolute(len(self.st07) - len(self.st11)))
- self.assertAlmostEqual(
- stds.van_rossum_dist([self.st07, self.st14], self.tau5)[0, 1],
- np.absolute(len(self.st07) - len(self.st14)))
- self.assertAlmostEqual(
- stds.van_rossum_dist([self.st07, self.st11], self.tau6)[0, 1],
- np.absolute(len(self.st07) - len(self.st11)))
- self.assertAlmostEqual(
- stds.van_rossum_dist([self.st07, self.st14], self.tau6)[0, 1],
- np.absolute(len(self.st07) - len(self.st14)))
- # Tests on unordered spiketrains
- self.assertEqual(
- stds.van_rossum_dist([self.st11, self.st13], self.tau4)[0, 1],
- stds.van_rossum_dist([self.st12, self.st13], self.tau4)[0, 1])
- self.assertNotEqual(
- stds.van_rossum_dist([self.st11, self.st13],
- self.tau4, sort=False)[0, 1],
- stds.van_rossum_dist([self.st12, self.st13],
- self.tau4, sort=False)[0, 1])
- # Tests on metric properties with random spiketrains
- # (explicit calculation of second metric axiom in particular case,
- # because from dist_matrix it is trivial)
- dist_matrix = stds.van_rossum_dist(
- [self.st21, self.st22, self.st23], self.tau3)
- for i in range(3):
- for j in range(3):
- self.assertGreaterEqual(dist_matrix[i, j], 0)
- if dist_matrix[i, j] == 0:
- assert_array_equal(self.rd_st_list[i], self.rd_st_list[j])
- assert_array_equal(
- stds.van_rossum_dist([self.st21, self.st22], self.tau3),
- stds.van_rossum_dist([self.st22, self.st21], self.tau3))
- self.assertLessEqual(dist_matrix[0, 1],
- dist_matrix[0, 2] + dist_matrix[1, 2])
- self.assertLessEqual(dist_matrix[0, 2],
- dist_matrix[1, 2] + dist_matrix[0, 1])
- self.assertLessEqual(dist_matrix[1, 2],
- dist_matrix[0, 1] + dist_matrix[0, 2])
- # Tests on proper unit conversion
- self.assertAlmostEqual(
- stds.van_rossum_dist([self.st14, self.st16], self.tau3)[0, 1],
- stds.van_rossum_dist([self.st15, self.st16], self.tau3)[0, 1])
- self.assertAlmostEqual(
- stds.van_rossum_dist([self.st16, self.st14], self.tau3)[0, 1],
- stds.van_rossum_dist([self.st16, self.st15], self.tau3)[0, 1])
- self.assertEqual(
- stds.van_rossum_dist([self.st01, self.st05], self.tau3)[0, 1],
- stds.van_rossum_dist([self.st01, self.st05], self.tau7)[0, 1])
- # Tests on algorithmic behaviour for equal spike times
- f_minus_g_squared = (
- (self.t > self.st31[0]) * np.exp(
- -((self.t-self.st31[0])/self.tau3).simplified) -
- (self.t > self.st34[0]) * np.exp(
- -((self.t-self.st34[0])/self.tau3).simplified) -
- (self.t > self.st34[1]) * np.exp(
- -((self.t-self.st34[1])/self.tau3).simplified))**2
- distance = np.sqrt(2.0 * spint.cumtrapz(
- y=f_minus_g_squared, x=self.t.magnitude)[-1] /
- self.tau3.rescale(self.t.units).magnitude)
- self.assertAlmostEqual(stds.van_rossum_dist([self.st31, self.st34],
- self.tau3)[0, 1],
- distance, 5)
- self.assertEqual(stds.van_rossum_dist([self.st31, self.st34],
- self.tau3)[0, 1],
- stds.van_rossum_dist([self.st32, self.st33],
- self.tau3)[0, 1])
- self.assertEqual(stds.van_rossum_dist([self.st31, self.st33],
- self.tau3)[0, 1] * 2.0,
- stds.van_rossum_dist([self.st32, self.st34],
- self.tau3)[0, 1])
- # Tests on spike train list lengthes smaller than 2
- self.assertEqual(stds.van_rossum_dist([self.st21], self.tau3)[0, 0], 0)
- self.assertEqual(len(stds.van_rossum_dist([], self.tau3)), 0)
- if __name__ == '__main__':
- unittest.main()
|