test_spiketrain.py 82 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of the neo.core.spiketrain.SpikeTrain class and related functions
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import
  7. import sys
  8. import unittest
  9. import numpy as np
  10. from numpy.testing import assert_array_equal
  11. import quantities as pq
  12. try:
  13. from IPython.lib.pretty import pretty
  14. except ImportError as err:
  15. HAVE_IPYTHON = False
  16. else:
  17. HAVE_IPYTHON = True
  18. from neo.core.spiketrain import (check_has_dimensions_time, SpikeTrain,
  19. _check_time_in_range, _new_spiketrain)
  20. from neo.core import Segment, Unit
  21. from neo.core.baseneo import MergeError
  22. from neo.test.tools import (assert_arrays_equal,
  23. assert_arrays_almost_equal,
  24. assert_neo_object_is_compliant)
  25. from neo.test.generate_datasets import (get_fake_value, get_fake_values,
  26. fake_neo, TEST_ANNOTATIONS)
  27. class Test__generate_datasets(unittest.TestCase):
  28. def setUp(self):
  29. np.random.seed(0)
  30. self.annotations = dict([(str(x), TEST_ANNOTATIONS[x]) for x in
  31. range(len(TEST_ANNOTATIONS))])
  32. def test__get_fake_values(self):
  33. self.annotations['seed'] = 0
  34. waveforms = get_fake_value('waveforms', pq.Quantity, seed=3, dim=3)
  35. shape = waveforms.shape[0]
  36. times = get_fake_value('times', pq.Quantity, seed=0, dim=1,
  37. shape=waveforms.shape[0])
  38. t_start = get_fake_value('t_start', pq.Quantity, seed=1, dim=0)
  39. t_stop = get_fake_value('t_stop', pq.Quantity, seed=2, dim=0)
  40. left_sweep = get_fake_value('left_sweep', pq.Quantity, seed=4, dim=0)
  41. sampling_rate = get_fake_value('sampling_rate', pq.Quantity,
  42. seed=5, dim=0)
  43. name = get_fake_value('name', str, seed=6, obj=SpikeTrain)
  44. description = get_fake_value('description', str,
  45. seed=7, obj='SpikeTrain')
  46. file_origin = get_fake_value('file_origin', str)
  47. attrs1 = {'name': name,
  48. 'description': description,
  49. 'file_origin': file_origin}
  50. attrs2 = attrs1.copy()
  51. attrs2.update(self.annotations)
  52. res11 = get_fake_values(SpikeTrain, annotate=False, seed=0)
  53. res12 = get_fake_values('SpikeTrain', annotate=False, seed=0)
  54. res21 = get_fake_values(SpikeTrain, annotate=True, seed=0)
  55. res22 = get_fake_values('SpikeTrain', annotate=True, seed=0)
  56. assert_arrays_equal(res11.pop('times'), times)
  57. assert_arrays_equal(res12.pop('times'), times)
  58. assert_arrays_equal(res21.pop('times'), times)
  59. assert_arrays_equal(res22.pop('times'), times)
  60. assert_arrays_equal(res11.pop('t_start'), t_start)
  61. assert_arrays_equal(res12.pop('t_start'), t_start)
  62. assert_arrays_equal(res21.pop('t_start'), t_start)
  63. assert_arrays_equal(res22.pop('t_start'), t_start)
  64. assert_arrays_equal(res11.pop('t_stop'), t_stop)
  65. assert_arrays_equal(res12.pop('t_stop'), t_stop)
  66. assert_arrays_equal(res21.pop('t_stop'), t_stop)
  67. assert_arrays_equal(res22.pop('t_stop'), t_stop)
  68. assert_arrays_equal(res11.pop('waveforms'), waveforms)
  69. assert_arrays_equal(res12.pop('waveforms'), waveforms)
  70. assert_arrays_equal(res21.pop('waveforms'), waveforms)
  71. assert_arrays_equal(res22.pop('waveforms'), waveforms)
  72. assert_arrays_equal(res11.pop('left_sweep'), left_sweep)
  73. assert_arrays_equal(res12.pop('left_sweep'), left_sweep)
  74. assert_arrays_equal(res21.pop('left_sweep'), left_sweep)
  75. assert_arrays_equal(res22.pop('left_sweep'), left_sweep)
  76. assert_arrays_equal(res11.pop('sampling_rate'), sampling_rate)
  77. assert_arrays_equal(res12.pop('sampling_rate'), sampling_rate)
  78. assert_arrays_equal(res21.pop('sampling_rate'), sampling_rate)
  79. assert_arrays_equal(res22.pop('sampling_rate'), sampling_rate)
  80. self.assertEqual(res11, attrs1)
  81. self.assertEqual(res12, attrs1)
  82. self.assertEqual(res21, attrs2)
  83. self.assertEqual(res22, attrs2)
  84. def test__fake_neo__cascade(self):
  85. self.annotations['seed'] = None
  86. obj_type = 'SpikeTrain'
  87. cascade = True
  88. res = fake_neo(obj_type=obj_type, cascade=cascade)
  89. self.assertTrue(isinstance(res, SpikeTrain))
  90. assert_neo_object_is_compliant(res)
  91. self.assertEqual(res.annotations, self.annotations)
  92. def test__fake_neo__nocascade(self):
  93. self.annotations['seed'] = None
  94. obj_type = SpikeTrain
  95. cascade = False
  96. res = fake_neo(obj_type=obj_type, cascade=cascade)
  97. self.assertTrue(isinstance(res, SpikeTrain))
  98. assert_neo_object_is_compliant(res)
  99. self.assertEqual(res.annotations, self.annotations)
  100. class Testcheck_has_dimensions_time(unittest.TestCase):
  101. def test__check_has_dimensions_time(self):
  102. a = np.arange(3) * pq.ms
  103. b = np.arange(3) * pq.mV
  104. c = np.arange(3) * pq.mA
  105. d = np.arange(3) * pq.minute
  106. check_has_dimensions_time(a)
  107. self.assertRaises(ValueError, check_has_dimensions_time, b)
  108. self.assertRaises(ValueError, check_has_dimensions_time, c)
  109. check_has_dimensions_time(d)
  110. self.assertRaises(ValueError, check_has_dimensions_time, a, b, c, d)
  111. class Testcheck_time_in_range(unittest.TestCase):
  112. def test__check_time_in_range_empty_array(self):
  113. value = np.array([])
  114. t_start = 0 * pq.s
  115. t_stop = 10 * pq.s
  116. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  117. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  118. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=True)
  119. def test__check_time_in_range_exact(self):
  120. value = np.array([0., 5., 10.]) * pq.s
  121. t_start = 0. * pq.s
  122. t_stop = 10. * pq.s
  123. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  124. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  125. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=True)
  126. def test__check_time_in_range_scale(self):
  127. value = np.array([0., 5000., 10000.]) * pq.ms
  128. t_start = 0. * pq.s
  129. t_stop = 10. * pq.s
  130. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  131. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  132. def test__check_time_in_range_inside(self):
  133. value = np.array([0.1, 5., 9.9]) * pq.s
  134. t_start = 0. * pq.s
  135. t_stop = 10. * pq.s
  136. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  137. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  138. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=True)
  139. def test__check_time_in_range_below(self):
  140. value = np.array([-0.1, 5., 10.]) * pq.s
  141. t_start = 0. * pq.s
  142. t_stop = 10. * pq.s
  143. self.assertRaises(ValueError, _check_time_in_range, value,
  144. t_start=t_start, t_stop=t_stop)
  145. self.assertRaises(ValueError, _check_time_in_range, value,
  146. t_start=t_start, t_stop=t_stop, view=False)
  147. self.assertRaises(ValueError, _check_time_in_range, value,
  148. t_start=t_start, t_stop=t_stop, view=True)
  149. def test__check_time_in_range_below_scale(self):
  150. value = np.array([-1., 5000., 10000.]) * pq.ms
  151. t_start = 0. * pq.s
  152. t_stop = 10. * pq.s
  153. self.assertRaises(ValueError, _check_time_in_range, value,
  154. t_start=t_start, t_stop=t_stop)
  155. self.assertRaises(ValueError, _check_time_in_range, value,
  156. t_start=t_start, t_stop=t_stop, view=False)
  157. def test__check_time_in_range_above(self):
  158. value = np.array([0., 5., 10.1]) * pq.s
  159. t_start = 0. * pq.s
  160. t_stop = 10. * pq.s
  161. self.assertRaises(ValueError, _check_time_in_range, value,
  162. t_start=t_start, t_stop=t_stop)
  163. self.assertRaises(ValueError, _check_time_in_range, value,
  164. t_start=t_start, t_stop=t_stop, view=False)
  165. self.assertRaises(ValueError, _check_time_in_range, value,
  166. t_start=t_start, t_stop=t_stop, view=True)
  167. def test__check_time_in_range_above_scale(self):
  168. value = np.array([0., 5000., 10001.]) * pq.ms
  169. t_start = 0. * pq.s
  170. t_stop = 10. * pq.s
  171. self.assertRaises(ValueError, _check_time_in_range, value,
  172. t_start=t_start, t_stop=t_stop)
  173. self.assertRaises(ValueError, _check_time_in_range, value,
  174. t_start=t_start, t_stop=t_stop, view=False)
  175. def test__check_time_in_range_above_below(self):
  176. value = np.array([-0.1, 5., 10.1]) * pq.s
  177. t_start = 0. * pq.s
  178. t_stop = 10. * pq.s
  179. self.assertRaises(ValueError, _check_time_in_range, value,
  180. t_start=t_start, t_stop=t_stop)
  181. self.assertRaises(ValueError, _check_time_in_range, value,
  182. t_start=t_start, t_stop=t_stop, view=False)
  183. self.assertRaises(ValueError, _check_time_in_range, value,
  184. t_start=t_start, t_stop=t_stop, view=True)
  185. def test__check_time_in_range_above_below_scale(self):
  186. value = np.array([-1., 5000., 10001.]) * pq.ms
  187. t_start = 0. * pq.s
  188. t_stop = 10. * pq.s
  189. self.assertRaises(ValueError, _check_time_in_range, value,
  190. t_start=t_start, t_stop=t_stop)
  191. self.assertRaises(ValueError, _check_time_in_range, value,
  192. t_start=t_start, t_stop=t_stop, view=False)
  193. class TestConstructor(unittest.TestCase):
  194. def result_spike_check(self, train, st_out, t_start_out, t_stop_out,
  195. dtype, units):
  196. assert_arrays_equal(train, st_out)
  197. assert_arrays_equal(train, train.times)
  198. assert_neo_object_is_compliant(train)
  199. self.assertEqual(train.t_start, t_start_out)
  200. self.assertEqual(train.t_start, train.times.t_start)
  201. self.assertEqual(train.t_stop, t_stop_out)
  202. self.assertEqual(train.t_stop, train.times.t_stop)
  203. self.assertEqual(train.units, units)
  204. self.assertEqual(train.units, train.times.units)
  205. self.assertEqual(train.t_start.units, units)
  206. self.assertEqual(train.t_start.units, train.times.t_start.units)
  207. self.assertEqual(train.t_stop.units, units)
  208. self.assertEqual(train.t_stop.units, train.times.t_stop.units)
  209. self.assertEqual(train.dtype, dtype)
  210. self.assertEqual(train.dtype, train.times.dtype)
  211. self.assertEqual(train.t_stop.dtype, dtype)
  212. self.assertEqual(train.t_stop.dtype, train.times.t_stop.dtype)
  213. self.assertEqual(train.t_start.dtype, dtype)
  214. self.assertEqual(train.t_start.dtype, train.times.t_start.dtype)
  215. def test__create_minimal(self):
  216. t_start = 0.0
  217. t_stop = 10.0
  218. train1 = SpikeTrain([] * pq.s, t_stop)
  219. train2 = _new_spiketrain(SpikeTrain, [] * pq.s, t_stop)
  220. dtype = np.float64
  221. units = 1 * pq.s
  222. t_start_out = t_start * units
  223. t_stop_out = t_stop * units
  224. st_out = [] * units
  225. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  226. dtype, units)
  227. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  228. dtype, units)
  229. def test__create_empty(self):
  230. t_start = 0.0
  231. t_stop = 10.0
  232. train1 = SpikeTrain([], t_start=t_start, t_stop=t_stop, units='s')
  233. train2 = _new_spiketrain(SpikeTrain, [], t_start=t_start,
  234. t_stop=t_stop, units='s')
  235. dtype = np.float64
  236. units = 1 * pq.s
  237. t_start_out = t_start * units
  238. t_stop_out = t_stop * units
  239. st_out = [] * units
  240. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  241. dtype, units)
  242. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  243. dtype, units)
  244. def test__create_empty_no_t_start(self):
  245. t_start = 0.0
  246. t_stop = 10.0
  247. train1 = SpikeTrain([], t_stop=t_stop, units='s')
  248. train2 = _new_spiketrain(SpikeTrain, [], t_stop=t_stop, units='s')
  249. dtype = np.float64
  250. units = 1 * pq.s
  251. t_start_out = t_start * units
  252. t_stop_out = t_stop * units
  253. st_out = [] * units
  254. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  255. dtype, units)
  256. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  257. dtype, units)
  258. def test__create_from_list(self):
  259. times = range(10)
  260. t_start = 0.0 * pq.s
  261. t_stop = 10000.0 * pq.ms
  262. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="ms")
  263. train2 = _new_spiketrain(SpikeTrain, times,
  264. t_start=t_start, t_stop=t_stop, units="ms")
  265. dtype = np.float64
  266. units = 1 * pq.ms
  267. t_start_out = t_start
  268. t_stop_out = t_stop
  269. st_out = times * units
  270. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  271. dtype, units)
  272. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  273. dtype, units)
  274. def test__create_from_list_set_dtype(self):
  275. times = range(10)
  276. t_start = 0.0 * pq.s
  277. t_stop = 10000.0 * pq.ms
  278. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  279. units="ms", dtype='f4')
  280. train2 = _new_spiketrain(SpikeTrain, times,
  281. t_start=t_start, t_stop=t_stop,
  282. units="ms", dtype='f4')
  283. dtype = np.float32
  284. units = 1 * pq.ms
  285. t_start_out = t_start.astype(dtype)
  286. t_stop_out = t_stop.astype(dtype)
  287. st_out = pq.Quantity(times, units=units, dtype=dtype)
  288. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  289. dtype, units)
  290. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  291. dtype, units)
  292. def test__create_from_list_no_start_stop_units(self):
  293. times = range(10)
  294. t_start = 0.0
  295. t_stop = 10000.0
  296. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="ms")
  297. train2 = _new_spiketrain(SpikeTrain, times,
  298. t_start=t_start, t_stop=t_stop, units="ms")
  299. dtype = np.float64
  300. units = 1 * pq.ms
  301. t_start_out = t_start * units
  302. t_stop_out = t_stop * units
  303. st_out = times * units
  304. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  305. dtype, units)
  306. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  307. dtype, units)
  308. def test__create_from_list_no_start_stop_units_set_dtype(self):
  309. times = range(10)
  310. t_start = 0.0
  311. t_stop = 10000.0
  312. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  313. units="ms", dtype='f4')
  314. train2 = _new_spiketrain(SpikeTrain, times,
  315. t_start=t_start, t_stop=t_stop,
  316. units="ms", dtype='f4')
  317. dtype = np.float32
  318. units = 1 * pq.ms
  319. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  320. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  321. st_out = pq.Quantity(times, units=units, dtype=dtype)
  322. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  323. dtype, units)
  324. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  325. dtype, units)
  326. def test__create_from_array(self):
  327. times = np.arange(10)
  328. t_start = 0.0 * pq.s
  329. t_stop = 10000.0 * pq.ms
  330. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  331. train2 = _new_spiketrain(SpikeTrain, times,
  332. t_start=t_start, t_stop=t_stop, units="s")
  333. dtype = np.int
  334. units = 1 * pq.s
  335. t_start_out = t_start
  336. t_stop_out = t_stop
  337. st_out = times * units
  338. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  339. dtype, units)
  340. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  341. dtype, units)
  342. def test__create_from_array_with_dtype(self):
  343. times = np.arange(10, dtype='f4')
  344. t_start = 0.0 * pq.s
  345. t_stop = 10000.0 * pq.ms
  346. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  347. train2 = _new_spiketrain(SpikeTrain, times,
  348. t_start=t_start, t_stop=t_stop, units="s")
  349. dtype = times.dtype
  350. units = 1 * pq.s
  351. t_start_out = t_start
  352. t_stop_out = t_stop
  353. st_out = times * units
  354. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  355. dtype, units)
  356. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  357. dtype, units)
  358. def test__create_from_array_set_dtype(self):
  359. times = np.arange(10)
  360. t_start = 0.0 * pq.s
  361. t_stop = 10000.0 * pq.ms
  362. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  363. units="s", dtype='f4')
  364. train2 = _new_spiketrain(SpikeTrain, times,
  365. t_start=t_start, t_stop=t_stop,
  366. units="s", dtype='f4')
  367. dtype = np.float32
  368. units = 1 * pq.s
  369. t_start_out = t_start.astype(dtype)
  370. t_stop_out = t_stop.astype(dtype)
  371. st_out = times.astype(dtype) * units
  372. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  373. dtype, units)
  374. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  375. dtype, units)
  376. def test__create_from_array_no_start_stop_units(self):
  377. times = np.arange(10)
  378. t_start = 0.0
  379. t_stop = 10000.0
  380. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  381. train2 = _new_spiketrain(SpikeTrain, times,
  382. t_start=t_start, t_stop=t_stop, units="s")
  383. dtype = np.int
  384. units = 1 * pq.s
  385. t_start_out = t_start * units
  386. t_stop_out = t_stop * units
  387. st_out = times * units
  388. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  389. dtype, units)
  390. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  391. dtype, units)
  392. def test__create_from_array_no_start_stop_units_with_dtype(self):
  393. times = np.arange(10, dtype='f4')
  394. t_start = 0.0
  395. t_stop = 10000.0
  396. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  397. train2 = _new_spiketrain(SpikeTrain, times,
  398. t_start=t_start, t_stop=t_stop, units="s")
  399. dtype = np.float32
  400. units = 1 * pq.s
  401. t_start_out = t_start * units
  402. t_stop_out = t_stop * units
  403. st_out = times * units
  404. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  405. dtype, units)
  406. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  407. dtype, units)
  408. def test__create_from_array_no_start_stop_units_set_dtype(self):
  409. times = np.arange(10)
  410. t_start = 0.0
  411. t_stop = 10000.0
  412. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  413. units="s", dtype='f4')
  414. train2 = _new_spiketrain(SpikeTrain, times,
  415. t_start=t_start, t_stop=t_stop,
  416. units="s", dtype='f4')
  417. dtype = np.float32
  418. units = 1 * pq.s
  419. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  420. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  421. st_out = times.astype(dtype) * units
  422. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  423. dtype, units)
  424. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  425. dtype, units)
  426. def test__create_from_quantity_array(self):
  427. times = np.arange(10) * pq.ms
  428. t_start = 0.0 * pq.s
  429. t_stop = 12.0 * pq.ms
  430. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  431. train2 = _new_spiketrain(SpikeTrain, times,
  432. t_start=t_start, t_stop=t_stop)
  433. dtype = np.float64
  434. units = 1 * pq.ms
  435. t_start_out = t_start
  436. t_stop_out = t_stop
  437. st_out = times
  438. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  439. dtype, units)
  440. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  441. dtype, units)
  442. def test__create_from_quantity_array_with_dtype(self):
  443. times = np.arange(10, dtype='f4') * pq.ms
  444. t_start = 0.0 * pq.s
  445. t_stop = 12.0 * pq.ms
  446. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  447. train2 = _new_spiketrain(SpikeTrain, times,
  448. t_start=t_start, t_stop=t_stop)
  449. dtype = np.float32
  450. units = 1 * pq.ms
  451. t_start_out = t_start.astype(dtype)
  452. t_stop_out = t_stop.astype(dtype)
  453. st_out = times.astype(dtype)
  454. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  455. dtype, units)
  456. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  457. dtype, units)
  458. def test__create_from_quantity_array_set_dtype(self):
  459. times = np.arange(10) * pq.ms
  460. t_start = 0.0 * pq.s
  461. t_stop = 12.0 * pq.ms
  462. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  463. dtype='f4')
  464. train2 = _new_spiketrain(SpikeTrain, times,
  465. t_start=t_start, t_stop=t_stop,
  466. dtype='f4')
  467. dtype = np.float32
  468. units = 1 * pq.ms
  469. t_start_out = t_start.astype(dtype)
  470. t_stop_out = t_stop.astype(dtype)
  471. st_out = times.astype(dtype)
  472. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  473. dtype, units)
  474. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  475. dtype, units)
  476. def test__create_from_quantity_array_no_start_stop_units(self):
  477. times = np.arange(10) * pq.ms
  478. t_start = 0.0
  479. t_stop = 12.0
  480. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  481. train2 = _new_spiketrain(SpikeTrain, times,
  482. t_start=t_start, t_stop=t_stop)
  483. dtype = np.float64
  484. units = 1 * pq.ms
  485. t_start_out = t_start * units
  486. t_stop_out = t_stop * units
  487. st_out = times
  488. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  489. dtype, units)
  490. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  491. dtype, units)
  492. def test__create_from_quantity_array_no_start_stop_units_with_dtype(self):
  493. times = np.arange(10, dtype='f4') * pq.ms
  494. t_start = 0.0
  495. t_stop = 12.0
  496. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  497. train2 = _new_spiketrain(SpikeTrain, times,
  498. t_start=t_start, t_stop=t_stop)
  499. dtype = np.float32
  500. units = 1 * pq.ms
  501. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  502. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  503. st_out = times.astype(dtype)
  504. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  505. dtype, units)
  506. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  507. dtype, units)
  508. def test__create_from_quantity_array_no_start_stop_units_set_dtype(self):
  509. times = np.arange(10) * pq.ms
  510. t_start = 0.0
  511. t_stop = 12.0
  512. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  513. dtype='f4')
  514. train2 = _new_spiketrain(SpikeTrain, times,
  515. t_start=t_start, t_stop=t_stop,
  516. dtype='f4')
  517. dtype = np.float32
  518. units = 1 * pq.ms
  519. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  520. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  521. st_out = times.astype(dtype)
  522. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  523. dtype, units)
  524. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  525. dtype, units)
  526. def test__create_from_quantity_array_units(self):
  527. times = np.arange(10) * pq.ms
  528. t_start = 0.0 * pq.s
  529. t_stop = 12.0 * pq.ms
  530. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units='s')
  531. train2 = _new_spiketrain(SpikeTrain, times,
  532. t_start=t_start, t_stop=t_stop, units='s')
  533. dtype = np.float64
  534. units = 1 * pq.s
  535. t_start_out = t_start
  536. t_stop_out = t_stop
  537. st_out = times
  538. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  539. dtype, units)
  540. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  541. dtype, units)
  542. def test__create_from_quantity_array_units_with_dtype(self):
  543. times = np.arange(10, dtype='f4') * pq.ms
  544. t_start = 0.0 * pq.s
  545. t_stop = 12.0 * pq.ms
  546. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  547. units='s')
  548. train2 = _new_spiketrain(SpikeTrain, times,
  549. t_start=t_start, t_stop=t_stop, units='s')
  550. dtype = np.float32
  551. units = 1 * pq.s
  552. t_start_out = t_start.astype(dtype)
  553. t_stop_out = t_stop.rescale(units).astype(dtype)
  554. st_out = times.rescale(units).astype(dtype)
  555. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  556. dtype, units)
  557. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  558. dtype, units)
  559. def test__create_from_quantity_array_units_set_dtype(self):
  560. times = np.arange(10) * pq.ms
  561. t_start = 0.0 * pq.s
  562. t_stop = 12.0 * pq.ms
  563. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  564. units='s', dtype='f4')
  565. train2 = _new_spiketrain(SpikeTrain, times,
  566. t_start=t_start, t_stop=t_stop,
  567. units='s', dtype='f4')
  568. dtype = np.float32
  569. units = 1 * pq.s
  570. t_start_out = t_start.astype(dtype)
  571. t_stop_out = t_stop.rescale(units).astype(dtype)
  572. st_out = times.rescale(units).astype(dtype)
  573. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  574. dtype, units)
  575. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  576. dtype, units)
  577. def test__create_from_quantity_array_units_no_start_stop_units(self):
  578. times = np.arange(10) * pq.ms
  579. t_start = 0.0
  580. t_stop = 12.0
  581. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units='s')
  582. train2 = _new_spiketrain(SpikeTrain, times,
  583. t_start=t_start, t_stop=t_stop, units='s')
  584. dtype = np.float64
  585. units = 1 * pq.s
  586. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  587. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  588. st_out = times
  589. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  590. dtype, units)
  591. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  592. dtype, units)
  593. def test__create_from_quantity_units_no_start_stop_units_set_dtype(self):
  594. times = np.arange(10) * pq.ms
  595. t_start = 0.0
  596. t_stop = 12.0
  597. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop,
  598. units='s', dtype='f4')
  599. train2 = _new_spiketrain(SpikeTrain, times,
  600. t_start=t_start, t_stop=t_stop,
  601. units='s', dtype='f4')
  602. dtype = np.float32
  603. units = 1 * pq.s
  604. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  605. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  606. st_out = times.rescale(units).astype(dtype)
  607. self.result_spike_check(train1, st_out, t_start_out, t_stop_out,
  608. dtype, units)
  609. self.result_spike_check(train2, st_out, t_start_out, t_stop_out,
  610. dtype, units)
  611. def test__create_from_list_without_units_should_raise_ValueError(self):
  612. times = range(10)
  613. t_start = 0.0 * pq.s
  614. t_stop = 10000.0 * pq.ms
  615. self.assertRaises(ValueError, SpikeTrain, times,
  616. t_start=t_start, t_stop=t_stop)
  617. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain, times,
  618. t_start=t_start, t_stop=t_stop)
  619. def test__create_from_array_without_units_should_raise_ValueError(self):
  620. times = np.arange(10)
  621. t_start = 0.0 * pq.s
  622. t_stop = 10000.0 * pq.ms
  623. self.assertRaises(ValueError, SpikeTrain, times,
  624. t_start=t_start, t_stop=t_stop)
  625. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain, times,
  626. t_start=t_start, t_stop=t_stop)
  627. def test__create_from_array_with_incompatible_units_ValueError(self):
  628. times = np.arange(10) * pq.km
  629. t_start = 0.0 * pq.s
  630. t_stop = 10000.0 * pq.ms
  631. self.assertRaises(ValueError, SpikeTrain, times,
  632. t_start=t_start, t_stop=t_stop)
  633. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain, times,
  634. t_start=t_start, t_stop=t_stop)
  635. def test__create_with_times_outside_tstart_tstop_ValueError(self):
  636. t_start = 23
  637. t_stop = 77
  638. train1 = SpikeTrain(np.arange(t_start, t_stop), units='ms',
  639. t_start=t_start, t_stop=t_stop)
  640. train2 = _new_spiketrain(SpikeTrain,
  641. np.arange(t_start, t_stop), units='ms',
  642. t_start=t_start, t_stop=t_stop)
  643. assert_neo_object_is_compliant(train1)
  644. assert_neo_object_is_compliant(train2)
  645. self.assertRaises(ValueError, SpikeTrain,
  646. np.arange(t_start - 5, t_stop), units='ms',
  647. t_start=t_start, t_stop=t_stop)
  648. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain,
  649. np.arange(t_start - 5, t_stop), units='ms',
  650. t_start=t_start, t_stop=t_stop)
  651. self.assertRaises(ValueError, SpikeTrain,
  652. np.arange(t_start, t_stop + 5), units='ms',
  653. t_start=t_start, t_stop=t_stop)
  654. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain,
  655. np.arange(t_start, t_stop + 5), units='ms',
  656. t_start=t_start, t_stop=t_stop)
  657. def test__create_with_len_times_different_size_than_waveform_shape1_ValueError(
  658. self):
  659. self.assertRaises(ValueError, SpikeTrain,
  660. times=np.arange(10), units='s',
  661. t_stop=4, waveforms=np.ones((10, 6, 50)))
  662. def test_defaults(self):
  663. # default recommended attributes
  664. train1 = SpikeTrain([3, 4, 5], units='sec', t_stop=10.0)
  665. train2 = _new_spiketrain(SpikeTrain, [3, 4, 5],
  666. units='sec', t_stop=10.0)
  667. assert_neo_object_is_compliant(train1)
  668. assert_neo_object_is_compliant(train2)
  669. self.assertEqual(train1.dtype, np.float)
  670. self.assertEqual(train2.dtype, np.float)
  671. self.assertEqual(train1.sampling_rate, 1.0 * pq.Hz)
  672. self.assertEqual(train2.sampling_rate, 1.0 * pq.Hz)
  673. self.assertEqual(train1.waveforms, None)
  674. self.assertEqual(train2.waveforms, None)
  675. self.assertEqual(train1.left_sweep, None)
  676. self.assertEqual(train2.left_sweep, None)
  677. def test_default_tstart(self):
  678. # t start defaults to zero
  679. train11 = SpikeTrain([3, 4, 5] * pq.s, t_stop=8000 * pq.ms)
  680. train21 = _new_spiketrain(SpikeTrain, [3, 4, 5] * pq.s,
  681. t_stop=8000 * pq.ms)
  682. assert_neo_object_is_compliant(train11)
  683. assert_neo_object_is_compliant(train21)
  684. self.assertEqual(train11.t_start, 0. * pq.s)
  685. self.assertEqual(train21.t_start, 0. * pq.s)
  686. # unless otherwise specified
  687. train12 = SpikeTrain([3, 4, 5] * pq.s, t_start=2.0, t_stop=8)
  688. train22 = _new_spiketrain(SpikeTrain, [3, 4, 5] * pq.s,
  689. t_start=2.0, t_stop=8)
  690. assert_neo_object_is_compliant(train12)
  691. assert_neo_object_is_compliant(train22)
  692. self.assertEqual(train12.t_start, 2. * pq.s)
  693. self.assertEqual(train22.t_start, 2. * pq.s)
  694. def test_tstop_units_conversion(self):
  695. train11 = SpikeTrain([3, 5, 4] * pq.s, t_stop=10)
  696. train21 = _new_spiketrain(SpikeTrain, [3, 5, 4] * pq.s, t_stop=10)
  697. assert_neo_object_is_compliant(train11)
  698. assert_neo_object_is_compliant(train21)
  699. self.assertEqual(train11.t_stop, 10. * pq.s)
  700. self.assertEqual(train21.t_stop, 10. * pq.s)
  701. train12 = SpikeTrain([3, 5, 4] * pq.s, t_stop=10000. * pq.ms)
  702. train22 = _new_spiketrain(SpikeTrain, [3, 5, 4] * pq.s,
  703. t_stop=10000. * pq.ms)
  704. assert_neo_object_is_compliant(train12)
  705. assert_neo_object_is_compliant(train22)
  706. self.assertEqual(train12.t_stop, 10. * pq.s)
  707. self.assertEqual(train22.t_stop, 10. * pq.s)
  708. train13 = SpikeTrain([3, 5, 4], units='sec', t_stop=10000. * pq.ms)
  709. train23 = _new_spiketrain(SpikeTrain, [3, 5, 4],
  710. units='sec', t_stop=10000. * pq.ms)
  711. assert_neo_object_is_compliant(train13)
  712. assert_neo_object_is_compliant(train23)
  713. self.assertEqual(train13.t_stop, 10. * pq.s)
  714. self.assertEqual(train23.t_stop, 10. * pq.s)
  715. class TestSorting(unittest.TestCase):
  716. def test_sort(self):
  717. waveforms = np.array([[[0., 1.]], [[2., 3.]], [[4., 5.]]]) * pq.mV
  718. train = SpikeTrain([3, 4, 5] * pq.s, waveforms=waveforms, name='n',
  719. t_stop=10.0)
  720. assert_neo_object_is_compliant(train)
  721. train.sort()
  722. assert_neo_object_is_compliant(train)
  723. assert_arrays_equal(train, [3, 4, 5] * pq.s)
  724. assert_arrays_equal(train.waveforms, waveforms)
  725. self.assertEqual(train.name, 'n')
  726. self.assertEqual(train.t_stop, 10.0 * pq.s)
  727. train = SpikeTrain([3, 5, 4] * pq.s, waveforms=waveforms, name='n',
  728. t_stop=10.0)
  729. assert_neo_object_is_compliant(train)
  730. train.sort()
  731. assert_neo_object_is_compliant(train)
  732. assert_arrays_equal(train, [3, 4, 5] * pq.s)
  733. assert_arrays_equal(train.waveforms, waveforms[[0, 2, 1]])
  734. self.assertEqual(train.name, 'n')
  735. self.assertEqual(train.t_start, 0.0 * pq.s)
  736. self.assertEqual(train.t_stop, 10.0 * pq.s)
  737. class TestSlice(unittest.TestCase):
  738. def setUp(self):
  739. self.waveforms1 = np.array([[[0., 1.],
  740. [0.1, 1.1]],
  741. [[2., 3.],
  742. [2.1, 3.1]],
  743. [[4., 5.],
  744. [4.1, 5.1]]]) * pq.mV
  745. self.data1 = np.array([3, 4, 5])
  746. self.data1quant = self.data1 * pq.s
  747. self.train1 = SpikeTrain(self.data1quant, waveforms=self.waveforms1,
  748. name='n', arb='arbb', t_stop=10.0)
  749. def test_compliant(self):
  750. assert_neo_object_is_compliant(self.train1)
  751. def test_slice(self):
  752. # slice spike train, keep sliced spike times
  753. result = self.train1[1:2]
  754. assert_arrays_equal(self.train1[1:2], result)
  755. targwaveforms = np.array([[[2., 3.],
  756. [2.1, 3.1]]]) * pq.mV
  757. # but keep everything else pristine
  758. assert_neo_object_is_compliant(result)
  759. self.assertEqual(self.train1.name, result.name)
  760. self.assertEqual(self.train1.description, result.description)
  761. self.assertEqual(self.train1.annotations, result.annotations)
  762. self.assertEqual(self.train1.file_origin, result.file_origin)
  763. self.assertEqual(self.train1.dtype, result.dtype)
  764. self.assertEqual(self.train1.t_start, result.t_start)
  765. self.assertEqual(self.train1.t_stop, result.t_stop)
  766. # except we update the waveforms
  767. assert_arrays_equal(self.train1.waveforms[1:2], result.waveforms)
  768. assert_arrays_equal(targwaveforms, result.waveforms)
  769. def test_slice_to_end(self):
  770. # slice spike train, keep sliced spike times
  771. result = self.train1[1:]
  772. assert_arrays_equal(self.train1[1:], result)
  773. targwaveforms = np.array([[[2., 3.],
  774. [2.1, 3.1]],
  775. [[4., 5.],
  776. [4.1, 5.1]]]) * pq.mV
  777. # but keep everything else pristine
  778. assert_neo_object_is_compliant(result)
  779. self.assertEqual(self.train1.name, result.name)
  780. self.assertEqual(self.train1.description, result.description)
  781. self.assertEqual(self.train1.annotations, result.annotations)
  782. self.assertEqual(self.train1.file_origin, result.file_origin)
  783. self.assertEqual(self.train1.dtype, result.dtype)
  784. self.assertEqual(self.train1.t_start, result.t_start)
  785. self.assertEqual(self.train1.t_stop, result.t_stop)
  786. # except we update the waveforms
  787. assert_arrays_equal(self.train1.waveforms[1:], result.waveforms)
  788. assert_arrays_equal(targwaveforms, result.waveforms)
  789. def test_slice_from_beginning(self):
  790. # slice spike train, keep sliced spike times
  791. result = self.train1[:2]
  792. assert_arrays_equal(self.train1[:2], result)
  793. targwaveforms = np.array([[[0., 1.],
  794. [0.1, 1.1]],
  795. [[2., 3.],
  796. [2.1, 3.1]]]) * pq.mV
  797. # but keep everything else pristine
  798. assert_neo_object_is_compliant(result)
  799. self.assertEqual(self.train1.name, result.name)
  800. self.assertEqual(self.train1.description, result.description)
  801. self.assertEqual(self.train1.annotations, result.annotations)
  802. self.assertEqual(self.train1.file_origin, result.file_origin)
  803. self.assertEqual(self.train1.dtype, result.dtype)
  804. self.assertEqual(self.train1.t_start, result.t_start)
  805. self.assertEqual(self.train1.t_stop, result.t_stop)
  806. # except we update the waveforms
  807. assert_arrays_equal(self.train1.waveforms[:2], result.waveforms)
  808. assert_arrays_equal(targwaveforms, result.waveforms)
  809. def test_slice_negative_idxs(self):
  810. # slice spike train, keep sliced spike times
  811. result = self.train1[:-1]
  812. assert_arrays_equal(self.train1[:-1], result)
  813. targwaveforms = np.array([[[0., 1.],
  814. [0.1, 1.1]],
  815. [[2., 3.],
  816. [2.1, 3.1]]]) * pq.mV
  817. # but keep everything else pristine
  818. assert_neo_object_is_compliant(result)
  819. self.assertEqual(self.train1.name, result.name)
  820. self.assertEqual(self.train1.description, result.description)
  821. self.assertEqual(self.train1.annotations, result.annotations)
  822. self.assertEqual(self.train1.file_origin, result.file_origin)
  823. self.assertEqual(self.train1.dtype, result.dtype)
  824. self.assertEqual(self.train1.t_start, result.t_start)
  825. self.assertEqual(self.train1.t_stop, result.t_stop)
  826. # except we update the waveforms
  827. assert_arrays_equal(self.train1.waveforms[:-1], result.waveforms)
  828. assert_arrays_equal(targwaveforms, result.waveforms)
  829. class TestTimeSlice(unittest.TestCase):
  830. def setUp(self):
  831. self.waveforms1 = np.array([[[0., 1.],
  832. [0.1, 1.1]],
  833. [[2., 3.],
  834. [2.1, 3.1]],
  835. [[4., 5.],
  836. [4.1, 5.1]],
  837. [[6., 7.],
  838. [6.1, 7.1]],
  839. [[8., 9.],
  840. [8.1, 9.1]],
  841. [[10., 11.],
  842. [10.1, 11.1]]]) * pq.mV
  843. self.data1 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7])
  844. self.data1quant = self.data1 * pq.ms
  845. self.train1 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms,
  846. waveforms=self.waveforms1)
  847. def test_compliant(self):
  848. assert_neo_object_is_compliant(self.train1)
  849. def test_time_slice_typical(self):
  850. # time_slice spike train, keep sliced spike times
  851. # this is the typical time slice falling somewhere
  852. # in the middle of spikes
  853. t_start = 0.12 * pq.ms
  854. t_stop = 3.5 * pq.ms
  855. result = self.train1.time_slice(t_start, t_stop)
  856. targ = SpikeTrain([0.5, 1.2, 3.3] * pq.ms, t_stop=3.3)
  857. assert_arrays_equal(result, targ)
  858. targwaveforms = np.array([[[2., 3.],
  859. [2.1, 3.1]],
  860. [[4., 5.],
  861. [4.1, 5.1]],
  862. [[6., 7.],
  863. [6.1, 7.1]]]) * pq.mV
  864. assert_arrays_equal(targwaveforms, result.waveforms)
  865. # but keep everything else pristine
  866. assert_neo_object_is_compliant(result)
  867. self.assertEqual(self.train1.name, result.name)
  868. self.assertEqual(self.train1.description, result.description)
  869. self.assertEqual(self.train1.annotations, result.annotations)
  870. self.assertEqual(self.train1.file_origin, result.file_origin)
  871. self.assertEqual(self.train1.dtype, result.dtype)
  872. self.assertEqual(t_start, result.t_start)
  873. self.assertEqual(t_stop, result.t_stop)
  874. def test_time_slice_differnt_units(self):
  875. # time_slice spike train, keep sliced spike times
  876. t_start = 0.00012 * pq.s
  877. t_stop = 0.0035 * pq.s
  878. result = self.train1.time_slice(t_start, t_stop)
  879. targ = SpikeTrain([0.5, 1.2, 3.3] * pq.ms, t_stop=3.3)
  880. assert_arrays_equal(result, targ)
  881. targwaveforms = np.array([[[2., 3.],
  882. [2.1, 3.1]],
  883. [[4., 5.],
  884. [4.1, 5.1]],
  885. [[6., 7.],
  886. [6.1, 7.1]]]) * pq.mV
  887. assert_arrays_equal(targwaveforms, result.waveforms)
  888. # but keep everything else pristine
  889. assert_neo_object_is_compliant(result)
  890. self.assertEqual(self.train1.name, result.name)
  891. self.assertEqual(self.train1.description, result.description)
  892. self.assertEqual(self.train1.annotations, result.annotations)
  893. self.assertEqual(self.train1.file_origin, result.file_origin)
  894. self.assertEqual(self.train1.dtype, result.dtype)
  895. self.assertEqual(t_start, result.t_start)
  896. self.assertEqual(t_stop, result.t_stop)
  897. def test_time_slice_matching_ends(self):
  898. # time_slice spike train, keep sliced spike times
  899. t_start = 0.1 * pq.ms
  900. t_stop = 7.0 * pq.ms
  901. result = self.train1.time_slice(t_start, t_stop)
  902. assert_arrays_equal(self.train1, result)
  903. assert_arrays_equal(self.waveforms1, result.waveforms)
  904. # but keep everything else pristine
  905. assert_neo_object_is_compliant(result)
  906. self.assertEqual(self.train1.name, result.name)
  907. self.assertEqual(self.train1.description, result.description)
  908. self.assertEqual(self.train1.annotations, result.annotations)
  909. self.assertEqual(self.train1.file_origin, result.file_origin)
  910. self.assertEqual(self.train1.dtype, result.dtype)
  911. self.assertEqual(t_start, result.t_start)
  912. self.assertEqual(t_stop, result.t_stop)
  913. def test_time_slice_out_of_boundries(self):
  914. self.train1.t_start = 0.1 * pq.ms
  915. assert_neo_object_is_compliant(self.train1)
  916. # time_slice spike train, keep sliced spike times
  917. t_start = 0.01 * pq.ms
  918. t_stop = 70.0 * pq.ms
  919. result = self.train1.time_slice(t_start, t_stop)
  920. assert_arrays_equal(self.train1, result)
  921. assert_arrays_equal(self.waveforms1, result.waveforms)
  922. # but keep everything else pristine
  923. assert_neo_object_is_compliant(result)
  924. self.assertEqual(self.train1.name, result.name)
  925. self.assertEqual(self.train1.description, result.description)
  926. self.assertEqual(self.train1.annotations, result.annotations)
  927. self.assertEqual(self.train1.file_origin, result.file_origin)
  928. self.assertEqual(self.train1.dtype, result.dtype)
  929. self.assertEqual(self.train1.t_start, result.t_start)
  930. self.assertEqual(self.train1.t_stop, result.t_stop)
  931. def test_time_slice_empty(self):
  932. waveforms = np.array([[[]]]) * pq.mV
  933. train = SpikeTrain([] * pq.ms, t_stop=10.0, waveforms=waveforms)
  934. assert_neo_object_is_compliant(train)
  935. # time_slice spike train, keep sliced spike times
  936. t_start = 0.01 * pq.ms
  937. t_stop = 70.0 * pq.ms
  938. result = train.time_slice(t_start, t_stop)
  939. assert_arrays_equal(train, result)
  940. assert_arrays_equal(waveforms[:-1], result.waveforms)
  941. # but keep everything else pristine
  942. assert_neo_object_is_compliant(result)
  943. self.assertEqual(train.name, result.name)
  944. self.assertEqual(train.description, result.description)
  945. self.assertEqual(train.annotations, result.annotations)
  946. self.assertEqual(train.file_origin, result.file_origin)
  947. self.assertEqual(train.dtype, result.dtype)
  948. self.assertEqual(t_start, result.t_start)
  949. self.assertEqual(train.t_stop, result.t_stop)
  950. def test_time_slice_none_stop(self):
  951. # time_slice spike train, keep sliced spike times
  952. t_start = 1 * pq.ms
  953. result = self.train1.time_slice(t_start, None)
  954. assert_arrays_equal([1.2, 3.3, 6.4, 7] * pq.ms, result)
  955. targwaveforms = np.array([[[4., 5.],
  956. [4.1, 5.1]],
  957. [[6., 7.],
  958. [6.1, 7.1]],
  959. [[8., 9.],
  960. [8.1, 9.1]],
  961. [[10., 11.],
  962. [10.1, 11.1]]]) * pq.mV
  963. assert_arrays_equal(targwaveforms, result.waveforms)
  964. # but keep everything else pristine
  965. assert_neo_object_is_compliant(result)
  966. self.assertEqual(self.train1.name, result.name)
  967. self.assertEqual(self.train1.description, result.description)
  968. self.assertEqual(self.train1.annotations, result.annotations)
  969. self.assertEqual(self.train1.file_origin, result.file_origin)
  970. self.assertEqual(self.train1.dtype, result.dtype)
  971. self.assertEqual(t_start, result.t_start)
  972. self.assertEqual(self.train1.t_stop, result.t_stop)
  973. def test_time_slice_none_start(self):
  974. # time_slice spike train, keep sliced spike times
  975. t_stop = 1 * pq.ms
  976. result = self.train1.time_slice(None, t_stop)
  977. assert_arrays_equal([0.1, 0.5] * pq.ms, result)
  978. targwaveforms = np.array([[[0., 1.],
  979. [0.1, 1.1]],
  980. [[2., 3.],
  981. [2.1, 3.1]]]) * pq.mV
  982. assert_arrays_equal(targwaveforms, result.waveforms)
  983. # but keep everything else pristine
  984. assert_neo_object_is_compliant(result)
  985. self.assertEqual(self.train1.name, result.name)
  986. self.assertEqual(self.train1.description, result.description)
  987. self.assertEqual(self.train1.annotations, result.annotations)
  988. self.assertEqual(self.train1.file_origin, result.file_origin)
  989. self.assertEqual(self.train1.dtype, result.dtype)
  990. self.assertEqual(self.train1.t_start, result.t_start)
  991. self.assertEqual(t_stop, result.t_stop)
  992. def test_time_slice_none_both(self):
  993. self.train1.t_start = 0.1 * pq.ms
  994. assert_neo_object_is_compliant(self.train1)
  995. # time_slice spike train, keep sliced spike times
  996. result = self.train1.time_slice(None, None)
  997. assert_arrays_equal(self.train1, result)
  998. assert_arrays_equal(self.waveforms1, result.waveforms)
  999. # but keep everything else pristine
  1000. assert_neo_object_is_compliant(result)
  1001. self.assertEqual(self.train1.name, result.name)
  1002. self.assertEqual(self.train1.description, result.description)
  1003. self.assertEqual(self.train1.annotations, result.annotations)
  1004. self.assertEqual(self.train1.file_origin, result.file_origin)
  1005. self.assertEqual(self.train1.dtype, result.dtype)
  1006. self.assertEqual(self.train1.t_start, result.t_start)
  1007. self.assertEqual(self.train1.t_stop, result.t_stop)
  1008. class TestMerge(unittest.TestCase):
  1009. def setUp(self):
  1010. self.waveforms1 = np.array([[[0., 1.],
  1011. [0.1, 1.1]],
  1012. [[2., 3.],
  1013. [2.1, 3.1]],
  1014. [[4., 5.],
  1015. [4.1, 5.1]],
  1016. [[6., 7.],
  1017. [6.1, 7.1]],
  1018. [[8., 9.],
  1019. [8.1, 9.1]],
  1020. [[10., 11.],
  1021. [10.1, 11.1]]]) * pq.mV
  1022. self.data1 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7])
  1023. self.data1quant = self.data1 * pq.ms
  1024. self.train1 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms,
  1025. waveforms=self.waveforms1)
  1026. self.waveforms2 = np.array([[[0., 1.],
  1027. [0.1, 1.1]],
  1028. [[2., 3.],
  1029. [2.1, 3.1]],
  1030. [[4., 5.],
  1031. [4.1, 5.1]],
  1032. [[6., 7.],
  1033. [6.1, 7.1]],
  1034. [[8., 9.],
  1035. [8.1, 9.1]],
  1036. [[10., 11.],
  1037. [10.1, 11.1]]]) * pq.mV
  1038. self.data2 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7])
  1039. self.data2quant = self.data1 * pq.ms
  1040. self.train2 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms,
  1041. waveforms=self.waveforms1)
  1042. self.segment = Segment()
  1043. self.segment.spiketrains.extend([self.train1, self.train2])
  1044. self.train1.segment = self.segment
  1045. self.train2.segment = self.segment
  1046. def test_compliant(self):
  1047. assert_neo_object_is_compliant(self.train1)
  1048. assert_neo_object_is_compliant(self.train2)
  1049. def test_merge_typical(self):
  1050. self.train1.waveforms = None
  1051. self.train2.waveforms = None
  1052. result = self.train1.merge(self.train2)
  1053. assert_neo_object_is_compliant(result)
  1054. def test_merge_with_waveforms(self):
  1055. result = self.train1.merge(self.train2)
  1056. assert_neo_object_is_compliant(result)
  1057. def test_correct_shape(self):
  1058. result = self.train1.merge(self.train2)
  1059. self.assertEqual(len(result.shape), 1)
  1060. self.assertEqual(result.shape[0],
  1061. self.train1.shape[0] + self.train2.shape[0])
  1062. def test_correct_times(self):
  1063. result = self.train1.merge(self.train2)
  1064. expected = sorted(np.concatenate((self.train1.times,
  1065. self.train2.times)))
  1066. np.testing.assert_array_equal(result, expected)
  1067. def test_rescaling_units(self):
  1068. train3 = self.train1.duplicate_with_new_data(
  1069. self.train1.times.magnitude * pq.microsecond)
  1070. train3.segment = self.train1.segment
  1071. result = train3.merge(self.train2)
  1072. time_unit = result.units
  1073. expected = sorted(np.concatenate((train3.rescale(time_unit).times,
  1074. self.train2.rescale(
  1075. time_unit).times)))
  1076. expected = expected * time_unit
  1077. np.testing.assert_array_equal(result.rescale(time_unit), expected)
  1078. def test_sampling_rate(self):
  1079. result = self.train1.merge(self.train2)
  1080. self.assertEqual(result.sampling_rate, self.train1.sampling_rate)
  1081. def test_neo_relations(self):
  1082. result = self.train1.merge(self.train2)
  1083. self.assertEqual(self.train1.segment, result.segment)
  1084. self.assertTrue(result in result.segment.spiketrains)
  1085. def test_missing_waveforms_error(self):
  1086. self.train1.waveforms = None
  1087. with self.assertRaises(MergeError):
  1088. self.train1.merge(self.train2)
  1089. with self.assertRaises(MergeError):
  1090. self.train2.merge(self.train1)
  1091. def test_incompatible_t_start(self):
  1092. train3 = self.train1.duplicate_with_new_data(self.train1,
  1093. t_start=-1 * pq.s)
  1094. train3.segment = self.train1.segment
  1095. with self.assertRaises(MergeError):
  1096. train3.merge(self.train2)
  1097. with self.assertRaises(MergeError):
  1098. self.train2.merge(train3)
  1099. class TestDuplicateWithNewData(unittest.TestCase):
  1100. def setUp(self):
  1101. self.waveforms = np.array([[[0., 1.],
  1102. [0.1, 1.1]],
  1103. [[2., 3.],
  1104. [2.1, 3.1]],
  1105. [[4., 5.],
  1106. [4.1, 5.1]],
  1107. [[6., 7.],
  1108. [6.1, 7.1]],
  1109. [[8., 9.],
  1110. [8.1, 9.1]],
  1111. [[10., 11.],
  1112. [10.1, 11.1]]]) * pq.mV
  1113. self.data = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7])
  1114. self.dataquant = self.data * pq.ms
  1115. self.train = SpikeTrain(self.dataquant, t_stop=10.0 * pq.ms,
  1116. waveforms=self.waveforms)
  1117. def test_duplicate_with_new_data(self):
  1118. signal1 = self.train
  1119. new_t_start = -10 * pq.s
  1120. new_t_stop = 10 * pq.s
  1121. new_data = np.sort(np.random.uniform(new_t_start.magnitude,
  1122. new_t_stop.magnitude,
  1123. len(self.train))) * pq.ms
  1124. signal1b = signal1.duplicate_with_new_data(new_data,
  1125. t_start=new_t_start,
  1126. t_stop=new_t_stop)
  1127. assert_arrays_almost_equal(np.asarray(signal1b),
  1128. np.asarray(new_data), 1e-12)
  1129. self.assertEqual(signal1b.t_start, new_t_start)
  1130. self.assertEqual(signal1b.t_stop, new_t_stop)
  1131. self.assertEqual(signal1b.sampling_rate, signal1.sampling_rate)
  1132. def test_deep_copy_attributes(self):
  1133. signal1 = self.train
  1134. new_t_start = -10*pq.s
  1135. new_t_stop = 10*pq.s
  1136. new_data = np.sort(np.random.uniform(new_t_start.magnitude,
  1137. new_t_stop.magnitude,
  1138. len(self.train))) * pq.ms
  1139. signal1b = signal1.duplicate_with_new_data(new_data,
  1140. t_start=new_t_start,
  1141. t_stop=new_t_stop)
  1142. signal1.annotate(new_annotation='for signal 1')
  1143. self.assertTrue('new_annotation' not in signal1b.annotations)
  1144. class TestAttributesAnnotations(unittest.TestCase):
  1145. def test_set_universally_recommended_attributes(self):
  1146. train = SpikeTrain([3, 4, 5], units='sec', name='Name',
  1147. description='Desc', file_origin='crack.txt',
  1148. t_stop=99.9)
  1149. assert_neo_object_is_compliant(train)
  1150. self.assertEqual(train.name, 'Name')
  1151. self.assertEqual(train.description, 'Desc')
  1152. self.assertEqual(train.file_origin, 'crack.txt')
  1153. def test_autoset_universally_recommended_attributes(self):
  1154. train = SpikeTrain([3, 4, 5] * pq.s, t_stop=10.0)
  1155. assert_neo_object_is_compliant(train)
  1156. self.assertEqual(train.name, None)
  1157. self.assertEqual(train.description, None)
  1158. self.assertEqual(train.file_origin, None)
  1159. def test_annotations(self):
  1160. train = SpikeTrain([3, 4, 5] * pq.s, t_stop=11.1)
  1161. assert_neo_object_is_compliant(train)
  1162. self.assertEqual(train.annotations, {})
  1163. train = SpikeTrain([3, 4, 5] * pq.s, t_stop=11.1, ratname='Phillippe')
  1164. assert_neo_object_is_compliant(train)
  1165. self.assertEqual(train.annotations, {'ratname': 'Phillippe'})
  1166. class TestChanging(unittest.TestCase):
  1167. def test_change_with_copy_default(self):
  1168. # Default is copy = True
  1169. # Changing spike train does not change data
  1170. # Data source is quantity
  1171. data = [3, 4, 5] * pq.s
  1172. train = SpikeTrain(data, t_stop=100.0)
  1173. train[0] = 99 * pq.s
  1174. assert_neo_object_is_compliant(train)
  1175. self.assertEqual(train[0], 99 * pq.s)
  1176. self.assertEqual(data[0], 3 * pq.s)
  1177. def test_change_with_copy_false(self):
  1178. # Changing spike train also changes data, because it is a view
  1179. # Data source is quantity
  1180. data = [3, 4, 5] * pq.s
  1181. train = SpikeTrain(data, copy=False, t_stop=100.0)
  1182. train[0] = 99 * pq.s
  1183. assert_neo_object_is_compliant(train)
  1184. self.assertEqual(train[0], 99 * pq.s)
  1185. self.assertEqual(data[0], 99 * pq.s)
  1186. def test_change_with_copy_false_and_fake_rescale(self):
  1187. # Changing spike train also changes data, because it is a view
  1188. # Data source is quantity
  1189. data = [3000, 4000, 5000] * pq.ms
  1190. # even though we specify units, it still returns a view
  1191. train = SpikeTrain(data, units='ms', copy=False, t_stop=100000)
  1192. train[0] = 99000 * pq.ms
  1193. assert_neo_object_is_compliant(train)
  1194. self.assertEqual(train[0], 99000 * pq.ms)
  1195. self.assertEqual(data[0], 99000 * pq.ms)
  1196. def test_change_with_copy_false_and_rescale_true(self):
  1197. # When rescaling, a view cannot be returned
  1198. # Changing spike train also changes data, because it is a view
  1199. data = [3, 4, 5] * pq.s
  1200. self.assertRaises(ValueError, SpikeTrain, data, units='ms',
  1201. copy=False, t_stop=10000)
  1202. def test_init_with_rescale(self):
  1203. data = [3, 4, 5] * pq.s
  1204. train = SpikeTrain(data, units='ms', t_stop=6000)
  1205. assert_neo_object_is_compliant(train)
  1206. self.assertEqual(train[0], 3000 * pq.ms)
  1207. self.assertEqual(train._dimensionality, pq.ms._dimensionality)
  1208. self.assertEqual(train.t_stop, 6000 * pq.ms)
  1209. def test_change_with_copy_true(self):
  1210. # Changing spike train does not change data
  1211. # Data source is quantity
  1212. data = [3, 4, 5] * pq.s
  1213. train = SpikeTrain(data, copy=True, t_stop=100)
  1214. train[0] = 99 * pq.s
  1215. assert_neo_object_is_compliant(train)
  1216. self.assertEqual(train[0], 99 * pq.s)
  1217. self.assertEqual(data[0], 3 * pq.s)
  1218. def test_change_with_copy_default_and_data_not_quantity(self):
  1219. # Default is copy = True
  1220. # Changing spike train does not change data
  1221. # Data source is array
  1222. # Array and quantity are tested separately because copy default
  1223. # is different for these two.
  1224. data = [3, 4, 5]
  1225. train = SpikeTrain(data, units='sec', t_stop=100)
  1226. train[0] = 99 * pq.s
  1227. assert_neo_object_is_compliant(train)
  1228. self.assertEqual(train[0], 99 * pq.s)
  1229. self.assertEqual(data[0], 3 * pq.s)
  1230. def test_change_with_copy_false_and_data_not_quantity(self):
  1231. # Changing spike train also changes data, because it is a view
  1232. # Data source is array
  1233. # Array and quantity are tested separately because copy default
  1234. # is different for these two.
  1235. data = np.array([3, 4, 5])
  1236. train = SpikeTrain(data, units='sec', copy=False, dtype=np.int,
  1237. t_stop=101)
  1238. train[0] = 99 * pq.s
  1239. assert_neo_object_is_compliant(train)
  1240. self.assertEqual(train[0], 99 * pq.s)
  1241. self.assertEqual(data[0], 99)
  1242. def test_change_with_copy_false_and_dtype_change(self):
  1243. # You cannot change dtype and request a view
  1244. data = np.array([3, 4, 5])
  1245. self.assertRaises(ValueError, SpikeTrain, data, units='sec',
  1246. copy=False, t_stop=101, dtype=np.float64)
  1247. def test_change_with_copy_true_and_data_not_quantity(self):
  1248. # Changing spike train does not change data
  1249. # Data source is array
  1250. # Array and quantity are tested separately because copy default
  1251. # is different for these two.
  1252. data = [3, 4, 5]
  1253. train = SpikeTrain(data, units='sec', copy=True, t_stop=123.4)
  1254. train[0] = 99 * pq.s
  1255. assert_neo_object_is_compliant(train)
  1256. self.assertEqual(train[0], 99 * pq.s)
  1257. self.assertEqual(data[0], 3)
  1258. def test_changing_slice_changes_original_spiketrain(self):
  1259. # If we slice a spiketrain and then change the slice, the
  1260. # original spiketrain should change.
  1261. # Whether the original data source changes is dependent on the
  1262. # copy parameter.
  1263. # This is compatible with both np and quantity default behavior.
  1264. data = [3, 4, 5] * pq.s
  1265. train = SpikeTrain(data, copy=True, t_stop=99.9)
  1266. result = train[1:3]
  1267. result[0] = 99 * pq.s
  1268. assert_neo_object_is_compliant(train)
  1269. self.assertEqual(train[1], 99 * pq.s)
  1270. self.assertEqual(result[0], 99 * pq.s)
  1271. self.assertEqual(data[1], 4 * pq.s)
  1272. def test_changing_slice_changes_original_spiketrain_with_copy_false(self):
  1273. # If we slice a spiketrain and then change the slice, the
  1274. # original spiketrain should change.
  1275. # Whether the original data source changes is dependent on the
  1276. # copy parameter.
  1277. # This is compatible with both np and quantity default behavior.
  1278. data = [3, 4, 5] * pq.s
  1279. train = SpikeTrain(data, copy=False, t_stop=100.0)
  1280. result = train[1:3]
  1281. result[0] = 99 * pq.s
  1282. assert_neo_object_is_compliant(train)
  1283. assert_neo_object_is_compliant(result)
  1284. self.assertEqual(train[1], 99 * pq.s)
  1285. self.assertEqual(result[0], 99 * pq.s)
  1286. self.assertEqual(data[1], 99 * pq.s)
  1287. def test__changing_spiketime_should_check_time_in_range(self):
  1288. data = [3, 4, 5] * pq.ms
  1289. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1290. assert_neo_object_is_compliant(train)
  1291. self.assertRaises(ValueError, train.__setitem__, 0, 10.1 * pq.ms)
  1292. self.assertRaises(ValueError, train.__setitem__, 1, 5.0 * pq.s)
  1293. self.assertRaises(ValueError, train.__setitem__, 2, 5.0 * pq.s)
  1294. self.assertRaises(ValueError, train.__setitem__, 0, 0)
  1295. def test__changing_multiple_spiketimes(self):
  1296. data = [3, 4, 5] * pq.ms
  1297. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1298. train[:] = [7, 8, 9] * pq.ms
  1299. assert_neo_object_is_compliant(train)
  1300. assert_arrays_equal(train, np.array([7, 8, 9]))
  1301. def test__changing_multiple_spiketimes_should_check_time_in_range(self):
  1302. data = [3, 4, 5] * pq.ms
  1303. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1304. assert_neo_object_is_compliant(train)
  1305. if sys.version_info[0] == 2:
  1306. self.assertRaises(ValueError, train.__setslice__,
  1307. 0, 3, [3, 4, 11] * pq.ms)
  1308. self.assertRaises(ValueError, train.__setslice__,
  1309. 0, 3, [0, 4, 5] * pq.ms)
  1310. def test__adding_time(self):
  1311. data = [3, 4, 5] * pq.ms
  1312. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1313. assert_neo_object_is_compliant(train)
  1314. self.assertRaises(ValueError, train.__add__, 10 * pq.ms)
  1315. assert_arrays_equal(train + 1 * pq.ms, data + 1 * pq.ms)
  1316. def test__subtracting_time(self):
  1317. data = [3, 4, 5] * pq.ms
  1318. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1319. assert_neo_object_is_compliant(train)
  1320. self.assertRaises(ValueError, train.__sub__, 10 * pq.ms)
  1321. assert_arrays_equal(train - 1 * pq.ms, data - 1 * pq.ms)
  1322. def test__rescale(self):
  1323. data = [3, 4, 5] * pq.ms
  1324. train = SpikeTrain(data, t_start=0.5, t_stop=10.0)
  1325. train.segment = Segment()
  1326. train.unit = Unit()
  1327. result = train.rescale(pq.s)
  1328. assert_neo_object_is_compliant(train)
  1329. assert_neo_object_is_compliant(result)
  1330. assert_arrays_equal(train, result)
  1331. self.assertEqual(result.units, 1 * pq.s)
  1332. self.assertIs(result.segment, train.segment)
  1333. self.assertIs(result.unit, train.unit)
  1334. def test__rescale_same_units(self):
  1335. data = [3, 4, 5] * pq.ms
  1336. train = SpikeTrain(data, t_start=0.5, t_stop=10.0)
  1337. result = train.rescale(pq.ms)
  1338. assert_neo_object_is_compliant(train)
  1339. assert_arrays_equal(train, result)
  1340. self.assertEqual(result.units, 1 * pq.ms)
  1341. def test__rescale_incompatible_units_ValueError(self):
  1342. data = [3, 4, 5] * pq.ms
  1343. train = SpikeTrain(data, t_start=0.5, t_stop=10.0)
  1344. assert_neo_object_is_compliant(train)
  1345. self.assertRaises(ValueError, train.rescale, pq.m)
  1346. class TestPropertiesMethods(unittest.TestCase):
  1347. def setUp(self):
  1348. self.data1 = [3, 4, 5]
  1349. self.data1quant = self.data1 * pq.ms
  1350. self.waveforms1 = np.array([[[0., 1.],
  1351. [0.1, 1.1]],
  1352. [[2., 3.],
  1353. [2.1, 3.1]],
  1354. [[4., 5.],
  1355. [4.1, 5.1]]]) * pq.mV
  1356. self.t_start1 = 0.5
  1357. self.t_stop1 = 10.0
  1358. self.t_start1quant = self.t_start1 * pq.ms
  1359. self.t_stop1quant = self.t_stop1 * pq.ms
  1360. self.sampling_rate1 = .1 * pq.Hz
  1361. self.left_sweep1 = 2. * pq.s
  1362. self.name1 = 'train 1'
  1363. self.description1 = 'a test object'
  1364. self.ann1 = {'targ0': [1, 2], 'targ1': 1.1}
  1365. self.train1 = SpikeTrain(self.data1quant,
  1366. t_start=self.t_start1, t_stop=self.t_stop1,
  1367. waveforms=self.waveforms1,
  1368. left_sweep=self.left_sweep1,
  1369. sampling_rate=self.sampling_rate1,
  1370. name=self.name1,
  1371. description=self.description1,
  1372. **self.ann1)
  1373. def test__compliant(self):
  1374. assert_neo_object_is_compliant(self.train1)
  1375. def test__repr(self):
  1376. result = repr(self.train1)
  1377. targ = '<SpikeTrain(array([ 3., 4., 5.]) * ms, [0.5 ms, 10.0 ms])>'
  1378. self.assertEqual(result, targ)
  1379. def test__duration(self):
  1380. result1 = self.train1.duration
  1381. self.train1.t_start = None
  1382. assert_neo_object_is_compliant(self.train1)
  1383. result2 = self.train1.duration
  1384. self.train1.t_start = self.t_start1quant
  1385. self.train1.t_stop = None
  1386. assert_neo_object_is_compliant(self.train1)
  1387. result3 = self.train1.duration
  1388. self.assertEqual(result1, 9.5 * pq.ms)
  1389. self.assertEqual(result1.units, 1. * pq.ms)
  1390. self.assertEqual(result2, None)
  1391. self.assertEqual(result3, None)
  1392. def test__spike_duration(self):
  1393. result1 = self.train1.spike_duration
  1394. self.train1.sampling_rate = None
  1395. assert_neo_object_is_compliant(self.train1)
  1396. result2 = self.train1.spike_duration
  1397. self.train1.sampling_rate = self.sampling_rate1
  1398. self.train1.waveforms = None
  1399. assert_neo_object_is_compliant(self.train1)
  1400. result3 = self.train1.spike_duration
  1401. self.assertEqual(result1, 20. / pq.Hz)
  1402. self.assertEqual(result1.units, 1. / pq.Hz)
  1403. self.assertEqual(result2, None)
  1404. self.assertEqual(result3, None)
  1405. def test__sampling_period(self):
  1406. result1 = self.train1.sampling_period
  1407. self.train1.sampling_rate = None
  1408. assert_neo_object_is_compliant(self.train1)
  1409. result2 = self.train1.sampling_period
  1410. self.train1.sampling_rate = self.sampling_rate1
  1411. self.train1.sampling_period = 10. * pq.ms
  1412. assert_neo_object_is_compliant(self.train1)
  1413. result3a = self.train1.sampling_period
  1414. result3b = self.train1.sampling_rate
  1415. self.train1.sampling_period = None
  1416. result4a = self.train1.sampling_period
  1417. result4b = self.train1.sampling_rate
  1418. self.assertEqual(result1, 10. / pq.Hz)
  1419. self.assertEqual(result1.units, 1. / pq.Hz)
  1420. self.assertEqual(result2, None)
  1421. self.assertEqual(result3a, 10. * pq.ms)
  1422. self.assertEqual(result3a.units, 1. * pq.ms)
  1423. self.assertEqual(result3b, .1 / pq.ms)
  1424. self.assertEqual(result3b.units, 1. / pq.ms)
  1425. self.assertEqual(result4a, None)
  1426. self.assertEqual(result4b, None)
  1427. def test__right_sweep(self):
  1428. result1 = self.train1.right_sweep
  1429. self.train1.left_sweep = None
  1430. assert_neo_object_is_compliant(self.train1)
  1431. result2 = self.train1.right_sweep
  1432. self.train1.left_sweep = self.left_sweep1
  1433. self.train1.sampling_rate = None
  1434. assert_neo_object_is_compliant(self.train1)
  1435. result3 = self.train1.right_sweep
  1436. self.train1.sampling_rate = self.sampling_rate1
  1437. self.train1.waveforms = None
  1438. assert_neo_object_is_compliant(self.train1)
  1439. result4 = self.train1.right_sweep
  1440. self.assertEqual(result1, 22. * pq.s)
  1441. self.assertEqual(result1.units, 1. * pq.s)
  1442. self.assertEqual(result2, None)
  1443. self.assertEqual(result3, None)
  1444. self.assertEqual(result4, None)
  1445. def test__children(self):
  1446. segment = Segment(name='seg1')
  1447. segment.spiketrains = [self.train1]
  1448. segment.create_many_to_one_relationship()
  1449. unit = Unit(name='unit1')
  1450. unit.spiketrains = [self.train1]
  1451. unit.create_many_to_one_relationship()
  1452. self.assertEqual(self.train1._single_parent_objects,
  1453. ('Segment', 'Unit'))
  1454. self.assertEqual(self.train1._multi_parent_objects, ())
  1455. self.assertEqual(self.train1._single_parent_containers,
  1456. ('segment', 'unit'))
  1457. self.assertEqual(self.train1._multi_parent_containers, ())
  1458. self.assertEqual(self.train1._parent_objects,
  1459. ('Segment', 'Unit'))
  1460. self.assertEqual(self.train1._parent_containers,
  1461. ('segment', 'unit'))
  1462. self.assertEqual(len(self.train1.parents), 2)
  1463. self.assertEqual(self.train1.parents[0].name, 'seg1')
  1464. self.assertEqual(self.train1.parents[1].name, 'unit1')
  1465. assert_neo_object_is_compliant(self.train1)
  1466. @unittest.skipUnless(HAVE_IPYTHON, "requires IPython")
  1467. def test__pretty(self):
  1468. res = pretty(self.train1)
  1469. targ = ("SpikeTrain\n" +
  1470. "name: '%s'\ndescription: '%s'\nannotations: %s" %
  1471. (self.name1, self.description1, pretty(self.ann1)))
  1472. self.assertEqual(res, targ)
  1473. class TestMiscellaneous(unittest.TestCase):
  1474. def test__different_dtype_for_t_start_and_array(self):
  1475. data = np.array([0, 9.9999999], dtype=np.float64) * pq.s
  1476. data16 = data.astype(np.float16)
  1477. data32 = data.astype(np.float32)
  1478. data64 = data.astype(np.float64)
  1479. t_start = data[0]
  1480. t_stop = data[1]
  1481. t_start16 = data[0].astype(dtype=np.float16)
  1482. t_stop16 = data[1].astype(dtype=np.float16)
  1483. t_start32 = data[0].astype(dtype=np.float32)
  1484. t_stop32 = data[1].astype(dtype=np.float32)
  1485. t_start64 = data[0].astype(dtype=np.float64)
  1486. t_stop64 = data[1].astype(dtype=np.float64)
  1487. t_start_custom = 0.0
  1488. t_stop_custom = 10.0
  1489. t_start_custom16 = np.array(t_start_custom, dtype=np.float16)
  1490. t_stop_custom16 = np.array(t_stop_custom, dtype=np.float16)
  1491. t_start_custom32 = np.array(t_start_custom, dtype=np.float32)
  1492. t_stop_custom32 = np.array(t_stop_custom, dtype=np.float32)
  1493. t_start_custom64 = np.array(t_start_custom, dtype=np.float64)
  1494. t_stop_custom64 = np.array(t_stop_custom, dtype=np.float64)
  1495. # This is OK.
  1496. train = SpikeTrain(data64, copy=True, t_start=t_start, t_stop=t_stop)
  1497. assert_neo_object_is_compliant(train)
  1498. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop,
  1499. dtype=np.float16)
  1500. assert_neo_object_is_compliant(train)
  1501. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop,
  1502. dtype=np.float32)
  1503. assert_neo_object_is_compliant(train)
  1504. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop,
  1505. dtype=np.float16)
  1506. assert_neo_object_is_compliant(train)
  1507. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop,
  1508. dtype=np.float32)
  1509. assert_neo_object_is_compliant(train)
  1510. train = SpikeTrain(data32, copy=True,
  1511. t_start=t_start16, t_stop=t_stop16)
  1512. assert_neo_object_is_compliant(train)
  1513. train = SpikeTrain(data32, copy=True,
  1514. t_start=t_start16, t_stop=t_stop16,
  1515. dtype=np.float16)
  1516. assert_neo_object_is_compliant(train)
  1517. train = SpikeTrain(data32, copy=True,
  1518. t_start=t_start16, t_stop=t_stop16,
  1519. dtype=np.float32)
  1520. assert_neo_object_is_compliant(train)
  1521. train = SpikeTrain(data32, copy=True,
  1522. t_start=t_start16, t_stop=t_stop16,
  1523. dtype=np.float64)
  1524. assert_neo_object_is_compliant(train)
  1525. train = SpikeTrain(data32, copy=True,
  1526. t_start=t_start32, t_stop=t_stop32)
  1527. assert_neo_object_is_compliant(train)
  1528. train = SpikeTrain(data32, copy=True,
  1529. t_start=t_start32, t_stop=t_stop32,
  1530. dtype=np.float16)
  1531. assert_neo_object_is_compliant(train)
  1532. train = SpikeTrain(data32, copy=True,
  1533. t_start=t_start32, t_stop=t_stop32,
  1534. dtype=np.float32)
  1535. assert_neo_object_is_compliant(train)
  1536. train = SpikeTrain(data32, copy=True,
  1537. t_start=t_start32, t_stop=t_stop32,
  1538. dtype=np.float64)
  1539. assert_neo_object_is_compliant(train)
  1540. train = SpikeTrain(data32, copy=True,
  1541. t_start=t_start64, t_stop=t_stop64,
  1542. dtype=np.float16)
  1543. assert_neo_object_is_compliant(train)
  1544. train = SpikeTrain(data32, copy=True,
  1545. t_start=t_start64, t_stop=t_stop64,
  1546. dtype=np.float32)
  1547. assert_neo_object_is_compliant(train)
  1548. train = SpikeTrain(data16, copy=True,
  1549. t_start=t_start_custom, t_stop=t_stop_custom)
  1550. assert_neo_object_is_compliant(train)
  1551. train = SpikeTrain(data16, copy=True,
  1552. t_start=t_start_custom, t_stop=t_stop_custom,
  1553. dtype=np.float16)
  1554. assert_neo_object_is_compliant(train)
  1555. train = SpikeTrain(data16, copy=True,
  1556. t_start=t_start_custom, t_stop=t_stop_custom,
  1557. dtype=np.float32)
  1558. assert_neo_object_is_compliant(train)
  1559. train = SpikeTrain(data16, copy=True,
  1560. t_start=t_start_custom, t_stop=t_stop_custom,
  1561. dtype=np.float64)
  1562. assert_neo_object_is_compliant(train)
  1563. train = SpikeTrain(data32, copy=True,
  1564. t_start=t_start_custom, t_stop=t_stop_custom)
  1565. assert_neo_object_is_compliant(train)
  1566. train = SpikeTrain(data32, copy=True,
  1567. t_start=t_start_custom, t_stop=t_stop_custom,
  1568. dtype=np.float16)
  1569. assert_neo_object_is_compliant(train)
  1570. train = SpikeTrain(data32, copy=True,
  1571. t_start=t_start_custom, t_stop=t_stop_custom,
  1572. dtype=np.float32)
  1573. assert_neo_object_is_compliant(train)
  1574. train = SpikeTrain(data32, copy=True,
  1575. t_start=t_start_custom, t_stop=t_stop_custom,
  1576. dtype=np.float64)
  1577. assert_neo_object_is_compliant(train)
  1578. train = SpikeTrain(data16, copy=True,
  1579. t_start=t_start_custom, t_stop=t_stop_custom)
  1580. assert_neo_object_is_compliant(train)
  1581. train = SpikeTrain(data16, copy=True,
  1582. t_start=t_start_custom, t_stop=t_stop_custom,
  1583. dtype=np.float16)
  1584. assert_neo_object_is_compliant(train)
  1585. train = SpikeTrain(data16, copy=True,
  1586. t_start=t_start_custom, t_stop=t_stop_custom,
  1587. dtype=np.float32)
  1588. assert_neo_object_is_compliant(train)
  1589. train = SpikeTrain(data16, copy=True,
  1590. t_start=t_start_custom, t_stop=t_stop_custom,
  1591. dtype=np.float64)
  1592. assert_neo_object_is_compliant(train)
  1593. train = SpikeTrain(data32, copy=True,
  1594. t_start=t_start_custom16, t_stop=t_stop_custom16)
  1595. assert_neo_object_is_compliant(train)
  1596. train = SpikeTrain(data32, copy=True,
  1597. t_start=t_start_custom16, t_stop=t_stop_custom16,
  1598. dtype=np.float16)
  1599. assert_neo_object_is_compliant(train)
  1600. train = SpikeTrain(data32, copy=True,
  1601. t_start=t_start_custom16, t_stop=t_stop_custom16,
  1602. dtype=np.float32)
  1603. assert_neo_object_is_compliant(train)
  1604. train = SpikeTrain(data32, copy=True,
  1605. t_start=t_start_custom16, t_stop=t_stop_custom16,
  1606. dtype=np.float64)
  1607. assert_neo_object_is_compliant(train)
  1608. train = SpikeTrain(data32, copy=True,
  1609. t_start=t_start_custom32, t_stop=t_stop_custom32)
  1610. assert_neo_object_is_compliant(train)
  1611. train = SpikeTrain(data32, copy=True,
  1612. t_start=t_start_custom32, t_stop=t_stop_custom32,
  1613. dtype=np.float16)
  1614. assert_neo_object_is_compliant(train)
  1615. train = SpikeTrain(data32, copy=True,
  1616. t_start=t_start_custom32, t_stop=t_stop_custom32,
  1617. dtype=np.float32)
  1618. assert_neo_object_is_compliant(train)
  1619. train = SpikeTrain(data32, copy=True,
  1620. t_start=t_start_custom32, t_stop=t_stop_custom32,
  1621. dtype=np.float64)
  1622. assert_neo_object_is_compliant(train)
  1623. train = SpikeTrain(data32, copy=True,
  1624. t_start=t_start_custom64, t_stop=t_stop_custom64)
  1625. assert_neo_object_is_compliant(train)
  1626. train = SpikeTrain(data32, copy=True,
  1627. t_start=t_start_custom64, t_stop=t_stop_custom64,
  1628. dtype=np.float16)
  1629. assert_neo_object_is_compliant(train)
  1630. train = SpikeTrain(data32, copy=True,
  1631. t_start=t_start_custom64, t_stop=t_stop_custom64,
  1632. dtype=np.float32)
  1633. assert_neo_object_is_compliant(train)
  1634. train = SpikeTrain(data32, copy=True,
  1635. t_start=t_start_custom64, t_stop=t_stop_custom64,
  1636. dtype=np.float64)
  1637. assert_neo_object_is_compliant(train)
  1638. # This use to bug - see ticket #38
  1639. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop)
  1640. assert_neo_object_is_compliant(train)
  1641. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop,
  1642. dtype=np.float64)
  1643. assert_neo_object_is_compliant(train)
  1644. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop)
  1645. assert_neo_object_is_compliant(train)
  1646. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop,
  1647. dtype=np.float64)
  1648. assert_neo_object_is_compliant(train)
  1649. train = SpikeTrain(data32, copy=True,
  1650. t_start=t_start64, t_stop=t_stop64)
  1651. assert_neo_object_is_compliant(train)
  1652. train = SpikeTrain(data32, copy=True,
  1653. t_start=t_start64, t_stop=t_stop64,
  1654. dtype=np.float64)
  1655. assert_neo_object_is_compliant(train)
  1656. def test_as_array(self):
  1657. data = np.arange(10.0)
  1658. st = SpikeTrain(data, t_stop=10.0, units='ms')
  1659. st_as_arr = st.as_array()
  1660. self.assertIsInstance(st_as_arr, np.ndarray)
  1661. assert_array_equal(data, st_as_arr)
  1662. def test_as_quantity(self):
  1663. data = np.arange(10.0)
  1664. st = SpikeTrain(data, t_stop=10.0, units='ms')
  1665. st_as_q = st.as_quantity()
  1666. self.assertIsInstance(st_as_q, pq.Quantity)
  1667. assert_array_equal(data * pq.ms, st_as_q)
  1668. if __name__ == "__main__":
  1669. unittest.main()