test_spiketrain.py 76 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of the neo.core.spiketrain.SpikeTrain class and related functions
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import
  7. import sys
  8. try:
  9. import unittest2 as unittest
  10. except ImportError:
  11. import unittest
  12. import numpy as np
  13. from numpy.testing import assert_array_equal
  14. import quantities as pq
  15. try:
  16. from IPython.lib.pretty import pretty
  17. except ImportError as err:
  18. HAVE_IPYTHON = False
  19. else:
  20. HAVE_IPYTHON = True
  21. from neo.core.spiketrain import (check_has_dimensions_time, SpikeTrain,
  22. _check_time_in_range, _new_spiketrain)
  23. from neo.core import Segment, Unit
  24. from neo.test.tools import (assert_arrays_equal,
  25. assert_arrays_almost_equal,
  26. assert_neo_object_is_compliant)
  27. from neo.test.generate_datasets import (get_fake_value, get_fake_values,
  28. fake_neo, TEST_ANNOTATIONS)
  29. class Test__generate_datasets(unittest.TestCase):
  30. def setUp(self):
  31. np.random.seed(0)
  32. self.annotations = dict([(str(x), TEST_ANNOTATIONS[x]) for x in
  33. range(len(TEST_ANNOTATIONS))])
  34. def test__get_fake_values(self):
  35. self.annotations['seed'] = 0
  36. waveforms = get_fake_value('waveforms', pq.Quantity, seed=3, dim=3)
  37. shape = waveforms.shape[0]
  38. times = get_fake_value('times', pq.Quantity, seed=0, dim=1, shape=waveforms.shape[0])
  39. t_start = get_fake_value('t_start', pq.Quantity, seed=1, dim=0)
  40. t_stop = get_fake_value('t_stop', pq.Quantity, seed=2, dim=0)
  41. left_sweep = get_fake_value('left_sweep', pq.Quantity, seed=4, dim=0)
  42. sampling_rate = get_fake_value('sampling_rate', pq.Quantity,
  43. seed=5, dim=0)
  44. name = get_fake_value('name', str, seed=6, obj=SpikeTrain)
  45. description = get_fake_value('description', str,
  46. seed=7, obj='SpikeTrain')
  47. file_origin = get_fake_value('file_origin', str)
  48. attrs1 = {'name': name,
  49. 'description': description,
  50. 'file_origin': file_origin}
  51. attrs2 = attrs1.copy()
  52. attrs2.update(self.annotations)
  53. res11 = get_fake_values(SpikeTrain, annotate=False, seed=0)
  54. res12 = get_fake_values('SpikeTrain', annotate=False, seed=0)
  55. res21 = get_fake_values(SpikeTrain, annotate=True, seed=0)
  56. res22 = get_fake_values('SpikeTrain', annotate=True, seed=0)
  57. assert_arrays_equal(res11.pop('times'), times)
  58. assert_arrays_equal(res12.pop('times'), times)
  59. assert_arrays_equal(res21.pop('times'), times)
  60. assert_arrays_equal(res22.pop('times'), times)
  61. assert_arrays_equal(res11.pop('t_start'), t_start)
  62. assert_arrays_equal(res12.pop('t_start'), t_start)
  63. assert_arrays_equal(res21.pop('t_start'), t_start)
  64. assert_arrays_equal(res22.pop('t_start'), t_start)
  65. assert_arrays_equal(res11.pop('t_stop'), t_stop)
  66. assert_arrays_equal(res12.pop('t_stop'), t_stop)
  67. assert_arrays_equal(res21.pop('t_stop'), t_stop)
  68. assert_arrays_equal(res22.pop('t_stop'), t_stop)
  69. assert_arrays_equal(res11.pop('waveforms'), waveforms)
  70. assert_arrays_equal(res12.pop('waveforms'), waveforms)
  71. assert_arrays_equal(res21.pop('waveforms'), waveforms)
  72. assert_arrays_equal(res22.pop('waveforms'), waveforms)
  73. assert_arrays_equal(res11.pop('left_sweep'), left_sweep)
  74. assert_arrays_equal(res12.pop('left_sweep'), left_sweep)
  75. assert_arrays_equal(res21.pop('left_sweep'), left_sweep)
  76. assert_arrays_equal(res22.pop('left_sweep'), left_sweep)
  77. assert_arrays_equal(res11.pop('sampling_rate'), sampling_rate)
  78. assert_arrays_equal(res12.pop('sampling_rate'), sampling_rate)
  79. assert_arrays_equal(res21.pop('sampling_rate'), sampling_rate)
  80. assert_arrays_equal(res22.pop('sampling_rate'), sampling_rate)
  81. self.assertEqual(res11, attrs1)
  82. self.assertEqual(res12, attrs1)
  83. self.assertEqual(res21, attrs2)
  84. self.assertEqual(res22, attrs2)
  85. def test__fake_neo__cascade(self):
  86. self.annotations['seed'] = None
  87. obj_type = 'SpikeTrain'
  88. cascade = True
  89. res = fake_neo(obj_type=obj_type, cascade=cascade)
  90. self.assertTrue(isinstance(res, SpikeTrain))
  91. assert_neo_object_is_compliant(res)
  92. self.assertEqual(res.annotations, self.annotations)
  93. def test__fake_neo__nocascade(self):
  94. self.annotations['seed'] = None
  95. obj_type = SpikeTrain
  96. cascade = False
  97. res = fake_neo(obj_type=obj_type, cascade=cascade)
  98. self.assertTrue(isinstance(res, SpikeTrain))
  99. assert_neo_object_is_compliant(res)
  100. self.assertEqual(res.annotations, self.annotations)
  101. class Testcheck_has_dimensions_time(unittest.TestCase):
  102. def test__check_has_dimensions_time(self):
  103. a = np.arange(3) * pq.ms
  104. b = np.arange(3) * pq.mV
  105. c = np.arange(3) * pq.mA
  106. d = np.arange(3) * pq.minute
  107. check_has_dimensions_time(a)
  108. self.assertRaises(ValueError, check_has_dimensions_time, b)
  109. self.assertRaises(ValueError, check_has_dimensions_time, c)
  110. check_has_dimensions_time(d)
  111. self.assertRaises(ValueError, check_has_dimensions_time, a, b, c, d)
  112. class Testcheck_time_in_range(unittest.TestCase):
  113. def test__check_time_in_range_empty_array(self):
  114. value = np.array([])
  115. t_start = 0*pq.s
  116. t_stop = 10*pq.s
  117. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  118. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  119. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=True)
  120. def test__check_time_in_range_exact(self):
  121. value = np.array([0., 5., 10.])*pq.s
  122. t_start = 0.*pq.s
  123. t_stop = 10.*pq.s
  124. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  125. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  126. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=True)
  127. def test__check_time_in_range_scale(self):
  128. value = np.array([0., 5000., 10000.])*pq.ms
  129. t_start = 0.*pq.s
  130. t_stop = 10.*pq.s
  131. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  132. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  133. def test__check_time_in_range_inside(self):
  134. value = np.array([0.1, 5., 9.9])*pq.s
  135. t_start = 0.*pq.s
  136. t_stop = 10.*pq.s
  137. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  138. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  139. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=True)
  140. def test__check_time_in_range_below(self):
  141. value = np.array([-0.1, 5., 10.])*pq.s
  142. t_start = 0.*pq.s
  143. t_stop = 10.*pq.s
  144. self.assertRaises(ValueError, _check_time_in_range, value,
  145. t_start=t_start, t_stop=t_stop)
  146. self.assertRaises(ValueError, _check_time_in_range, value,
  147. t_start=t_start, t_stop=t_stop, view=False)
  148. self.assertRaises(ValueError, _check_time_in_range, value,
  149. t_start=t_start, t_stop=t_stop, view=True)
  150. def test__check_time_in_range_below_scale(self):
  151. value = np.array([-1., 5000., 10000.])*pq.ms
  152. t_start = 0.*pq.s
  153. t_stop = 10.*pq.s
  154. self.assertRaises(ValueError, _check_time_in_range, value,
  155. t_start=t_start, t_stop=t_stop)
  156. self.assertRaises(ValueError, _check_time_in_range, value,
  157. t_start=t_start, t_stop=t_stop, view=False)
  158. def test__check_time_in_range_above(self):
  159. value = np.array([0., 5., 10.1])*pq.s
  160. t_start = 0.*pq.s
  161. t_stop = 10.*pq.s
  162. self.assertRaises(ValueError, _check_time_in_range, value,
  163. t_start=t_start, t_stop=t_stop)
  164. self.assertRaises(ValueError, _check_time_in_range, value,
  165. t_start=t_start, t_stop=t_stop, view=False)
  166. self.assertRaises(ValueError, _check_time_in_range, value,
  167. t_start=t_start, t_stop=t_stop, view=True)
  168. def test__check_time_in_range_above_scale(self):
  169. value = np.array([0., 5000., 10001.])*pq.ms
  170. t_start = 0.*pq.s
  171. t_stop = 10.*pq.s
  172. self.assertRaises(ValueError, _check_time_in_range, value,
  173. t_start=t_start, t_stop=t_stop)
  174. self.assertRaises(ValueError, _check_time_in_range, value,
  175. t_start=t_start, t_stop=t_stop, view=False)
  176. def test__check_time_in_range_above_below(self):
  177. value = np.array([-0.1, 5., 10.1])*pq.s
  178. t_start = 0.*pq.s
  179. t_stop = 10.*pq.s
  180. self.assertRaises(ValueError, _check_time_in_range, value,
  181. t_start=t_start, t_stop=t_stop)
  182. self.assertRaises(ValueError, _check_time_in_range, value,
  183. t_start=t_start, t_stop=t_stop, view=False)
  184. self.assertRaises(ValueError, _check_time_in_range, value,
  185. t_start=t_start, t_stop=t_stop, view=True)
  186. def test__check_time_in_range_above_below_scale(self):
  187. value = np.array([-1., 5000., 10001.])*pq.ms
  188. t_start = 0.*pq.s
  189. t_stop = 10.*pq.s
  190. self.assertRaises(ValueError, _check_time_in_range, value,
  191. t_start=t_start, t_stop=t_stop)
  192. self.assertRaises(ValueError, _check_time_in_range, value,
  193. t_start=t_start, t_stop=t_stop, view=False)
  194. class TestConstructor(unittest.TestCase):
  195. def result_spike_check(self, train, st_out, t_start_out, t_stop_out,
  196. dtype, units):
  197. assert_arrays_equal(train, st_out)
  198. assert_arrays_equal(train, train.times)
  199. assert_neo_object_is_compliant(train)
  200. self.assertEqual(train.t_start, t_start_out)
  201. self.assertEqual(train.t_start, train.times.t_start)
  202. self.assertEqual(train.t_stop, t_stop_out)
  203. self.assertEqual(train.t_stop, train.times.t_stop)
  204. self.assertEqual(train.units, units)
  205. self.assertEqual(train.units, train.times.units)
  206. self.assertEqual(train.t_start.units, units)
  207. self.assertEqual(train.t_start.units, train.times.t_start.units)
  208. self.assertEqual(train.t_stop.units, units)
  209. self.assertEqual(train.t_stop.units, train.times.t_stop.units)
  210. self.assertEqual(train.dtype, dtype)
  211. self.assertEqual(train.dtype, train.times.dtype)
  212. self.assertEqual(train.t_stop.dtype, dtype)
  213. self.assertEqual(train.t_stop.dtype, train.times.t_stop.dtype)
  214. self.assertEqual(train.t_start.dtype, dtype)
  215. self.assertEqual(train.t_start.dtype, train.times.t_start.dtype)
  216. def test__create_minimal(self):
  217. t_start = 0.0
  218. t_stop = 10.0
  219. train1 = SpikeTrain([]*pq.s, t_stop)
  220. train2 = _new_spiketrain(SpikeTrain, []*pq.s, t_stop)
  221. dtype = np.float64
  222. units = 1 * pq.s
  223. t_start_out = t_start * units
  224. t_stop_out = t_stop * units
  225. st_out = [] * units
  226. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  227. dtype, units)
  228. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  229. dtype, units)
  230. def test__create_empty(self):
  231. t_start = 0.0
  232. t_stop = 10.0
  233. train1 = SpikeTrain([], t_start=t_start, t_stop=t_stop, units='s')
  234. train2 = _new_spiketrain(SpikeTrain, [], t_start=t_start,
  235. t_stop=t_stop, units='s')
  236. dtype = np.float64
  237. units = 1 * pq.s
  238. t_start_out = t_start * units
  239. t_stop_out = t_stop * units
  240. st_out = [] * units
  241. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  242. dtype, units)
  243. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  244. dtype, units)
  245. def test__create_empty_no_t_start(self):
  246. t_start = 0.0
  247. t_stop = 10.0
  248. train1 = SpikeTrain([], t_stop=t_stop, units='s')
  249. train2 = _new_spiketrain(SpikeTrain, [], t_stop=t_stop, units='s')
  250. dtype = np.float64
  251. units = 1 * pq.s
  252. t_start_out = t_start * units
  253. t_stop_out = t_stop * units
  254. st_out = [] * units
  255. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  256. dtype, units)
  257. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  258. dtype, units)
  259. def test__create_from_list(self):
  260. times = range(10)
  261. t_start = 0.0*pq.s
  262. t_stop = 10000.0*pq.ms
  263. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="ms")
  264. train2 = _new_spiketrain(SpikeTrain, times,
  265. t_start=t_start, t_stop=t_stop, units="ms")
  266. dtype = np.float64
  267. units = 1 * pq.ms
  268. t_start_out = t_start
  269. t_stop_out = t_stop
  270. st_out = times * units
  271. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  272. dtype, units)
  273. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  274. dtype, units)
  275. def test__create_from_list_set_dtype(self):
  276. times = range(10)
  277. t_start = 0.0*pq.s
  278. t_stop = 10000.0*pq.ms
  279. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  280. units="ms", dtype='f4')
  281. train2 = _new_spiketrain(SpikeTrain, times,
  282. t_start=t_start, t_stop=t_stop,
  283. units="ms", dtype='f4')
  284. dtype = np.float32
  285. units = 1 * pq.ms
  286. t_start_out = t_start.astype(dtype)
  287. t_stop_out = t_stop.astype(dtype)
  288. st_out = pq.Quantity(times, units=units, dtype=dtype)
  289. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  290. dtype, units)
  291. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  292. dtype, units)
  293. def test__create_from_list_no_start_stop_units(self):
  294. times = range(10)
  295. t_start = 0.0
  296. t_stop = 10000.0
  297. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="ms")
  298. train2 = _new_spiketrain(SpikeTrain, times,
  299. t_start=t_start, t_stop=t_stop, units="ms")
  300. dtype = np.float64
  301. units = 1 * pq.ms
  302. t_start_out = t_start * units
  303. t_stop_out = t_stop * units
  304. st_out = times * units
  305. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  306. dtype, units)
  307. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  308. dtype, units)
  309. def test__create_from_list_no_start_stop_units_set_dtype(self):
  310. times = range(10)
  311. t_start = 0.0
  312. t_stop = 10000.0
  313. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  314. units="ms", dtype='f4')
  315. train2 = _new_spiketrain(SpikeTrain, times,
  316. t_start=t_start, t_stop=t_stop,
  317. units="ms", dtype='f4')
  318. dtype = np.float32
  319. units = 1 * pq.ms
  320. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  321. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  322. st_out = pq.Quantity(times, units=units, dtype=dtype)
  323. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  324. dtype, units)
  325. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  326. dtype, units)
  327. def test__create_from_array(self):
  328. times = np.arange(10)
  329. t_start = 0.0*pq.s
  330. t_stop = 10000.0*pq.ms
  331. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  332. train2 = _new_spiketrain(SpikeTrain, times,
  333. t_start=t_start, t_stop=t_stop, units="s")
  334. dtype = np.int
  335. units = 1 * pq.s
  336. t_start_out = t_start
  337. t_stop_out = t_stop
  338. st_out = times * units
  339. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  340. dtype, units)
  341. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  342. dtype, units)
  343. def test__create_from_array_with_dtype(self):
  344. times = np.arange(10, dtype='f4')
  345. t_start = 0.0*pq.s
  346. t_stop = 10000.0*pq.ms
  347. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  348. train2 = _new_spiketrain(SpikeTrain, times,
  349. t_start=t_start, t_stop=t_stop, units="s")
  350. dtype = times.dtype
  351. units = 1 * pq.s
  352. t_start_out = t_start
  353. t_stop_out = t_stop
  354. st_out = times * units
  355. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  356. dtype, units)
  357. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  358. dtype, units)
  359. def test__create_from_array_set_dtype(self):
  360. times = np.arange(10)
  361. t_start = 0.0*pq.s
  362. t_stop = 10000.0*pq.ms
  363. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  364. units="s", dtype='f4')
  365. train2 = _new_spiketrain(SpikeTrain, times,
  366. t_start=t_start, t_stop=t_stop,
  367. units="s", dtype='f4')
  368. dtype = np.float32
  369. units = 1 * pq.s
  370. t_start_out = t_start.astype(dtype)
  371. t_stop_out = t_stop.astype(dtype)
  372. st_out = times.astype(dtype) * units
  373. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  374. dtype, units)
  375. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  376. dtype, units)
  377. def test__create_from_array_no_start_stop_units(self):
  378. times = np.arange(10)
  379. t_start = 0.0
  380. t_stop = 10000.0
  381. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  382. train2 = _new_spiketrain(SpikeTrain, times,
  383. t_start=t_start, t_stop=t_stop, units="s")
  384. dtype = np.int
  385. units = 1 * pq.s
  386. t_start_out = t_start * units
  387. t_stop_out = t_stop * units
  388. st_out = times * units
  389. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  390. dtype, units)
  391. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  392. dtype, units)
  393. def test__create_from_array_no_start_stop_units_with_dtype(self):
  394. times = np.arange(10, dtype='f4')
  395. t_start = 0.0
  396. t_stop = 10000.0
  397. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  398. train2 = _new_spiketrain(SpikeTrain, times,
  399. t_start=t_start, t_stop=t_stop, units="s")
  400. dtype = np.float32
  401. units = 1 * pq.s
  402. t_start_out = t_start * units
  403. t_stop_out = t_stop * units
  404. st_out = times * units
  405. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  406. dtype, units)
  407. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  408. dtype, units)
  409. def test__create_from_array_no_start_stop_units_set_dtype(self):
  410. times = np.arange(10)
  411. t_start = 0.0
  412. t_stop = 10000.0
  413. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  414. units="s", dtype='f4')
  415. train2 = _new_spiketrain(SpikeTrain, times,
  416. t_start=t_start, t_stop=t_stop,
  417. units="s", dtype='f4')
  418. dtype = np.float32
  419. units = 1 * pq.s
  420. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  421. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  422. st_out = times.astype(dtype) * units
  423. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  424. dtype, units)
  425. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  426. dtype, units)
  427. def test__create_from_quantity_array(self):
  428. times = np.arange(10) * pq.ms
  429. t_start = 0.0*pq.s
  430. t_stop = 12.0*pq.ms
  431. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  432. train2 = _new_spiketrain(SpikeTrain, times,
  433. t_start=t_start, t_stop=t_stop)
  434. dtype = np.float64
  435. units = 1 * pq.ms
  436. t_start_out = t_start
  437. t_stop_out = t_stop
  438. st_out = times
  439. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  440. dtype, units)
  441. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  442. dtype, units)
  443. def test__create_from_quantity_array_with_dtype(self):
  444. times = np.arange(10, dtype='f4') * pq.ms
  445. t_start = 0.0*pq.s
  446. t_stop = 12.0*pq.ms
  447. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  448. train2 = _new_spiketrain(SpikeTrain, times,
  449. t_start=t_start, t_stop=t_stop)
  450. dtype = np.float32
  451. units = 1 * pq.ms
  452. t_start_out = t_start.astype(dtype)
  453. t_stop_out = t_stop.astype(dtype)
  454. st_out = times.astype(dtype)
  455. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  456. dtype, units)
  457. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  458. dtype, units)
  459. def test__create_from_quantity_array_set_dtype(self):
  460. times = np.arange(10) * pq.ms
  461. t_start = 0.0*pq.s
  462. t_stop = 12.0*pq.ms
  463. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  464. dtype='f4')
  465. train2 = _new_spiketrain(SpikeTrain, times,
  466. t_start=t_start, t_stop=t_stop,
  467. dtype='f4')
  468. dtype = np.float32
  469. units = 1 * pq.ms
  470. t_start_out = t_start.astype(dtype)
  471. t_stop_out = t_stop.astype(dtype)
  472. st_out = times.astype(dtype)
  473. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  474. dtype, units)
  475. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  476. dtype, units)
  477. def test__create_from_quantity_array_no_start_stop_units(self):
  478. times = np.arange(10) * pq.ms
  479. t_start = 0.0
  480. t_stop = 12.0
  481. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  482. train2 = _new_spiketrain(SpikeTrain, times,
  483. t_start=t_start, t_stop=t_stop)
  484. dtype = np.float64
  485. units = 1 * pq.ms
  486. t_start_out = t_start * units
  487. t_stop_out = t_stop * units
  488. st_out = times
  489. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  490. dtype, units)
  491. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  492. dtype, units)
  493. def test__create_from_quantity_array_no_start_stop_units_with_dtype(self):
  494. times = np.arange(10, dtype='f4') * pq.ms
  495. t_start = 0.0
  496. t_stop = 12.0
  497. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  498. train2 = _new_spiketrain(SpikeTrain, times,
  499. t_start=t_start, t_stop=t_stop)
  500. dtype = np.float32
  501. units = 1 * pq.ms
  502. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  503. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  504. st_out = times.astype(dtype)
  505. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  506. dtype, units)
  507. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  508. dtype, units)
  509. def test__create_from_quantity_array_no_start_stop_units_set_dtype(self):
  510. times = np.arange(10) * pq.ms
  511. t_start = 0.0
  512. t_stop = 12.0
  513. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  514. dtype='f4')
  515. train2 = _new_spiketrain(SpikeTrain, times,
  516. t_start=t_start, t_stop=t_stop,
  517. dtype='f4')
  518. dtype = np.float32
  519. units = 1 * pq.ms
  520. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  521. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  522. st_out = times.astype(dtype)
  523. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  524. dtype, units)
  525. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  526. dtype, units)
  527. def test__create_from_quantity_array_units(self):
  528. times = np.arange(10) * pq.ms
  529. t_start = 0.0*pq.s
  530. t_stop = 12.0*pq.ms
  531. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units='s')
  532. train2 = _new_spiketrain(SpikeTrain, times,
  533. t_start=t_start, t_stop=t_stop, units='s')
  534. dtype = np.float64
  535. units = 1 * pq.s
  536. t_start_out = t_start
  537. t_stop_out = t_stop
  538. st_out = times
  539. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  540. dtype, units)
  541. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  542. dtype, units)
  543. def test__create_from_quantity_array_units_with_dtype(self):
  544. times = np.arange(10, dtype='f4') * pq.ms
  545. t_start = 0.0*pq.s
  546. t_stop = 12.0*pq.ms
  547. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  548. units='s')
  549. train2 = _new_spiketrain(SpikeTrain, times,
  550. t_start=t_start, t_stop=t_stop, units='s')
  551. dtype = np.float32
  552. units = 1 * pq.s
  553. t_start_out = t_start.astype(dtype)
  554. t_stop_out = t_stop.rescale(units).astype(dtype)
  555. st_out = times.rescale(units).astype(dtype)
  556. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  557. dtype, units)
  558. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  559. dtype, units)
  560. def test__create_from_quantity_array_units_set_dtype(self):
  561. times = np.arange(10) * pq.ms
  562. t_start = 0.0*pq.s
  563. t_stop = 12.0*pq.ms
  564. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  565. units='s', dtype='f4')
  566. train2 = _new_spiketrain(SpikeTrain, times,
  567. t_start=t_start, t_stop=t_stop,
  568. units='s', dtype='f4')
  569. dtype = np.float32
  570. units = 1 * pq.s
  571. t_start_out = t_start.astype(dtype)
  572. t_stop_out = t_stop.rescale(units).astype(dtype)
  573. st_out = times.rescale(units).astype(dtype)
  574. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  575. dtype, units)
  576. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  577. dtype, units)
  578. def test__create_from_quantity_array_units_no_start_stop_units(self):
  579. times = np.arange(10) * pq.ms
  580. t_start = 0.0
  581. t_stop = 12.0
  582. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units='s')
  583. train2 = _new_spiketrain(SpikeTrain, times,
  584. t_start=t_start, t_stop=t_stop, units='s')
  585. dtype = np.float64
  586. units = 1 * pq.s
  587. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  588. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  589. st_out = times
  590. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  591. dtype, units)
  592. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  593. dtype, units)
  594. def test__create_from_quantity_units_no_start_stop_units_set_dtype(self):
  595. times = np.arange(10) * pq.ms
  596. t_start = 0.0
  597. t_stop = 12.0
  598. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  599. units='s', dtype='f4')
  600. train2 = _new_spiketrain(SpikeTrain, times,
  601. t_start=t_start, t_stop=t_stop,
  602. units='s', dtype='f4')
  603. dtype = np.float32
  604. units = 1 * pq.s
  605. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  606. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  607. st_out = times.rescale(units).astype(dtype)
  608. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  609. dtype, units)
  610. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  611. dtype, units)
  612. def test__create_from_list_without_units_should_raise_ValueError(self):
  613. times = range(10)
  614. t_start = 0.0*pq.s
  615. t_stop = 10000.0*pq.ms
  616. self.assertRaises(ValueError, SpikeTrain, times,
  617. t_start=t_start, t_stop=t_stop)
  618. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain, times,
  619. t_start=t_start, t_stop=t_stop)
  620. def test__create_from_array_without_units_should_raise_ValueError(self):
  621. times = np.arange(10)
  622. t_start = 0.0*pq.s
  623. t_stop = 10000.0*pq.ms
  624. self.assertRaises(ValueError, SpikeTrain, times,
  625. t_start=t_start, t_stop=t_stop)
  626. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain, times,
  627. t_start=t_start, t_stop=t_stop)
  628. def test__create_from_array_with_incompatible_units_ValueError(self):
  629. times = np.arange(10) * pq.km
  630. t_start = 0.0*pq.s
  631. t_stop = 10000.0*pq.ms
  632. self.assertRaises(ValueError, SpikeTrain, times,
  633. t_start=t_start, t_stop=t_stop)
  634. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain, times,
  635. t_start=t_start, t_stop=t_stop)
  636. def test__create_with_times_outside_tstart_tstop_ValueError(self):
  637. t_start = 23
  638. t_stop = 77
  639. train1 = SpikeTrain(np.arange(t_start, t_stop), units='ms',
  640. t_start=t_start, t_stop=t_stop)
  641. train2 = _new_spiketrain(SpikeTrain,
  642. np.arange(t_start, t_stop), units='ms',
  643. t_start=t_start, t_stop=t_stop)
  644. assert_neo_object_is_compliant(train1)
  645. assert_neo_object_is_compliant(train2)
  646. self.assertRaises(ValueError, SpikeTrain,
  647. np.arange(t_start-5, t_stop), units='ms',
  648. t_start=t_start, t_stop=t_stop)
  649. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain,
  650. np.arange(t_start-5, t_stop), units='ms',
  651. t_start=t_start, t_stop=t_stop)
  652. self.assertRaises(ValueError, SpikeTrain,
  653. np.arange(t_start, t_stop+5), units='ms',
  654. t_start=t_start, t_stop=t_stop)
  655. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain,
  656. np.arange(t_start, t_stop+5), units='ms',
  657. t_start=t_start, t_stop=t_stop)
  658. def test__create_with_len_times_different_size_than_waveform_shape1_ValueError(self):
  659. self.assertRaises(ValueError, SpikeTrain,
  660. times=np.arange(10), units='s',
  661. t_stop=4, waveforms=np.ones((10,6,50)))
  662. def test_defaults(self):
  663. # default recommended attributes
  664. train1 = SpikeTrain([3, 4, 5], units='sec', t_stop=10.0)
  665. train2 = _new_spiketrain(SpikeTrain, [3, 4, 5],
  666. units='sec', t_stop=10.0)
  667. assert_neo_object_is_compliant(train1)
  668. assert_neo_object_is_compliant(train2)
  669. self.assertEqual(train1.dtype, np.float)
  670. self.assertEqual(train2.dtype, np.float)
  671. self.assertEqual(train1.sampling_rate, 1.0 * pq.Hz)
  672. self.assertEqual(train2.sampling_rate, 1.0 * pq.Hz)
  673. self.assertEqual(train1.waveforms, None)
  674. self.assertEqual(train2.waveforms, None)
  675. self.assertEqual(train1.left_sweep, None)
  676. self.assertEqual(train2.left_sweep, None)
  677. def test_default_tstart(self):
  678. # t start defaults to zero
  679. train11 = SpikeTrain([3, 4, 5]*pq.s, t_stop=8000*pq.ms)
  680. train21 = _new_spiketrain(SpikeTrain, [3, 4, 5]*pq.s,
  681. t_stop=8000*pq.ms)
  682. assert_neo_object_is_compliant(train11)
  683. assert_neo_object_is_compliant(train21)
  684. self.assertEqual(train11.t_start, 0.*pq.s)
  685. self.assertEqual(train21.t_start, 0.*pq.s)
  686. # unless otherwise specified
  687. train12 = SpikeTrain([3, 4, 5]*pq.s, t_start=2.0, t_stop=8)
  688. train22 = _new_spiketrain(SpikeTrain, [3, 4, 5]*pq.s,
  689. t_start=2.0, t_stop=8)
  690. assert_neo_object_is_compliant(train12)
  691. assert_neo_object_is_compliant(train22)
  692. self.assertEqual(train12.t_start, 2.*pq.s)
  693. self.assertEqual(train22.t_start, 2.*pq.s)
  694. def test_tstop_units_conversion(self):
  695. train11 = SpikeTrain([3, 5, 4]*pq.s, t_stop=10)
  696. train21 = _new_spiketrain(SpikeTrain, [3, 5, 4]*pq.s, t_stop=10)
  697. assert_neo_object_is_compliant(train11)
  698. assert_neo_object_is_compliant(train21)
  699. self.assertEqual(train11.t_stop, 10.*pq.s)
  700. self.assertEqual(train21.t_stop, 10.*pq.s)
  701. train12 = SpikeTrain([3, 5, 4]*pq.s, t_stop=10000.*pq.ms)
  702. train22 = _new_spiketrain(SpikeTrain, [3, 5, 4]*pq.s,
  703. t_stop=10000.*pq.ms)
  704. assert_neo_object_is_compliant(train12)
  705. assert_neo_object_is_compliant(train22)
  706. self.assertEqual(train12.t_stop, 10.*pq.s)
  707. self.assertEqual(train22.t_stop, 10.*pq.s)
  708. train13 = SpikeTrain([3, 5, 4], units='sec', t_stop=10000.*pq.ms)
  709. train23 = _new_spiketrain(SpikeTrain, [3, 5, 4],
  710. units='sec', t_stop=10000.*pq.ms)
  711. assert_neo_object_is_compliant(train13)
  712. assert_neo_object_is_compliant(train23)
  713. self.assertEqual(train13.t_stop, 10.*pq.s)
  714. self.assertEqual(train23.t_stop, 10.*pq.s)
  715. class TestSorting(unittest.TestCase):
  716. def test_sort(self):
  717. waveforms = np.array([[[0., 1.]], [[2., 3.]], [[4., 5.]]]) * pq.mV
  718. train = SpikeTrain([3, 4, 5]*pq.s, waveforms=waveforms, name='n',
  719. t_stop=10.0)
  720. assert_neo_object_is_compliant(train)
  721. train.sort()
  722. assert_neo_object_is_compliant(train)
  723. assert_arrays_equal(train, [3, 4, 5]*pq.s)
  724. assert_arrays_equal(train.waveforms, waveforms)
  725. self.assertEqual(train.name, 'n')
  726. self.assertEqual(train.t_stop, 10.0 * pq.s)
  727. train = SpikeTrain([3, 5, 4]*pq.s, waveforms=waveforms, name='n',
  728. t_stop=10.0)
  729. assert_neo_object_is_compliant(train)
  730. train.sort()
  731. assert_neo_object_is_compliant(train)
  732. assert_arrays_equal(train, [3, 4, 5]*pq.s)
  733. assert_arrays_equal(train.waveforms, waveforms[[0, 2, 1]])
  734. self.assertEqual(train.name, 'n')
  735. self.assertEqual(train.t_start, 0.0 * pq.s)
  736. self.assertEqual(train.t_stop, 10.0 * pq.s)
  737. class TestSlice(unittest.TestCase):
  738. def setUp(self):
  739. self.waveforms1 = np.array([[[0., 1.],
  740. [0.1, 1.1]],
  741. [[2., 3.],
  742. [2.1, 3.1]],
  743. [[4., 5.],
  744. [4.1, 5.1]]]) * pq.mV
  745. self.data1 = np.array([3, 4, 5])
  746. self.data1quant = self.data1*pq.s
  747. self.train1 = SpikeTrain(self.data1quant, waveforms=self.waveforms1,
  748. name='n', arb='arbb', t_stop=10.0)
  749. def test_compliant(self):
  750. assert_neo_object_is_compliant(self.train1)
  751. def test_slice(self):
  752. # slice spike train, keep sliced spike times
  753. result = self.train1[1:2]
  754. assert_arrays_equal(self.train1[1:2], result)
  755. targwaveforms = np.array([[[2., 3.],
  756. [2.1, 3.1]]])
  757. # but keep everything else pristine
  758. assert_neo_object_is_compliant(result)
  759. self.assertEqual(self.train1.name, result.name)
  760. self.assertEqual(self.train1.description, result.description)
  761. self.assertEqual(self.train1.annotations, result.annotations)
  762. self.assertEqual(self.train1.file_origin, result.file_origin)
  763. self.assertEqual(self.train1.dtype, result.dtype)
  764. self.assertEqual(self.train1.t_start, result.t_start)
  765. self.assertEqual(self.train1.t_stop, result.t_stop)
  766. # except we update the waveforms
  767. assert_arrays_equal(self.train1.waveforms[1:2], result.waveforms)
  768. assert_arrays_equal(targwaveforms, result.waveforms)
  769. def test_slice_to_end(self):
  770. # slice spike train, keep sliced spike times
  771. result = self.train1[1:]
  772. assert_arrays_equal(self.train1[1:], result)
  773. targwaveforms = np.array([[[2., 3.],
  774. [2.1, 3.1]],
  775. [[4., 5.],
  776. [4.1, 5.1]]]) * pq.mV
  777. # but keep everything else pristine
  778. assert_neo_object_is_compliant(result)
  779. self.assertEqual(self.train1.name, result.name)
  780. self.assertEqual(self.train1.description, result.description)
  781. self.assertEqual(self.train1.annotations, result.annotations)
  782. self.assertEqual(self.train1.file_origin, result.file_origin)
  783. self.assertEqual(self.train1.dtype, result.dtype)
  784. self.assertEqual(self.train1.t_start, result.t_start)
  785. self.assertEqual(self.train1.t_stop, result.t_stop)
  786. # except we update the waveforms
  787. assert_arrays_equal(self.train1.waveforms[1:], result.waveforms)
  788. assert_arrays_equal(targwaveforms, result.waveforms)
  789. def test_slice_from_beginning(self):
  790. # slice spike train, keep sliced spike times
  791. result = self.train1[:2]
  792. assert_arrays_equal(self.train1[:2], result)
  793. targwaveforms = np.array([[[0., 1.],
  794. [0.1, 1.1]],
  795. [[2., 3.],
  796. [2.1, 3.1]]]) * pq.mV
  797. # but keep everything else pristine
  798. assert_neo_object_is_compliant(result)
  799. self.assertEqual(self.train1.name, result.name)
  800. self.assertEqual(self.train1.description, result.description)
  801. self.assertEqual(self.train1.annotations, result.annotations)
  802. self.assertEqual(self.train1.file_origin, result.file_origin)
  803. self.assertEqual(self.train1.dtype, result.dtype)
  804. self.assertEqual(self.train1.t_start, result.t_start)
  805. self.assertEqual(self.train1.t_stop, result.t_stop)
  806. # except we update the waveforms
  807. assert_arrays_equal(self.train1.waveforms[:2], result.waveforms)
  808. assert_arrays_equal(targwaveforms, result.waveforms)
  809. def test_slice_negative_idxs(self):
  810. # slice spike train, keep sliced spike times
  811. result = self.train1[:-1]
  812. assert_arrays_equal(self.train1[:-1], result)
  813. targwaveforms = np.array([[[0., 1.],
  814. [0.1, 1.1]],
  815. [[2., 3.],
  816. [2.1, 3.1]]]) * pq.mV
  817. # but keep everything else pristine
  818. assert_neo_object_is_compliant(result)
  819. self.assertEqual(self.train1.name, result.name)
  820. self.assertEqual(self.train1.description, result.description)
  821. self.assertEqual(self.train1.annotations, result.annotations)
  822. self.assertEqual(self.train1.file_origin, result.file_origin)
  823. self.assertEqual(self.train1.dtype, result.dtype)
  824. self.assertEqual(self.train1.t_start, result.t_start)
  825. self.assertEqual(self.train1.t_stop, result.t_stop)
  826. # except we update the waveforms
  827. assert_arrays_equal(self.train1.waveforms[:-1], result.waveforms)
  828. assert_arrays_equal(targwaveforms, result.waveforms)
  829. class TestTimeSlice(unittest.TestCase):
  830. def setUp(self):
  831. self.waveforms1 = np.array([[[0., 1.],
  832. [0.1, 1.1]],
  833. [[2., 3.],
  834. [2.1, 3.1]],
  835. [[4., 5.],
  836. [4.1, 5.1]],
  837. [[6., 7.],
  838. [6.1, 7.1]],
  839. [[8., 9.],
  840. [8.1, 9.1]],
  841. [[10., 11.],
  842. [10.1, 11.1]]]) * pq.mV
  843. self.data1 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7])
  844. self.data1quant = self.data1*pq.ms
  845. self.train1 = SpikeTrain(self.data1quant, t_stop=10.0*pq.ms,
  846. waveforms=self.waveforms1)
  847. def test_compliant(self):
  848. assert_neo_object_is_compliant(self.train1)
  849. def test_time_slice_typical(self):
  850. # time_slice spike train, keep sliced spike times
  851. # this is the typical time slice falling somewhere
  852. # in the middle of spikes
  853. t_start = 0.12 * pq.ms
  854. t_stop = 3.5 * pq.ms
  855. result = self.train1.time_slice(t_start, t_stop)
  856. targ = SpikeTrain([0.5, 1.2, 3.3] * pq.ms, t_stop=3.3)
  857. assert_arrays_equal(result, targ)
  858. targwaveforms = np.array([[[2., 3.],
  859. [2.1, 3.1]],
  860. [[4., 5.],
  861. [4.1, 5.1]],
  862. [[6., 7.],
  863. [6.1, 7.1]]]) * pq.mV
  864. assert_arrays_equal(targwaveforms, result.waveforms)
  865. # but keep everything else pristine
  866. assert_neo_object_is_compliant(result)
  867. self.assertEqual(self.train1.name, result.name)
  868. self.assertEqual(self.train1.description, result.description)
  869. self.assertEqual(self.train1.annotations, result.annotations)
  870. self.assertEqual(self.train1.file_origin, result.file_origin)
  871. self.assertEqual(self.train1.dtype, result.dtype)
  872. self.assertEqual(t_start, result.t_start)
  873. self.assertEqual(t_stop, result.t_stop)
  874. def test_time_slice_differnt_units(self):
  875. # time_slice spike train, keep sliced spike times
  876. t_start = 0.00012 * pq.s
  877. t_stop = 0.0035 * pq.s
  878. result = self.train1.time_slice(t_start, t_stop)
  879. targ = SpikeTrain([0.5, 1.2, 3.3] * pq.ms, t_stop=3.3)
  880. assert_arrays_equal(result, targ)
  881. targwaveforms = np.array([[[2., 3.],
  882. [2.1, 3.1]],
  883. [[4., 5.],
  884. [4.1, 5.1]],
  885. [[6., 7.],
  886. [6.1, 7.1]]]) * pq.mV
  887. assert_arrays_equal(targwaveforms, result.waveforms)
  888. # but keep everything else pristine
  889. assert_neo_object_is_compliant(result)
  890. self.assertEqual(self.train1.name, result.name)
  891. self.assertEqual(self.train1.description, result.description)
  892. self.assertEqual(self.train1.annotations, result.annotations)
  893. self.assertEqual(self.train1.file_origin, result.file_origin)
  894. self.assertEqual(self.train1.dtype, result.dtype)
  895. self.assertEqual(t_start, result.t_start)
  896. self.assertEqual(t_stop, result.t_stop)
  897. def test_time_slice_matching_ends(self):
  898. # time_slice spike train, keep sliced spike times
  899. t_start = 0.1 * pq.ms
  900. t_stop = 7.0 * pq.ms
  901. result = self.train1.time_slice(t_start, t_stop)
  902. assert_arrays_equal(self.train1, result)
  903. assert_arrays_equal(self.waveforms1, result.waveforms)
  904. # but keep everything else pristine
  905. assert_neo_object_is_compliant(result)
  906. self.assertEqual(self.train1.name, result.name)
  907. self.assertEqual(self.train1.description, result.description)
  908. self.assertEqual(self.train1.annotations, result.annotations)
  909. self.assertEqual(self.train1.file_origin, result.file_origin)
  910. self.assertEqual(self.train1.dtype, result.dtype)
  911. self.assertEqual(t_start, result.t_start)
  912. self.assertEqual(t_stop, result.t_stop)
  913. def test_time_slice_out_of_boundries(self):
  914. self.train1.t_start = 0.1*pq.ms
  915. assert_neo_object_is_compliant(self.train1)
  916. # time_slice spike train, keep sliced spike times
  917. t_start = 0.01 * pq.ms
  918. t_stop = 70.0 * pq.ms
  919. result = self.train1.time_slice(t_start, t_stop)
  920. assert_arrays_equal(self.train1, result)
  921. assert_arrays_equal(self.waveforms1, result.waveforms)
  922. # but keep everything else pristine
  923. assert_neo_object_is_compliant(result)
  924. self.assertEqual(self.train1.name, result.name)
  925. self.assertEqual(self.train1.description, result.description)
  926. self.assertEqual(self.train1.annotations, result.annotations)
  927. self.assertEqual(self.train1.file_origin, result.file_origin)
  928. self.assertEqual(self.train1.dtype, result.dtype)
  929. self.assertEqual(self.train1.t_start, result.t_start)
  930. self.assertEqual(self.train1.t_stop, result.t_stop)
  931. def test_time_slice_empty(self):
  932. waveforms = np.array([[[]]]) * pq.mV
  933. train = SpikeTrain([] * pq.ms, t_stop=10.0, waveforms=waveforms)
  934. assert_neo_object_is_compliant(train)
  935. # time_slice spike train, keep sliced spike times
  936. t_start = 0.01 * pq.ms
  937. t_stop = 70.0 * pq.ms
  938. result = train.time_slice(t_start, t_stop)
  939. assert_arrays_equal(train, result)
  940. assert_arrays_equal(waveforms[:-1], result.waveforms)
  941. # but keep everything else pristine
  942. assert_neo_object_is_compliant(result)
  943. self.assertEqual(train.name, result.name)
  944. self.assertEqual(train.description, result.description)
  945. self.assertEqual(train.annotations, result.annotations)
  946. self.assertEqual(train.file_origin, result.file_origin)
  947. self.assertEqual(train.dtype, result.dtype)
  948. self.assertEqual(t_start, result.t_start)
  949. self.assertEqual(train.t_stop, result.t_stop)
  950. def test_time_slice_none_stop(self):
  951. # time_slice spike train, keep sliced spike times
  952. t_start = 1 * pq.ms
  953. result = self.train1.time_slice(t_start, None)
  954. assert_arrays_equal([1.2, 3.3, 6.4, 7] * pq.ms, result)
  955. targwaveforms = np.array([[[4., 5.],
  956. [4.1, 5.1]],
  957. [[6., 7.],
  958. [6.1, 7.1]],
  959. [[8., 9.],
  960. [8.1, 9.1]],
  961. [[10., 11.],
  962. [10.1, 11.1]]]) * pq.mV
  963. assert_arrays_equal(targwaveforms, result.waveforms)
  964. # but keep everything else pristine
  965. assert_neo_object_is_compliant(result)
  966. self.assertEqual(self.train1.name, result.name)
  967. self.assertEqual(self.train1.description, result.description)
  968. self.assertEqual(self.train1.annotations, result.annotations)
  969. self.assertEqual(self.train1.file_origin, result.file_origin)
  970. self.assertEqual(self.train1.dtype, result.dtype)
  971. self.assertEqual(t_start, result.t_start)
  972. self.assertEqual(self.train1.t_stop, result.t_stop)
  973. def test_time_slice_none_start(self):
  974. # time_slice spike train, keep sliced spike times
  975. t_stop = 1 * pq.ms
  976. result = self.train1.time_slice(None, t_stop)
  977. assert_arrays_equal([0.1, 0.5] * pq.ms, result)
  978. targwaveforms = np.array([[[0., 1.],
  979. [0.1, 1.1]],
  980. [[2., 3.],
  981. [2.1, 3.1]]]) * pq.mV
  982. assert_arrays_equal(targwaveforms, result.waveforms)
  983. # but keep everything else pristine
  984. assert_neo_object_is_compliant(result)
  985. self.assertEqual(self.train1.name, result.name)
  986. self.assertEqual(self.train1.description, result.description)
  987. self.assertEqual(self.train1.annotations, result.annotations)
  988. self.assertEqual(self.train1.file_origin, result.file_origin)
  989. self.assertEqual(self.train1.dtype, result.dtype)
  990. self.assertEqual(self.train1.t_start, result.t_start)
  991. self.assertEqual(t_stop, result.t_stop)
  992. def test_time_slice_none_both(self):
  993. self.train1.t_start = 0.1*pq.ms
  994. assert_neo_object_is_compliant(self.train1)
  995. # time_slice spike train, keep sliced spike times
  996. result = self.train1.time_slice(None, None)
  997. assert_arrays_equal(self.train1, result)
  998. assert_arrays_equal(self.waveforms1, result.waveforms)
  999. # but keep everything else pristine
  1000. assert_neo_object_is_compliant(result)
  1001. self.assertEqual(self.train1.name, result.name)
  1002. self.assertEqual(self.train1.description, result.description)
  1003. self.assertEqual(self.train1.annotations, result.annotations)
  1004. self.assertEqual(self.train1.file_origin, result.file_origin)
  1005. self.assertEqual(self.train1.dtype, result.dtype)
  1006. self.assertEqual(self.train1.t_start, result.t_start)
  1007. self.assertEqual(self.train1.t_stop, result.t_stop)
  1008. class TestDuplicateWithNewData(unittest.TestCase):
  1009. def setUp(self):
  1010. self.waveforms = np.array([[[0., 1.],
  1011. [0.1, 1.1]],
  1012. [[2., 3.],
  1013. [2.1, 3.1]],
  1014. [[4., 5.],
  1015. [4.1, 5.1]],
  1016. [[6., 7.],
  1017. [6.1, 7.1]],
  1018. [[8., 9.],
  1019. [8.1, 9.1]],
  1020. [[10., 11.],
  1021. [10.1, 11.1]]]) * pq.mV
  1022. self.data = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7])
  1023. self.dataquant = self.data*pq.ms
  1024. self.train = SpikeTrain(self.dataquant, t_stop=10.0*pq.ms,
  1025. waveforms=self.waveforms)
  1026. def test_duplicate_with_new_data(self):
  1027. signal1 = self.train
  1028. new_t_start = -10*pq.s
  1029. new_t_stop = 10*pq.s
  1030. new_data = np.sort(np.random.uniform(new_t_start.magnitude,
  1031. new_t_stop.magnitude,
  1032. len(self.train))) * pq.ms
  1033. signal1b = signal1.duplicate_with_new_data(new_data,
  1034. t_start=new_t_start,
  1035. t_stop=new_t_stop)
  1036. assert_arrays_almost_equal(np.asarray(signal1b),
  1037. np.asarray(new_data), 1e-12)
  1038. self.assertEqual(signal1b.t_start, new_t_start)
  1039. self.assertEqual(signal1b.t_stop, new_t_stop)
  1040. self.assertEqual(signal1b.sampling_rate, signal1.sampling_rate)
  1041. class TestAttributesAnnotations(unittest.TestCase):
  1042. def test_set_universally_recommended_attributes(self):
  1043. train = SpikeTrain([3, 4, 5], units='sec', name='Name',
  1044. description='Desc', file_origin='crack.txt',
  1045. t_stop=99.9)
  1046. assert_neo_object_is_compliant(train)
  1047. self.assertEqual(train.name, 'Name')
  1048. self.assertEqual(train.description, 'Desc')
  1049. self.assertEqual(train.file_origin, 'crack.txt')
  1050. def test_autoset_universally_recommended_attributes(self):
  1051. train = SpikeTrain([3, 4, 5]*pq.s, t_stop=10.0)
  1052. assert_neo_object_is_compliant(train)
  1053. self.assertEqual(train.name, None)
  1054. self.assertEqual(train.description, None)
  1055. self.assertEqual(train.file_origin, None)
  1056. def test_annotations(self):
  1057. train = SpikeTrain([3, 4, 5]*pq.s, t_stop=11.1)
  1058. assert_neo_object_is_compliant(train)
  1059. self.assertEqual(train.annotations, {})
  1060. train = SpikeTrain([3, 4, 5]*pq.s, t_stop=11.1, ratname='Phillippe')
  1061. assert_neo_object_is_compliant(train)
  1062. self.assertEqual(train.annotations, {'ratname': 'Phillippe'})
  1063. class TestChanging(unittest.TestCase):
  1064. def test_change_with_copy_default(self):
  1065. # Default is copy = True
  1066. # Changing spike train does not change data
  1067. # Data source is quantity
  1068. data = [3, 4, 5] * pq.s
  1069. train = SpikeTrain(data, t_stop=100.0)
  1070. train[0] = 99 * pq.s
  1071. assert_neo_object_is_compliant(train)
  1072. self.assertEqual(train[0], 99*pq.s)
  1073. self.assertEqual(data[0], 3*pq.s)
  1074. def test_change_with_copy_false(self):
  1075. # Changing spike train also changes data, because it is a view
  1076. # Data source is quantity
  1077. data = [3, 4, 5] * pq.s
  1078. train = SpikeTrain(data, copy=False, t_stop=100.0)
  1079. train[0] = 99 * pq.s
  1080. assert_neo_object_is_compliant(train)
  1081. self.assertEqual(train[0], 99*pq.s)
  1082. self.assertEqual(data[0], 99*pq.s)
  1083. def test_change_with_copy_false_and_fake_rescale(self):
  1084. # Changing spike train also changes data, because it is a view
  1085. # Data source is quantity
  1086. data = [3000, 4000, 5000] * pq.ms
  1087. # even though we specify units, it still returns a view
  1088. train = SpikeTrain(data, units='ms', copy=False, t_stop=100000)
  1089. train[0] = 99000 * pq.ms
  1090. assert_neo_object_is_compliant(train)
  1091. self.assertEqual(train[0], 99000*pq.ms)
  1092. self.assertEqual(data[0], 99000*pq.ms)
  1093. def test_change_with_copy_false_and_rescale_true(self):
  1094. # When rescaling, a view cannot be returned
  1095. # Changing spike train also changes data, because it is a view
  1096. data = [3, 4, 5] * pq.s
  1097. self.assertRaises(ValueError, SpikeTrain, data, units='ms',
  1098. copy=False, t_stop=10000)
  1099. def test_init_with_rescale(self):
  1100. data = [3, 4, 5] * pq.s
  1101. train = SpikeTrain(data, units='ms', t_stop=6000)
  1102. assert_neo_object_is_compliant(train)
  1103. self.assertEqual(train[0], 3000*pq.ms)
  1104. self.assertEqual(train._dimensionality, pq.ms._dimensionality)
  1105. self.assertEqual(train.t_stop, 6000*pq.ms)
  1106. def test_change_with_copy_true(self):
  1107. # Changing spike train does not change data
  1108. # Data source is quantity
  1109. data = [3, 4, 5] * pq.s
  1110. train = SpikeTrain(data, copy=True, t_stop=100)
  1111. train[0] = 99 * pq.s
  1112. assert_neo_object_is_compliant(train)
  1113. self.assertEqual(train[0], 99*pq.s)
  1114. self.assertEqual(data[0], 3*pq.s)
  1115. def test_change_with_copy_default_and_data_not_quantity(self):
  1116. # Default is copy = True
  1117. # Changing spike train does not change data
  1118. # Data source is array
  1119. # Array and quantity are tested separately because copy default
  1120. # is different for these two.
  1121. data = [3, 4, 5]
  1122. train = SpikeTrain(data, units='sec', t_stop=100)
  1123. train[0] = 99 * pq.s
  1124. assert_neo_object_is_compliant(train)
  1125. self.assertEqual(train[0], 99*pq.s)
  1126. self.assertEqual(data[0], 3*pq.s)
  1127. def test_change_with_copy_false_and_data_not_quantity(self):
  1128. # Changing spike train also changes data, because it is a view
  1129. # Data source is array
  1130. # Array and quantity are tested separately because copy default
  1131. # is different for these two.
  1132. data = np.array([3, 4, 5])
  1133. train = SpikeTrain(data, units='sec', copy=False, dtype=np.int,
  1134. t_stop=101)
  1135. train[0] = 99 * pq.s
  1136. assert_neo_object_is_compliant(train)
  1137. self.assertEqual(train[0], 99*pq.s)
  1138. self.assertEqual(data[0], 99)
  1139. def test_change_with_copy_false_and_dtype_change(self):
  1140. # You cannot change dtype and request a view
  1141. data = np.array([3, 4, 5])
  1142. self.assertRaises(ValueError, SpikeTrain, data, units='sec',
  1143. copy=False, t_stop=101, dtype=np.float64)
  1144. def test_change_with_copy_true_and_data_not_quantity(self):
  1145. # Changing spike train does not change data
  1146. # Data source is array
  1147. # Array and quantity are tested separately because copy default
  1148. # is different for these two.
  1149. data = [3, 4, 5]
  1150. train = SpikeTrain(data, units='sec', copy=True, t_stop=123.4)
  1151. train[0] = 99 * pq.s
  1152. assert_neo_object_is_compliant(train)
  1153. self.assertEqual(train[0], 99*pq.s)
  1154. self.assertEqual(data[0], 3)
  1155. def test_changing_slice_changes_original_spiketrain(self):
  1156. # If we slice a spiketrain and then change the slice, the
  1157. # original spiketrain should change.
  1158. # Whether the original data source changes is dependent on the
  1159. # copy parameter.
  1160. # This is compatible with both np and quantity default behavior.
  1161. data = [3, 4, 5] * pq.s
  1162. train = SpikeTrain(data, copy=True, t_stop=99.9)
  1163. result = train[1:3]
  1164. result[0] = 99 * pq.s
  1165. assert_neo_object_is_compliant(train)
  1166. self.assertEqual(train[1], 99*pq.s)
  1167. self.assertEqual(result[0], 99*pq.s)
  1168. self.assertEqual(data[1], 4*pq.s)
  1169. def test_changing_slice_changes_original_spiketrain_with_copy_false(self):
  1170. # If we slice a spiketrain and then change the slice, the
  1171. # original spiketrain should change.
  1172. # Whether the original data source changes is dependent on the
  1173. # copy parameter.
  1174. # This is compatible with both np and quantity default behavior.
  1175. data = [3, 4, 5] * pq.s
  1176. train = SpikeTrain(data, copy=False, t_stop=100.0)
  1177. result = train[1:3]
  1178. result[0] = 99 * pq.s
  1179. assert_neo_object_is_compliant(train)
  1180. assert_neo_object_is_compliant(result)
  1181. self.assertEqual(train[1], 99*pq.s)
  1182. self.assertEqual(result[0], 99*pq.s)
  1183. self.assertEqual(data[1], 99*pq.s)
  1184. def test__changing_spiketime_should_check_time_in_range(self):
  1185. data = [3, 4, 5] * pq.ms
  1186. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1187. assert_neo_object_is_compliant(train)
  1188. self.assertRaises(ValueError, train.__setitem__, 0, 10.1*pq.ms)
  1189. self.assertRaises(ValueError, train.__setitem__, 1, 5.0*pq.s)
  1190. self.assertRaises(ValueError, train.__setitem__, 2, 5.0*pq.s)
  1191. self.assertRaises(ValueError, train.__setitem__, 0, 0)
  1192. def test__changing_multiple_spiketimes(self):
  1193. data = [3, 4, 5] * pq.ms
  1194. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1195. train[:] = [7, 8, 9] * pq.ms
  1196. assert_neo_object_is_compliant(train)
  1197. assert_arrays_equal(train, np.array([7, 8, 9]))
  1198. def test__changing_multiple_spiketimes_should_check_time_in_range(self):
  1199. data = [3, 4, 5] * pq.ms
  1200. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1201. assert_neo_object_is_compliant(train)
  1202. if sys.version_info[0] == 2:
  1203. self.assertRaises(ValueError, train.__setslice__,
  1204. 0, 3, [3, 4, 11] * pq.ms)
  1205. self.assertRaises(ValueError, train.__setslice__,
  1206. 0, 3, [0, 4, 5] * pq.ms)
  1207. def test__adding_time(self):
  1208. data = [3, 4, 5] * pq.ms
  1209. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1210. assert_neo_object_is_compliant(train)
  1211. self.assertRaises(ValueError, train.__add__, 10 * pq.ms)
  1212. assert_arrays_equal(train + 1 * pq.ms, data + 1 * pq.ms)
  1213. def test__subtracting_time(self):
  1214. data = [3, 4, 5] * pq.ms
  1215. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1216. assert_neo_object_is_compliant(train)
  1217. self.assertRaises(ValueError, train.__sub__, 10 * pq.ms)
  1218. assert_arrays_equal(train - 1 * pq.ms, data - 1 * pq.ms)
  1219. def test__rescale(self):
  1220. data = [3, 4, 5] * pq.ms
  1221. train = SpikeTrain(data, t_start=0.5, t_stop=10.0)
  1222. result = train.rescale(pq.s)
  1223. assert_neo_object_is_compliant(train)
  1224. assert_neo_object_is_compliant(result)
  1225. assert_arrays_equal(train, result)
  1226. self.assertEqual(result.units, 1 * pq.s)
  1227. def test__rescale_same_units(self):
  1228. data = [3, 4, 5] * pq.ms
  1229. train = SpikeTrain(data, t_start=0.5, t_stop=10.0)
  1230. result = train.rescale(pq.ms)
  1231. assert_neo_object_is_compliant(train)
  1232. assert_arrays_equal(train, result)
  1233. self.assertEqual(result.units, 1 * pq.ms)
  1234. def test__rescale_incompatible_units_ValueError(self):
  1235. data = [3, 4, 5] * pq.ms
  1236. train = SpikeTrain(data, t_start=0.5, t_stop=10.0)
  1237. assert_neo_object_is_compliant(train)
  1238. self.assertRaises(ValueError, train.rescale, pq.m)
  1239. class TestPropertiesMethods(unittest.TestCase):
  1240. def setUp(self):
  1241. self.data1 = [3, 4, 5]
  1242. self.data1quant = self.data1 * pq.ms
  1243. self.waveforms1 = np.array([[[0., 1.],
  1244. [0.1, 1.1]],
  1245. [[2., 3.],
  1246. [2.1, 3.1]],
  1247. [[4., 5.],
  1248. [4.1, 5.1]]]) * pq.mV
  1249. self.t_start1 = 0.5
  1250. self.t_stop1 = 10.0
  1251. self.t_start1quant = self.t_start1 * pq.ms
  1252. self.t_stop1quant = self.t_stop1 * pq.ms
  1253. self.sampling_rate1 = .1*pq.Hz
  1254. self.left_sweep1 = 2.*pq.s
  1255. self.name1 = 'train 1'
  1256. self.description1 = 'a test object'
  1257. self.ann1 = {'targ0': [1, 2], 'targ1': 1.1}
  1258. self.train1 = SpikeTrain(self.data1quant,
  1259. t_start=self.t_start1, t_stop=self.t_stop1,
  1260. waveforms=self.waveforms1,
  1261. left_sweep=self.left_sweep1,
  1262. sampling_rate=self.sampling_rate1,
  1263. name=self.name1,
  1264. description=self.description1,
  1265. **self.ann1)
  1266. def test__compliant(self):
  1267. assert_neo_object_is_compliant(self.train1)
  1268. def test__repr(self):
  1269. result = repr(self.train1)
  1270. targ = '<SpikeTrain(array([ 3., 4., 5.]) * ms, [0.5 ms, 10.0 ms])>'
  1271. self.assertEqual(result, targ)
  1272. def test__duration(self):
  1273. result1 = self.train1.duration
  1274. self.train1.t_start = None
  1275. assert_neo_object_is_compliant(self.train1)
  1276. result2 = self.train1.duration
  1277. self.train1.t_start = self.t_start1quant
  1278. self.train1.t_stop = None
  1279. assert_neo_object_is_compliant(self.train1)
  1280. result3 = self.train1.duration
  1281. self.assertEqual(result1, 9.5 * pq.ms)
  1282. self.assertEqual(result1.units, 1. * pq.ms)
  1283. self.assertEqual(result2, None)
  1284. self.assertEqual(result3, None)
  1285. def test__spike_duration(self):
  1286. result1 = self.train1.spike_duration
  1287. self.train1.sampling_rate = None
  1288. assert_neo_object_is_compliant(self.train1)
  1289. result2 = self.train1.spike_duration
  1290. self.train1.sampling_rate = self.sampling_rate1
  1291. self.train1.waveforms = None
  1292. assert_neo_object_is_compliant(self.train1)
  1293. result3 = self.train1.spike_duration
  1294. self.assertEqual(result1, 20./pq.Hz)
  1295. self.assertEqual(result1.units, 1./pq.Hz)
  1296. self.assertEqual(result2, None)
  1297. self.assertEqual(result3, None)
  1298. def test__sampling_period(self):
  1299. result1 = self.train1.sampling_period
  1300. self.train1.sampling_rate = None
  1301. assert_neo_object_is_compliant(self.train1)
  1302. result2 = self.train1.sampling_period
  1303. self.train1.sampling_rate = self.sampling_rate1
  1304. self.train1.sampling_period = 10.*pq.ms
  1305. assert_neo_object_is_compliant(self.train1)
  1306. result3a = self.train1.sampling_period
  1307. result3b = self.train1.sampling_rate
  1308. self.train1.sampling_period = None
  1309. result4a = self.train1.sampling_period
  1310. result4b = self.train1.sampling_rate
  1311. self.assertEqual(result1, 10./pq.Hz)
  1312. self.assertEqual(result1.units, 1./pq.Hz)
  1313. self.assertEqual(result2, None)
  1314. self.assertEqual(result3a, 10.*pq.ms)
  1315. self.assertEqual(result3a.units, 1.*pq.ms)
  1316. self.assertEqual(result3b, .1/pq.ms)
  1317. self.assertEqual(result3b.units, 1./pq.ms)
  1318. self.assertEqual(result4a, None)
  1319. self.assertEqual(result4b, None)
  1320. def test__right_sweep(self):
  1321. result1 = self.train1.right_sweep
  1322. self.train1.left_sweep = None
  1323. assert_neo_object_is_compliant(self.train1)
  1324. result2 = self.train1.right_sweep
  1325. self.train1.left_sweep = self.left_sweep1
  1326. self.train1.sampling_rate = None
  1327. assert_neo_object_is_compliant(self.train1)
  1328. result3 = self.train1.right_sweep
  1329. self.train1.sampling_rate = self.sampling_rate1
  1330. self.train1.waveforms = None
  1331. assert_neo_object_is_compliant(self.train1)
  1332. result4 = self.train1.right_sweep
  1333. self.assertEqual(result1, 22.*pq.s)
  1334. self.assertEqual(result1.units, 1.*pq.s)
  1335. self.assertEqual(result2, None)
  1336. self.assertEqual(result3, None)
  1337. self.assertEqual(result4, None)
  1338. def test__children(self):
  1339. segment = Segment(name='seg1')
  1340. segment.spiketrains = [self.train1]
  1341. segment.create_many_to_one_relationship()
  1342. unit = Unit(name='unit1')
  1343. unit.spiketrains = [self.train1]
  1344. unit.create_many_to_one_relationship()
  1345. self.assertEqual(self.train1._single_parent_objects,
  1346. ('Segment', 'Unit'))
  1347. self.assertEqual(self.train1._multi_parent_objects, ())
  1348. self.assertEqual(self.train1._single_parent_containers,
  1349. ('segment', 'unit'))
  1350. self.assertEqual(self.train1._multi_parent_containers, ())
  1351. self.assertEqual(self.train1._parent_objects,
  1352. ('Segment', 'Unit'))
  1353. self.assertEqual(self.train1._parent_containers,
  1354. ('segment', 'unit'))
  1355. self.assertEqual(len(self.train1.parents), 2)
  1356. self.assertEqual(self.train1.parents[0].name, 'seg1')
  1357. self.assertEqual(self.train1.parents[1].name, 'unit1')
  1358. assert_neo_object_is_compliant(self.train1)
  1359. @unittest.skipUnless(HAVE_IPYTHON, "requires IPython")
  1360. def test__pretty(self):
  1361. res = pretty(self.train1)
  1362. targ = ("SpikeTrain\n" +
  1363. "name: '%s'\ndescription: '%s'\nannotations: %s" %
  1364. (self.name1, self.description1, pretty(self.ann1)))
  1365. self.assertEqual(res, targ)
  1366. class TestMiscellaneous(unittest.TestCase):
  1367. def test__different_dtype_for_t_start_and_array(self):
  1368. data = np.array([0, 9.9999999], dtype=np.float64) * pq.s
  1369. data16 = data.astype(np.float16)
  1370. data32 = data.astype(np.float32)
  1371. data64 = data.astype(np.float64)
  1372. t_start = data[0]
  1373. t_stop = data[1]
  1374. t_start16 = data[0].astype(dtype=np.float16)
  1375. t_stop16 = data[1].astype(dtype=np.float16)
  1376. t_start32 = data[0].astype(dtype=np.float32)
  1377. t_stop32 = data[1].astype(dtype=np.float32)
  1378. t_start64 = data[0].astype(dtype=np.float64)
  1379. t_stop64 = data[1].astype(dtype=np.float64)
  1380. t_start_custom = 0.0
  1381. t_stop_custom = 10.0
  1382. t_start_custom16 = np.array(t_start_custom, dtype=np.float16)
  1383. t_stop_custom16 = np.array(t_stop_custom, dtype=np.float16)
  1384. t_start_custom32 = np.array(t_start_custom, dtype=np.float32)
  1385. t_stop_custom32 = np.array(t_stop_custom, dtype=np.float32)
  1386. t_start_custom64 = np.array(t_start_custom, dtype=np.float64)
  1387. t_stop_custom64 = np.array(t_stop_custom, dtype=np.float64)
  1388. #This is OK.
  1389. train = SpikeTrain(data64, copy=True, t_start=t_start, t_stop=t_stop)
  1390. assert_neo_object_is_compliant(train)
  1391. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop,
  1392. dtype=np.float16)
  1393. assert_neo_object_is_compliant(train)
  1394. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop,
  1395. dtype=np.float32)
  1396. assert_neo_object_is_compliant(train)
  1397. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop,
  1398. dtype=np.float16)
  1399. assert_neo_object_is_compliant(train)
  1400. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop,
  1401. dtype=np.float32)
  1402. assert_neo_object_is_compliant(train)
  1403. train = SpikeTrain(data32, copy=True,
  1404. t_start=t_start16, t_stop=t_stop16)
  1405. assert_neo_object_is_compliant(train)
  1406. train = SpikeTrain(data32, copy=True,
  1407. t_start=t_start16, t_stop=t_stop16,
  1408. dtype=np.float16)
  1409. assert_neo_object_is_compliant(train)
  1410. train = SpikeTrain(data32, copy=True,
  1411. t_start=t_start16, t_stop=t_stop16,
  1412. dtype=np.float32)
  1413. assert_neo_object_is_compliant(train)
  1414. train = SpikeTrain(data32, copy=True,
  1415. t_start=t_start16, t_stop=t_stop16,
  1416. dtype=np.float64)
  1417. assert_neo_object_is_compliant(train)
  1418. train = SpikeTrain(data32, copy=True,
  1419. t_start=t_start32, t_stop=t_stop32)
  1420. assert_neo_object_is_compliant(train)
  1421. train = SpikeTrain(data32, copy=True,
  1422. t_start=t_start32, t_stop=t_stop32,
  1423. dtype=np.float16)
  1424. assert_neo_object_is_compliant(train)
  1425. train = SpikeTrain(data32, copy=True,
  1426. t_start=t_start32, t_stop=t_stop32,
  1427. dtype=np.float32)
  1428. assert_neo_object_is_compliant(train)
  1429. train = SpikeTrain(data32, copy=True,
  1430. t_start=t_start32, t_stop=t_stop32,
  1431. dtype=np.float64)
  1432. assert_neo_object_is_compliant(train)
  1433. train = SpikeTrain(data32, copy=True,
  1434. t_start=t_start64, t_stop=t_stop64,
  1435. dtype=np.float16)
  1436. assert_neo_object_is_compliant(train)
  1437. train = SpikeTrain(data32, copy=True,
  1438. t_start=t_start64, t_stop=t_stop64,
  1439. dtype=np.float32)
  1440. assert_neo_object_is_compliant(train)
  1441. train = SpikeTrain(data16, copy=True,
  1442. t_start=t_start_custom, t_stop=t_stop_custom)
  1443. assert_neo_object_is_compliant(train)
  1444. train = SpikeTrain(data16, copy=True,
  1445. t_start=t_start_custom, t_stop=t_stop_custom,
  1446. dtype=np.float16)
  1447. assert_neo_object_is_compliant(train)
  1448. train = SpikeTrain(data16, copy=True,
  1449. t_start=t_start_custom, t_stop=t_stop_custom,
  1450. dtype=np.float32)
  1451. assert_neo_object_is_compliant(train)
  1452. train = SpikeTrain(data16, copy=True,
  1453. t_start=t_start_custom, t_stop=t_stop_custom,
  1454. dtype=np.float64)
  1455. assert_neo_object_is_compliant(train)
  1456. train = SpikeTrain(data32, copy=True,
  1457. t_start=t_start_custom, t_stop=t_stop_custom)
  1458. assert_neo_object_is_compliant(train)
  1459. train = SpikeTrain(data32, copy=True,
  1460. t_start=t_start_custom, t_stop=t_stop_custom,
  1461. dtype=np.float16)
  1462. assert_neo_object_is_compliant(train)
  1463. train = SpikeTrain(data32, copy=True,
  1464. t_start=t_start_custom, t_stop=t_stop_custom,
  1465. dtype=np.float32)
  1466. assert_neo_object_is_compliant(train)
  1467. train = SpikeTrain(data32, copy=True,
  1468. t_start=t_start_custom, t_stop=t_stop_custom,
  1469. dtype=np.float64)
  1470. assert_neo_object_is_compliant(train)
  1471. train = SpikeTrain(data16, copy=True,
  1472. t_start=t_start_custom, t_stop=t_stop_custom)
  1473. assert_neo_object_is_compliant(train)
  1474. train = SpikeTrain(data16, copy=True,
  1475. t_start=t_start_custom, t_stop=t_stop_custom,
  1476. dtype=np.float16)
  1477. assert_neo_object_is_compliant(train)
  1478. train = SpikeTrain(data16, copy=True,
  1479. t_start=t_start_custom, t_stop=t_stop_custom,
  1480. dtype=np.float32)
  1481. assert_neo_object_is_compliant(train)
  1482. train = SpikeTrain(data16, copy=True,
  1483. t_start=t_start_custom, t_stop=t_stop_custom,
  1484. dtype=np.float64)
  1485. assert_neo_object_is_compliant(train)
  1486. train = SpikeTrain(data32, copy=True,
  1487. t_start=t_start_custom16, t_stop=t_stop_custom16)
  1488. assert_neo_object_is_compliant(train)
  1489. train = SpikeTrain(data32, copy=True,
  1490. t_start=t_start_custom16, t_stop=t_stop_custom16,
  1491. dtype=np.float16)
  1492. assert_neo_object_is_compliant(train)
  1493. train = SpikeTrain(data32, copy=True,
  1494. t_start=t_start_custom16, t_stop=t_stop_custom16,
  1495. dtype=np.float32)
  1496. assert_neo_object_is_compliant(train)
  1497. train = SpikeTrain(data32, copy=True,
  1498. t_start=t_start_custom16, t_stop=t_stop_custom16,
  1499. dtype=np.float64)
  1500. assert_neo_object_is_compliant(train)
  1501. train = SpikeTrain(data32, copy=True,
  1502. t_start=t_start_custom32, t_stop=t_stop_custom32)
  1503. assert_neo_object_is_compliant(train)
  1504. train = SpikeTrain(data32, copy=True,
  1505. t_start=t_start_custom32, t_stop=t_stop_custom32,
  1506. dtype=np.float16)
  1507. assert_neo_object_is_compliant(train)
  1508. train = SpikeTrain(data32, copy=True,
  1509. t_start=t_start_custom32, t_stop=t_stop_custom32,
  1510. dtype=np.float32)
  1511. assert_neo_object_is_compliant(train)
  1512. train = SpikeTrain(data32, copy=True,
  1513. t_start=t_start_custom32, t_stop=t_stop_custom32,
  1514. dtype=np.float64)
  1515. assert_neo_object_is_compliant(train)
  1516. train = SpikeTrain(data32, copy=True,
  1517. t_start=t_start_custom64, t_stop=t_stop_custom64)
  1518. assert_neo_object_is_compliant(train)
  1519. train = SpikeTrain(data32, copy=True,
  1520. t_start=t_start_custom64, t_stop=t_stop_custom64,
  1521. dtype=np.float16)
  1522. assert_neo_object_is_compliant(train)
  1523. train = SpikeTrain(data32, copy=True,
  1524. t_start=t_start_custom64, t_stop=t_stop_custom64,
  1525. dtype=np.float32)
  1526. assert_neo_object_is_compliant(train)
  1527. train = SpikeTrain(data32, copy=True,
  1528. t_start=t_start_custom64, t_stop=t_stop_custom64,
  1529. dtype=np.float64)
  1530. assert_neo_object_is_compliant(train)
  1531. #This use to bug - see ticket #38
  1532. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop)
  1533. assert_neo_object_is_compliant(train)
  1534. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop,
  1535. dtype=np.float64)
  1536. assert_neo_object_is_compliant(train)
  1537. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop)
  1538. assert_neo_object_is_compliant(train)
  1539. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop,
  1540. dtype=np.float64)
  1541. assert_neo_object_is_compliant(train)
  1542. train = SpikeTrain(data32, copy=True,
  1543. t_start=t_start64, t_stop=t_stop64)
  1544. assert_neo_object_is_compliant(train)
  1545. train = SpikeTrain(data32, copy=True,
  1546. t_start=t_start64, t_stop=t_stop64,
  1547. dtype=np.float64)
  1548. assert_neo_object_is_compliant(train)
  1549. def test_as_array(self):
  1550. data = np.arange(10.0)
  1551. st = SpikeTrain(data, t_stop=10.0, units='ms')
  1552. st_as_arr = st.as_array()
  1553. self.assertIsInstance(st_as_arr, np.ndarray)
  1554. assert_array_equal(data, st_as_arr)
  1555. def test_as_quantity(self):
  1556. data = np.arange(10.0)
  1557. st = SpikeTrain(data, t_stop=10.0, units='ms')
  1558. st_as_q = st.as_quantity()
  1559. self.assertIsInstance(st_as_q, pq.Quantity)
  1560. assert_array_equal(data * pq.ms, st_as_q)
  1561. if __name__ == "__main__":
  1562. unittest.main()