123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319 |
- # -*- coding: utf-8 -*-
- """
- unittests for spike_train_surrogates module.
- :copyright: Copyright 2014-2016 by the Elephant team, see AUTHORS.txt.
- :license: Modified BSD, see LICENSE.txt for details.
- """
- import unittest
- import elephant.spike_train_surrogates as surr
- import numpy as np
- import quantities as pq
- import neo
- np.random.seed(0)
- class SurrogatesTestCase(unittest.TestCase):
- def test_dither_spikes_output_format(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- nr_surr = 2
- dither = 10 * pq.ms
- surrs = surr.dither_spikes(st, dither=dither, n=nr_surr)
- self.assertIsInstance(surrs, list)
- self.assertEqual(len(surrs), nr_surr)
- for surrog in surrs:
- self.assertIsInstance(surrs[0], neo.SpikeTrain)
- self.assertEqual(surrog.units, st.units)
- self.assertEqual(surrog.t_start, st.t_start)
- self.assertEqual(surrog.t_stop, st.t_stop)
- self.assertEqual(len(surrog), len(st))
- def test_dither_spikes_empty_train(self):
- st = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms)
- dither = 10 * pq.ms
- surrog = surr.dither_spikes(st, dither=dither, n=1)[0]
- self.assertEqual(len(surrog), 0)
- def test_dither_spikes_output_decimals(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- nr_surr = 2
- dither = 10 * pq.ms
- np.random.seed(42)
- surrs = surr.dither_spikes(st, dither=dither, decimals=3, n=nr_surr)
- np.random.seed(42)
- dither_values = np.random.random_sample((nr_surr, len(st)))
- expected_non_dithered = np.sum(dither_values==0)
- observed_non_dithered = 0
- for surrog in surrs:
- for i in range(len(surrog)):
- if surrog[i] - int(surrog[i]) * pq.ms == surrog[i] - surrog[i]:
- observed_non_dithered += 1
- self.assertEqual(observed_non_dithered, expected_non_dithered)
- def test_dither_spikes_false_edges(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- nr_surr = 2
- dither = 10 * pq.ms
- surrs = surr.dither_spikes(st, dither=dither, n=nr_surr, edges=False)
- for surrog in surrs:
- for i in range(len(surrog)):
- self.assertLessEqual(surrog[i], st.t_stop)
- def test_randomise_spikes_output_format(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- nr_surr = 2
- surrs = surr.randomise_spikes(st, n=nr_surr)
- self.assertIsInstance(surrs, list)
- self.assertEqual(len(surrs), nr_surr)
- for surrog in surrs:
- self.assertIsInstance(surrs[0], neo.SpikeTrain)
- self.assertEqual(surrog.units, st.units)
- self.assertEqual(surrog.t_start, st.t_start)
- self.assertEqual(surrog.t_stop, st.t_stop)
- self.assertEqual(len(surrog), len(st))
- def test_randomise_spikes_empty_train(self):
- st = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms)
- surrog = surr.randomise_spikes(st, n=1)[0]
- self.assertEqual(len(surrog), 0)
- def test_randomise_spikes_output_decimals(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- nr_surr = 2
- surrs = surr.randomise_spikes(st, n=nr_surr, decimals=3)
- for surrog in surrs:
- for i in range(len(surrog)):
- self.assertNotEqual(surrog[i] - int(surrog[i]) * pq.ms,
- surrog[i] - surrog[i])
- def test_shuffle_isis_output_format(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- nr_surr = 2
- surrs = surr.shuffle_isis(st, n=nr_surr)
- self.assertIsInstance(surrs, list)
- self.assertEqual(len(surrs), nr_surr)
- for surrog in surrs:
- self.assertIsInstance(surrs[0], neo.SpikeTrain)
- self.assertEqual(surrog.units, st.units)
- self.assertEqual(surrog.t_start, st.t_start)
- self.assertEqual(surrog.t_stop, st.t_stop)
- self.assertEqual(len(surrog), len(st))
- def test_shuffle_isis_empty_train(self):
- st = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms)
- surrog = surr.shuffle_isis(st, n=1)[0]
- self.assertEqual(len(surrog), 0)
- def test_shuffle_isis_same_isis(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- surrog = surr.shuffle_isis(st, n=1)[0]
- st_pq = st.view(pq.Quantity)
- surr_pq = surrog.view(pq.Quantity)
- isi0_orig = st[0] - st.t_start
- ISIs_orig = np.sort([isi0_orig] + [isi for isi in np.diff(st_pq)])
- isi0_surr = surrog[0] - surrog.t_start
- ISIs_surr = np.sort([isi0_surr] + [isi for isi in np.diff(surr_pq)])
- self.assertTrue(np.all(ISIs_orig == ISIs_surr))
- def test_shuffle_isis_output_decimals(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- surrog = surr.shuffle_isis(st, n=1, decimals=95)[0]
- st_pq = st.view(pq.Quantity)
- surr_pq = surrog.view(pq.Quantity)
- isi0_orig = st[0] - st.t_start
- ISIs_orig = np.sort([isi0_orig] + [isi for isi in np.diff(st_pq)])
- isi0_surr = surrog[0] - surrog.t_start
- ISIs_surr = np.sort([isi0_surr] + [isi for isi in np.diff(surr_pq)])
- self.assertTrue(np.all(ISIs_orig == ISIs_surr))
- def test_dither_spike_train_output_format(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- nr_surr = 2
- shift = 10 * pq.ms
- surrs = surr.dither_spike_train(st, shift=shift, n=nr_surr)
- self.assertIsInstance(surrs, list)
- self.assertEqual(len(surrs), nr_surr)
- for surrog in surrs:
- self.assertIsInstance(surrs[0], neo.SpikeTrain)
- self.assertEqual(surrog.units, st.units)
- self.assertEqual(surrog.t_start, st.t_start)
- self.assertEqual(surrog.t_stop, st.t_stop)
- self.assertEqual(len(surrog), len(st))
- def test_dither_spike_train_empty_train(self):
- st = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms)
- shift = 10 * pq.ms
- surrog = surr.dither_spike_train(st, shift=shift, n=1)[0]
- self.assertEqual(len(surrog), 0)
- def test_dither_spike_train_output_decimals(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- nr_surr = 2
- shift = 10 * pq.ms
- surrs = surr.dither_spike_train(st, shift=shift, n=nr_surr, decimals=3)
- for surrog in surrs:
- for i in range(len(surrog)):
- self.assertNotEqual(surrog[i] - int(surrog[i]) * pq.ms,
- surrog[i] - surrog[i])
- def test_dither_spike_train_false_edges(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- nr_surr = 2
- shift = 10 * pq.ms
- surrs = surr.dither_spike_train(
- st, shift=shift, n=nr_surr, edges=False)
- for surrog in surrs:
- for i in range(len(surrog)):
- self.assertLessEqual(surrog[i], st.t_stop)
- def test_jitter_spikes_output_format(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- nr_surr = 2
- binsize = 100 * pq.ms
- surrs = surr.jitter_spikes(st, binsize=binsize, n=nr_surr)
- self.assertIsInstance(surrs, list)
- self.assertEqual(len(surrs), nr_surr)
- for surrog in surrs:
- self.assertIsInstance(surrs[0], neo.SpikeTrain)
- self.assertEqual(surrog.units, st.units)
- self.assertEqual(surrog.t_start, st.t_start)
- self.assertEqual(surrog.t_stop, st.t_stop)
- self.assertEqual(len(surrog), len(st))
- def test_jitter_spikes_empty_train(self):
- st = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms)
- binsize = 75 * pq.ms
- surrog = surr.jitter_spikes(st, binsize=binsize, n=1)[0]
- self.assertEqual(len(surrog), 0)
- def test_jitter_spikes_same_bins(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- binsize = 100 * pq.ms
- surrog = surr.jitter_spikes(st, binsize=binsize, n=1)[0]
- bin_ids_orig = np.array((st.view(pq.Quantity) / binsize).rescale(
- pq.dimensionless).magnitude, dtype=int)
- bin_ids_surr = np.array((surrog.view(pq.Quantity) / binsize).rescale(
- pq.dimensionless).magnitude, dtype=int)
- self.assertTrue(np.all(bin_ids_orig == bin_ids_surr))
- # Bug encountered when the original and surrogate trains have
- # different number of spikes
- self.assertEqual(len(st), len(surrog))
- def test_jitter_spikes_unequal_binsize(self):
- st = neo.SpikeTrain([90, 150, 180, 480] * pq.ms, t_stop=500 * pq.ms)
- binsize = 75 * pq.ms
- surrog = surr.jitter_spikes(st, binsize=binsize, n=1)[0]
- bin_ids_orig = np.array((st.view(pq.Quantity) / binsize).rescale(
- pq.dimensionless).magnitude, dtype=int)
- bin_ids_surr = np.array((surrog.view(pq.Quantity) / binsize).rescale(
- pq.dimensionless).magnitude, dtype=int)
- self.assertTrue(np.all(bin_ids_orig == bin_ids_surr))
- def test_surr_method(self):
- st = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
- nr_surr = 2
- surrs = surr.surrogates(st, dt=3 * pq.ms, n=nr_surr,
- surr_method='shuffle_isis', edges=False)
- self.assertRaises(ValueError, surr.surrogates, st, n=1,
- surr_method='spike_shifting',
- dt=None, decimals=None, edges=True)
- self.assertTrue(len(surrs) == nr_surr)
- nr_surr2 = 4
- surrs2 = surr.surrogates(st, dt=5 * pq.ms, n=nr_surr2,
- surr_method='dither_spike_train', edges=True)
- for surrog in surrs:
- self.assertTrue(isinstance(surrs[0], neo.SpikeTrain))
- self.assertEqual(surrog.units, st.units)
- self.assertEqual(surrog.t_start, st.t_start)
- self.assertEqual(surrog.t_stop, st.t_stop)
- self.assertEqual(len(surrog), len(st))
- self.assertTrue(len(surrs) == nr_surr)
- for surrog in surrs2:
- self.assertTrue(isinstance(surrs2[0], neo.SpikeTrain))
- self.assertEqual(surrog.units, st.units)
- self.assertEqual(surrog.t_start, st.t_start)
- self.assertEqual(surrog.t_stop, st.t_stop)
- self.assertEqual(len(surrog), len(st))
- self.assertTrue(len(surrs2) == nr_surr2)
- def suite():
- suite = unittest.makeSuite(SurrogatesTestCase, 'test')
- return suite
- if __name__ == "__main__":
- runner = unittest.TextTestRunner(verbosity=2)
- runner.run(suite())
|