test_spiketrain.py 88 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of the neo.core.spiketrain.SpikeTrain class and related functions
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import
  7. import sys
  8. import unittest
  9. import warnings
  10. import numpy as np
  11. from numpy.testing import assert_array_equal
  12. import quantities as pq
  13. from neo.core.dataobject import ArrayDict
  14. try:
  15. from IPython.lib.pretty import pretty
  16. except ImportError as err:
  17. HAVE_IPYTHON = False
  18. else:
  19. HAVE_IPYTHON = True
  20. from neo.core.spiketrain import (check_has_dimensions_time, SpikeTrain, _check_time_in_range,
  21. _new_spiketrain)
  22. from neo.core import Segment, Unit
  23. from neo.core.baseneo import MergeError
  24. from neo.test.tools import (assert_arrays_equal, assert_arrays_almost_equal,
  25. assert_neo_object_is_compliant)
  26. from neo.test.generate_datasets import (get_fake_value, get_fake_values, fake_neo,
  27. TEST_ANNOTATIONS)
  28. class Test__generate_datasets(unittest.TestCase):
  29. def setUp(self):
  30. np.random.seed(0)
  31. self.annotations = dict(
  32. [(str(x), TEST_ANNOTATIONS[x]) for x in range(len(TEST_ANNOTATIONS))])
  33. def test__get_fake_values(self):
  34. self.annotations['seed'] = 0
  35. waveforms = get_fake_value('waveforms', pq.Quantity, seed=3, dim=3)
  36. shape = waveforms.shape[0]
  37. times = get_fake_value('times', pq.Quantity, seed=0, dim=1, shape=shape)
  38. t_start = get_fake_value('t_start', pq.Quantity, seed=1, dim=0)
  39. t_stop = get_fake_value('t_stop', pq.Quantity, seed=2, dim=0)
  40. left_sweep = get_fake_value('left_sweep', pq.Quantity, seed=4, dim=0)
  41. sampling_rate = get_fake_value('sampling_rate', pq.Quantity, seed=5, dim=0)
  42. name = get_fake_value('name', str, seed=6, obj=SpikeTrain)
  43. description = get_fake_value('description', str, seed=7, obj='SpikeTrain')
  44. file_origin = get_fake_value('file_origin', str)
  45. arr_ann = get_fake_value('array_annotations', dict, seed=9, obj=SpikeTrain, n=1)
  46. attrs1 = {'name': name, 'description': description, 'file_origin': file_origin}
  47. attrs2 = attrs1.copy()
  48. attrs2.update(self.annotations)
  49. attrs2['array_annotations'] = arr_ann
  50. res11 = get_fake_values(SpikeTrain, annotate=False, seed=0)
  51. res12 = get_fake_values('SpikeTrain', annotate=False, seed=0)
  52. res21 = get_fake_values(SpikeTrain, annotate=True, seed=0)
  53. res22 = get_fake_values('SpikeTrain', annotate=True, seed=0)
  54. assert_arrays_equal(res11.pop('times'), times)
  55. assert_arrays_equal(res12.pop('times'), times)
  56. assert_arrays_equal(res21.pop('times'), times)
  57. assert_arrays_equal(res22.pop('times'), times)
  58. assert_arrays_equal(res11.pop('t_start'), t_start)
  59. assert_arrays_equal(res12.pop('t_start'), t_start)
  60. assert_arrays_equal(res21.pop('t_start'), t_start)
  61. assert_arrays_equal(res22.pop('t_start'), t_start)
  62. assert_arrays_equal(res11.pop('t_stop'), t_stop)
  63. assert_arrays_equal(res12.pop('t_stop'), t_stop)
  64. assert_arrays_equal(res21.pop('t_stop'), t_stop)
  65. assert_arrays_equal(res22.pop('t_stop'), t_stop)
  66. assert_arrays_equal(res11.pop('waveforms'), waveforms)
  67. assert_arrays_equal(res12.pop('waveforms'), waveforms)
  68. assert_arrays_equal(res21.pop('waveforms'), waveforms)
  69. assert_arrays_equal(res22.pop('waveforms'), waveforms)
  70. assert_arrays_equal(res11.pop('left_sweep'), left_sweep)
  71. assert_arrays_equal(res12.pop('left_sweep'), left_sweep)
  72. assert_arrays_equal(res21.pop('left_sweep'), left_sweep)
  73. assert_arrays_equal(res22.pop('left_sweep'), left_sweep)
  74. assert_arrays_equal(res11.pop('sampling_rate'), sampling_rate)
  75. assert_arrays_equal(res12.pop('sampling_rate'), sampling_rate)
  76. assert_arrays_equal(res21.pop('sampling_rate'), sampling_rate)
  77. assert_arrays_equal(res22.pop('sampling_rate'), sampling_rate)
  78. self.assertEqual(res11, attrs1)
  79. self.assertEqual(res12, attrs1)
  80. # Array annotations need to be compared separately
  81. # because numpy arrays define equality differently
  82. arr_ann_res21 = res21.pop('array_annotations')
  83. arr_ann_attrs2 = attrs2.pop('array_annotations')
  84. self.assertEqual(res21, attrs2)
  85. assert_arrays_equal(arr_ann_res21['valid'], arr_ann_attrs2['valid'])
  86. assert_arrays_equal(arr_ann_res21['number'], arr_ann_attrs2['number'])
  87. arr_ann_res22 = res22.pop('array_annotations')
  88. self.assertEqual(res22, attrs2)
  89. assert_arrays_equal(arr_ann_res22['valid'], arr_ann_attrs2['valid'])
  90. assert_arrays_equal(arr_ann_res22['number'], arr_ann_attrs2['number'])
  91. def test__fake_neo__cascade(self):
  92. self.annotations['seed'] = None
  93. obj_type = 'SpikeTrain'
  94. cascade = True
  95. res = fake_neo(obj_type=obj_type, cascade=cascade)
  96. self.assertTrue(isinstance(res, SpikeTrain))
  97. assert_neo_object_is_compliant(res)
  98. self.assertEqual(res.annotations, self.annotations)
  99. def test__fake_neo__nocascade(self):
  100. self.annotations['seed'] = None
  101. obj_type = SpikeTrain
  102. cascade = False
  103. res = fake_neo(obj_type=obj_type, cascade=cascade)
  104. self.assertTrue(isinstance(res, SpikeTrain))
  105. assert_neo_object_is_compliant(res)
  106. self.assertEqual(res.annotations, self.annotations)
  107. class Testcheck_has_dimensions_time(unittest.TestCase):
  108. def test__check_has_dimensions_time(self):
  109. a = np.arange(3) * pq.ms
  110. b = np.arange(3) * pq.mV
  111. c = np.arange(3) * pq.mA
  112. d = np.arange(3) * pq.minute
  113. check_has_dimensions_time(a)
  114. self.assertRaises(ValueError, check_has_dimensions_time, b)
  115. self.assertRaises(ValueError, check_has_dimensions_time, c)
  116. check_has_dimensions_time(d)
  117. self.assertRaises(ValueError, check_has_dimensions_time, a, b, c, d)
  118. class Testcheck_time_in_range(unittest.TestCase):
  119. def test__check_time_in_range_empty_array(self):
  120. value = np.array([])
  121. t_start = 0 * pq.s
  122. t_stop = 10 * pq.s
  123. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  124. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  125. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=True)
  126. def test__check_time_in_range_empty_array_invalid_t_stop(self):
  127. value = np.array([])
  128. t_start = 6 * pq.s
  129. t_stop = 4 * pq.s
  130. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop)
  131. def test__check_time_in_range_exact(self):
  132. value = np.array([0., 5., 10.]) * pq.s
  133. t_start = 0. * pq.s
  134. t_stop = 10. * pq.s
  135. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  136. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  137. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=True)
  138. def test__check_time_in_range_scale(self):
  139. value = np.array([0., 5000., 10000.]) * pq.ms
  140. t_start = 0. * pq.s
  141. t_stop = 10. * pq.s
  142. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  143. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  144. def test__check_time_in_range_inside(self):
  145. value = np.array([0.1, 5., 9.9]) * pq.s
  146. t_start = 0. * pq.s
  147. t_stop = 10. * pq.s
  148. _check_time_in_range(value, t_start=t_start, t_stop=t_stop)
  149. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=False)
  150. _check_time_in_range(value, t_start=t_start, t_stop=t_stop, view=True)
  151. def test__check_time_in_range_below(self):
  152. value = np.array([-0.1, 5., 10.]) * pq.s
  153. t_start = 0. * pq.s
  154. t_stop = 10. * pq.s
  155. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop)
  156. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop,
  157. view=False)
  158. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop,
  159. view=True)
  160. def test__check_time_in_range_below_scale(self):
  161. value = np.array([-1., 5000., 10000.]) * pq.ms
  162. t_start = 0. * pq.s
  163. t_stop = 10. * pq.s
  164. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop)
  165. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop,
  166. view=False)
  167. def test__check_time_in_range_above(self):
  168. value = np.array([0., 5., 10.1]) * pq.s
  169. t_start = 0. * pq.s
  170. t_stop = 10. * pq.s
  171. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop)
  172. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop,
  173. view=False)
  174. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop,
  175. view=True)
  176. def test__check_time_in_range_above_scale(self):
  177. value = np.array([0., 5000., 10001.]) * pq.ms
  178. t_start = 0. * pq.s
  179. t_stop = 10. * pq.s
  180. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop)
  181. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop,
  182. view=False)
  183. def test__check_time_in_range_above_below(self):
  184. value = np.array([-0.1, 5., 10.1]) * pq.s
  185. t_start = 0. * pq.s
  186. t_stop = 10. * pq.s
  187. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop)
  188. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop,
  189. view=False)
  190. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop,
  191. view=True)
  192. def test__check_time_in_range_above_below_scale(self):
  193. value = np.array([-1., 5000., 10001.]) * pq.ms
  194. t_start = 0. * pq.s
  195. t_stop = 10. * pq.s
  196. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop)
  197. self.assertRaises(ValueError, _check_time_in_range, value, t_start=t_start, t_stop=t_stop,
  198. view=False)
  199. class TestConstructor(unittest.TestCase):
  200. def result_spike_check(self, train, st_out, t_start_out, t_stop_out, dtype, units):
  201. assert_arrays_equal(train, st_out)
  202. assert_arrays_equal(train, train.times)
  203. assert_neo_object_is_compliant(train)
  204. self.assertEqual(train.t_start, t_start_out)
  205. self.assertEqual(train.t_stop, t_stop_out)
  206. self.assertEqual(train.units, units)
  207. self.assertEqual(train.units, train.times.units)
  208. self.assertEqual(train.t_start.units, units)
  209. self.assertEqual(train.t_stop.units, units)
  210. self.assertEqual(train.dtype, dtype)
  211. self.assertEqual(train.dtype, train.times.dtype)
  212. self.assertEqual(train.t_stop.dtype, dtype)
  213. self.assertEqual(train.t_start.dtype, dtype)
  214. def test__create_minimal(self):
  215. t_start = 0.0
  216. t_stop = 10.0
  217. train1 = SpikeTrain([] * pq.s, t_stop)
  218. train2 = _new_spiketrain(SpikeTrain, [] * pq.s, t_stop)
  219. dtype = np.float64
  220. units = 1 * pq.s
  221. t_start_out = t_start * units
  222. t_stop_out = t_stop * units
  223. st_out = [] * units
  224. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  225. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  226. def test__create_empty(self):
  227. t_start = 0.0
  228. t_stop = 10.0
  229. train1 = SpikeTrain([], t_start=t_start, t_stop=t_stop, units='s')
  230. train2 = _new_spiketrain(SpikeTrain, [], t_start=t_start, t_stop=t_stop, units='s')
  231. dtype = np.float64
  232. units = 1 * pq.s
  233. t_start_out = t_start * units
  234. t_stop_out = t_stop * units
  235. st_out = [] * units
  236. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  237. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  238. def test__create_empty_no_t_start(self):
  239. t_start = 0.0
  240. t_stop = 10.0
  241. train1 = SpikeTrain([], t_stop=t_stop, units='s')
  242. train2 = _new_spiketrain(SpikeTrain, [], t_stop=t_stop, units='s')
  243. dtype = np.float64
  244. units = 1 * pq.s
  245. t_start_out = t_start * units
  246. t_stop_out = t_stop * units
  247. st_out = [] * units
  248. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  249. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  250. def test__create_from_list(self):
  251. times = range(10)
  252. t_start = 0.0 * pq.s
  253. t_stop = 10000.0 * pq.ms
  254. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="ms")
  255. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units="ms")
  256. dtype = np.float64
  257. units = 1 * pq.ms
  258. t_start_out = t_start
  259. t_stop_out = t_stop
  260. st_out = times * units
  261. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  262. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  263. def test__create_from_list_set_dtype(self):
  264. times = range(10)
  265. t_start = 0.0 * pq.s
  266. t_stop = 10000.0 * pq.ms
  267. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="ms", dtype='f4')
  268. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units="ms",
  269. dtype='f4')
  270. dtype = np.float32
  271. units = 1 * pq.ms
  272. t_start_out = t_start.astype(dtype)
  273. t_stop_out = t_stop.astype(dtype)
  274. st_out = pq.Quantity(times, units=units, dtype=dtype)
  275. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  276. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  277. def test__create_from_list_no_start_stop_units(self):
  278. times = range(10)
  279. t_start = 0.0
  280. t_stop = 10000.0
  281. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="ms")
  282. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units="ms")
  283. dtype = np.float64
  284. units = 1 * pq.ms
  285. t_start_out = t_start * units
  286. t_stop_out = t_stop * units
  287. st_out = times * units
  288. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  289. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  290. def test__create_from_list_no_start_stop_units_set_dtype(self):
  291. times = range(10)
  292. t_start = 0.0
  293. t_stop = 10000.0
  294. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="ms", dtype='f4')
  295. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units="ms",
  296. dtype='f4')
  297. dtype = np.float32
  298. units = 1 * pq.ms
  299. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  300. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  301. st_out = pq.Quantity(times, units=units, dtype=dtype)
  302. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  303. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  304. def test__create_from_array(self):
  305. times = np.arange(10)
  306. t_start = 0.0 * pq.s
  307. t_stop = 10000.0 * pq.ms
  308. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  309. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units="s")
  310. dtype = np.int
  311. units = 1 * pq.s
  312. t_start_out = t_start
  313. t_stop_out = t_stop
  314. st_out = times * units
  315. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  316. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  317. def test__create_from_array_with_dtype(self):
  318. times = np.arange(10, dtype='f4')
  319. t_start = 0.0 * pq.s
  320. t_stop = 10000.0 * pq.ms
  321. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  322. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units="s")
  323. dtype = times.dtype
  324. units = 1 * pq.s
  325. t_start_out = t_start
  326. t_stop_out = t_stop
  327. st_out = times * units
  328. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  329. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  330. def test__create_from_array_set_dtype(self):
  331. times = np.arange(10)
  332. t_start = 0.0 * pq.s
  333. t_stop = 10000.0 * pq.ms
  334. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s", dtype='f4')
  335. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units="s",
  336. dtype='f4')
  337. dtype = np.float32
  338. units = 1 * pq.s
  339. t_start_out = t_start.astype(dtype)
  340. t_stop_out = t_stop.astype(dtype)
  341. st_out = times.astype(dtype) * units
  342. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  343. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  344. def test__create_from_array_no_start_stop_units(self):
  345. times = np.arange(10)
  346. t_start = 0.0
  347. t_stop = 10000.0
  348. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  349. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units="s")
  350. dtype = np.int
  351. units = 1 * pq.s
  352. t_start_out = t_start * units
  353. t_stop_out = t_stop * units
  354. st_out = times * units
  355. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  356. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  357. def test__create_from_array_no_start_stop_units_with_dtype(self):
  358. times = np.arange(10, dtype='f4')
  359. t_start = 0.0
  360. t_stop = 10000.0
  361. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s")
  362. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units="s")
  363. dtype = np.float32
  364. units = 1 * pq.s
  365. t_start_out = t_start * units
  366. t_stop_out = t_stop * units
  367. st_out = times * units
  368. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  369. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  370. def test__create_from_array_no_start_stop_units_set_dtype(self):
  371. times = np.arange(10)
  372. t_start = 0.0
  373. t_stop = 10000.0
  374. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units="s", dtype='f4')
  375. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units="s",
  376. dtype='f4')
  377. dtype = np.float32
  378. units = 1 * pq.s
  379. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  380. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  381. st_out = times.astype(dtype) * units
  382. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  383. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  384. def test__create_from_quantity_array(self):
  385. times = np.arange(10) * pq.ms
  386. t_start = 0.0 * pq.s
  387. t_stop = 12.0 * pq.ms
  388. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  389. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop)
  390. dtype = np.float64
  391. units = 1 * pq.ms
  392. t_start_out = t_start
  393. t_stop_out = t_stop
  394. st_out = times
  395. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  396. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  397. def test__create_from_quantity_array_with_dtype(self):
  398. times = np.arange(10, dtype='f4') * pq.ms
  399. t_start = 0.0 * pq.s
  400. t_stop = 12.0 * pq.ms
  401. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  402. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop)
  403. dtype = np.float32
  404. units = 1 * pq.ms
  405. t_start_out = t_start.astype(dtype)
  406. t_stop_out = t_stop.astype(dtype)
  407. st_out = times.astype(dtype)
  408. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  409. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  410. def test__create_from_quantity_array_set_dtype(self):
  411. times = np.arange(10) * pq.ms
  412. t_start = 0.0 * pq.s
  413. t_stop = 12.0 * pq.ms
  414. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, dtype='f4')
  415. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, dtype='f4')
  416. dtype = np.float32
  417. units = 1 * pq.ms
  418. t_start_out = t_start.astype(dtype)
  419. t_stop_out = t_stop.astype(dtype)
  420. st_out = times.astype(dtype)
  421. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  422. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  423. def test__create_from_quantity_array_no_start_stop_units(self):
  424. times = np.arange(10) * pq.ms
  425. t_start = 0.0
  426. t_stop = 12.0
  427. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  428. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop)
  429. dtype = np.float64
  430. units = 1 * pq.ms
  431. t_start_out = t_start * units
  432. t_stop_out = t_stop * units
  433. st_out = times
  434. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  435. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  436. def test__create_from_quantity_array_no_start_stop_units_with_dtype(self):
  437. times = np.arange(10, dtype='f4') * pq.ms
  438. t_start = 0.0
  439. t_stop = 12.0
  440. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop)
  441. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop)
  442. dtype = np.float32
  443. units = 1 * pq.ms
  444. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  445. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  446. st_out = times.astype(dtype)
  447. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  448. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  449. def test__create_from_quantity_array_no_start_stop_units_set_dtype(self):
  450. times = np.arange(10) * pq.ms
  451. t_start = 0.0
  452. t_stop = 12.0
  453. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, dtype='f4')
  454. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, dtype='f4')
  455. dtype = np.float32
  456. units = 1 * pq.ms
  457. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  458. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  459. st_out = times.astype(dtype)
  460. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  461. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  462. def test__create_from_quantity_array_units(self):
  463. times = np.arange(10) * pq.ms
  464. t_start = 0.0 * pq.s
  465. t_stop = 12.0 * pq.ms
  466. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units='s')
  467. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units='s')
  468. dtype = np.float64
  469. units = 1 * pq.s
  470. t_start_out = t_start
  471. t_stop_out = t_stop
  472. st_out = times
  473. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  474. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  475. def test__create_from_quantity_array_units_with_dtype(self):
  476. times = np.arange(10, dtype='f4') * pq.ms
  477. t_start = 0.0 * pq.s
  478. t_stop = 12.0 * pq.ms
  479. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units='s')
  480. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units='s')
  481. dtype = np.float32
  482. units = 1 * pq.s
  483. t_start_out = t_start.astype(dtype)
  484. t_stop_out = t_stop.rescale(units).astype(dtype)
  485. st_out = times.rescale(units).astype(dtype)
  486. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  487. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  488. def test__create_from_quantity_array_units_set_dtype(self):
  489. times = np.arange(10) * pq.ms
  490. t_start = 0.0 * pq.s
  491. t_stop = 12.0 * pq.ms
  492. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units='s', dtype='f4')
  493. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units='s',
  494. dtype='f4')
  495. dtype = np.float32
  496. units = 1 * pq.s
  497. t_start_out = t_start.astype(dtype)
  498. t_stop_out = t_stop.rescale(units).astype(dtype)
  499. st_out = times.rescale(units).astype(dtype)
  500. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  501. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  502. def test__create_from_quantity_array_units_no_start_stop_units(self):
  503. times = np.arange(10) * pq.ms
  504. t_start = 0.0
  505. t_stop = 12.0
  506. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units='s')
  507. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units='s')
  508. dtype = np.float64
  509. units = 1 * pq.s
  510. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  511. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  512. st_out = times
  513. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  514. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  515. def test__create_from_quantity_units_no_start_stop_units_set_dtype(self):
  516. times = np.arange(10) * pq.ms
  517. t_start = 0.0
  518. t_stop = 12.0
  519. train1 = SpikeTrain(times, t_start=t_start, t_stop=t_stop, units='s', dtype='f4')
  520. train2 = _new_spiketrain(SpikeTrain, times, t_start=t_start, t_stop=t_stop, units='s',
  521. dtype='f4')
  522. dtype = np.float32
  523. units = 1 * pq.s
  524. t_start_out = pq.Quantity(t_start, units=units, dtype=dtype)
  525. t_stop_out = pq.Quantity(t_stop, units=units, dtype=dtype)
  526. st_out = times.rescale(units).astype(dtype)
  527. self.result_spike_check(train1, st_out, t_start_out, t_stop_out, dtype, units)
  528. self.result_spike_check(train2, st_out, t_start_out, t_stop_out, dtype, units)
  529. def test__create_from_list_without_units_should_raise_ValueError(self):
  530. times = range(10)
  531. t_start = 0.0 * pq.s
  532. t_stop = 10000.0 * pq.ms
  533. self.assertRaises(ValueError, SpikeTrain, times, t_start=t_start, t_stop=t_stop)
  534. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain, times, t_start=t_start,
  535. t_stop=t_stop)
  536. def test__create_from_array_without_units_should_raise_ValueError(self):
  537. times = np.arange(10)
  538. t_start = 0.0 * pq.s
  539. t_stop = 10000.0 * pq.ms
  540. self.assertRaises(ValueError, SpikeTrain, times, t_start=t_start, t_stop=t_stop)
  541. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain, times, t_start=t_start,
  542. t_stop=t_stop)
  543. def test__create_from_array_with_incompatible_units_ValueError(self):
  544. times = np.arange(10) * pq.km
  545. t_start = 0.0 * pq.s
  546. t_stop = 10000.0 * pq.ms
  547. self.assertRaises(ValueError, SpikeTrain, times, t_start=t_start, t_stop=t_stop)
  548. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain, times, t_start=t_start,
  549. t_stop=t_stop)
  550. def test__create_with_times_outside_tstart_tstop_ValueError(self):
  551. t_start = 23
  552. t_stop = 77
  553. train1 = SpikeTrain(np.arange(t_start, t_stop), units='ms', t_start=t_start, t_stop=t_stop)
  554. train2 = _new_spiketrain(SpikeTrain, np.arange(t_start, t_stop), units='ms',
  555. t_start=t_start, t_stop=t_stop)
  556. assert_neo_object_is_compliant(train1)
  557. assert_neo_object_is_compliant(train2)
  558. self.assertRaises(ValueError, SpikeTrain, np.arange(t_start - 5, t_stop), units='ms',
  559. t_start=t_start, t_stop=t_stop)
  560. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain, np.arange(t_start - 5, t_stop),
  561. units='ms', t_start=t_start, t_stop=t_stop)
  562. self.assertRaises(ValueError, SpikeTrain, np.arange(t_start, t_stop + 5), units='ms',
  563. t_start=t_start, t_stop=t_stop)
  564. self.assertRaises(ValueError, _new_spiketrain, SpikeTrain, np.arange(t_start, t_stop + 5),
  565. units='ms', t_start=t_start, t_stop=t_stop)
  566. def test__create_with_len_times_different_size_than_waveform_shape1_ValueError(self):
  567. self.assertRaises(ValueError, SpikeTrain, times=np.arange(10), units='s', t_stop=4,
  568. waveforms=np.ones((10, 6, 50)))
  569. def test_defaults(self):
  570. # default recommended attributes
  571. train1 = SpikeTrain([3, 4, 5], units='sec', t_stop=10.0)
  572. train2 = _new_spiketrain(SpikeTrain, [3, 4, 5], units='sec', t_stop=10.0)
  573. assert_neo_object_is_compliant(train1)
  574. assert_neo_object_is_compliant(train2)
  575. self.assertEqual(train1.dtype, np.float)
  576. self.assertEqual(train2.dtype, np.float)
  577. self.assertEqual(train1.sampling_rate, 1.0 * pq.Hz)
  578. self.assertEqual(train2.sampling_rate, 1.0 * pq.Hz)
  579. self.assertEqual(train1.waveforms, None)
  580. self.assertEqual(train2.waveforms, None)
  581. self.assertEqual(train1.left_sweep, None)
  582. self.assertEqual(train2.left_sweep, None)
  583. self.assertEqual(train1.array_annotations, {})
  584. self.assertEqual(train2.array_annotations, {})
  585. self.assertIsInstance(train1.array_annotations, ArrayDict)
  586. self.assertIsInstance(train2.array_annotations, ArrayDict)
  587. def test_default_tstart(self):
  588. # t start defaults to zero
  589. train11 = SpikeTrain([3, 4, 5] * pq.s, t_stop=8000 * pq.ms)
  590. train21 = _new_spiketrain(SpikeTrain, [3, 4, 5] * pq.s, t_stop=8000 * pq.ms)
  591. assert_neo_object_is_compliant(train11)
  592. assert_neo_object_is_compliant(train21)
  593. self.assertEqual(train11.t_start, 0. * pq.s)
  594. self.assertEqual(train21.t_start, 0. * pq.s)
  595. # unless otherwise specified
  596. train12 = SpikeTrain([3, 4, 5] * pq.s, t_start=2.0, t_stop=8)
  597. train22 = _new_spiketrain(SpikeTrain, [3, 4, 5] * pq.s, t_start=2.0, t_stop=8)
  598. assert_neo_object_is_compliant(train12)
  599. assert_neo_object_is_compliant(train22)
  600. self.assertEqual(train12.t_start, 2. * pq.s)
  601. self.assertEqual(train22.t_start, 2. * pq.s)
  602. def test_tstop_units_conversion(self):
  603. train11 = SpikeTrain([3, 5, 4] * pq.s, t_stop=10)
  604. train21 = _new_spiketrain(SpikeTrain, [3, 5, 4] * pq.s, t_stop=10)
  605. assert_neo_object_is_compliant(train11)
  606. assert_neo_object_is_compliant(train21)
  607. self.assertEqual(train11.t_stop, 10. * pq.s)
  608. self.assertEqual(train21.t_stop, 10. * pq.s)
  609. train12 = SpikeTrain([3, 5, 4] * pq.s, t_stop=10000. * pq.ms)
  610. train22 = _new_spiketrain(SpikeTrain, [3, 5, 4] * pq.s, t_stop=10000. * pq.ms)
  611. assert_neo_object_is_compliant(train12)
  612. assert_neo_object_is_compliant(train22)
  613. self.assertEqual(train12.t_stop, 10. * pq.s)
  614. self.assertEqual(train22.t_stop, 10. * pq.s)
  615. train13 = SpikeTrain([3, 5, 4], units='sec', t_stop=10000. * pq.ms)
  616. train23 = _new_spiketrain(SpikeTrain, [3, 5, 4], units='sec', t_stop=10000. * pq.ms)
  617. assert_neo_object_is_compliant(train13)
  618. assert_neo_object_is_compliant(train23)
  619. self.assertEqual(train13.t_stop, 10. * pq.s)
  620. self.assertEqual(train23.t_stop, 10. * pq.s)
  621. class TestSorting(unittest.TestCase):
  622. def test_sort(self):
  623. waveforms = np.array([[[0., 1.]], [[2., 3.]], [[4., 5.]]]) * pq.mV
  624. train = SpikeTrain([3, 4, 5] * pq.s, waveforms=waveforms, name='n', t_stop=10.0,
  625. array_annotations={'a': np.arange(3)})
  626. assert_neo_object_is_compliant(train)
  627. train.sort()
  628. assert_neo_object_is_compliant(train)
  629. assert_arrays_equal(train, [3, 4, 5] * pq.s)
  630. assert_arrays_equal(train.waveforms, waveforms)
  631. self.assertEqual(train.name, 'n')
  632. self.assertEqual(train.t_stop, 10.0 * pq.s)
  633. assert_arrays_equal(train.array_annotations['a'], np.arange(3))
  634. train = SpikeTrain([3, 5, 4] * pq.s, waveforms=waveforms, name='n', t_stop=10.0,
  635. array_annotations={'a': np.arange(3)})
  636. assert_neo_object_is_compliant(train)
  637. train.sort()
  638. assert_neo_object_is_compliant(train)
  639. assert_arrays_equal(train, [3, 4, 5] * pq.s)
  640. assert_arrays_equal(train.waveforms, waveforms[[0, 2, 1]])
  641. self.assertEqual(train.name, 'n')
  642. self.assertEqual(train.t_start, 0.0 * pq.s)
  643. self.assertEqual(train.t_stop, 10.0 * pq.s)
  644. assert_arrays_equal(train.array_annotations['a'], np.array([0, 2, 1]))
  645. self.assertIsInstance(train.array_annotations, ArrayDict)
  646. class TestSlice(unittest.TestCase):
  647. def setUp(self):
  648. self.waveforms1 = np.array(
  649. [[[0., 1.], [0.1, 1.1]], [[2., 3.], [2.1, 3.1]], [[4., 5.], [4.1, 5.1]]]) * pq.mV
  650. self.data1 = np.array([3, 4, 5])
  651. self.data1quant = self.data1 * pq.s
  652. self.arr_ann = {'index': np.arange(1, 4), 'label': ['abc', 'def', 'ghi']}
  653. self.train1 = SpikeTrain(self.data1quant, waveforms=self.waveforms1, name='n', arb='arbb',
  654. t_stop=10.0, array_annotations=self.arr_ann)
  655. def test_compliant(self):
  656. assert_neo_object_is_compliant(self.train1)
  657. def test_slice(self):
  658. # slice spike train, keep sliced spike times
  659. result = self.train1[1:2]
  660. assert_arrays_equal(self.train1[1:2], result)
  661. targwaveforms = np.array([[[2., 3.], [2.1, 3.1]]]) * pq.mV
  662. # but keep everything else pristine
  663. assert_neo_object_is_compliant(result)
  664. self.assertEqual(self.train1.name, result.name)
  665. self.assertEqual(self.train1.description, result.description)
  666. self.assertEqual(self.train1.annotations, result.annotations)
  667. self.assertEqual(self.train1.file_origin, result.file_origin)
  668. self.assertEqual(self.train1.dtype, result.dtype)
  669. self.assertEqual(self.train1.t_start, result.t_start)
  670. self.assertEqual(self.train1.t_stop, result.t_stop)
  671. # except we update the waveforms
  672. assert_arrays_equal(self.train1.waveforms[1:2], result.waveforms)
  673. assert_arrays_equal(targwaveforms, result.waveforms)
  674. # Also array annotations should be updated
  675. assert_arrays_equal(result.array_annotations['index'], np.array([2]))
  676. assert_arrays_equal(result.array_annotations['label'], np.array(['def']))
  677. self.assertIsInstance(result.array_annotations['index'], np.ndarray)
  678. self.assertIsInstance(result.array_annotations['label'], np.ndarray)
  679. self.assertIsInstance(result.array_annotations, ArrayDict)
  680. def test_slice_to_end(self):
  681. # slice spike train, keep sliced spike times
  682. result = self.train1[1:]
  683. assert_arrays_equal(self.train1[1:], result)
  684. targwaveforms = np.array([[[2., 3.], [2.1, 3.1]], [[4., 5.], [4.1, 5.1]]]) * pq.mV
  685. # but keep everything else pristine
  686. assert_neo_object_is_compliant(result)
  687. self.assertEqual(self.train1.name, result.name)
  688. self.assertEqual(self.train1.description, result.description)
  689. self.assertEqual(self.train1.annotations, result.annotations)
  690. self.assertEqual(self.train1.file_origin, result.file_origin)
  691. self.assertEqual(self.train1.dtype, result.dtype)
  692. self.assertEqual(self.train1.t_start, result.t_start)
  693. self.assertEqual(self.train1.t_stop, result.t_stop)
  694. # except we update the waveforms
  695. assert_arrays_equal(self.train1.waveforms[1:], result.waveforms)
  696. assert_arrays_equal(targwaveforms, result.waveforms)
  697. # Also array annotations should be updated
  698. assert_arrays_equal(result.array_annotations['index'], np.array([2, 3]))
  699. assert_arrays_equal(result.array_annotations['label'], np.array(['def', 'ghi']))
  700. self.assertIsInstance(result.array_annotations['index'], np.ndarray)
  701. self.assertIsInstance(result.array_annotations['label'], np.ndarray)
  702. self.assertIsInstance(result.array_annotations, ArrayDict)
  703. def test_slice_from_beginning(self):
  704. # slice spike train, keep sliced spike times
  705. result = self.train1[:2]
  706. assert_arrays_equal(self.train1[:2], result)
  707. targwaveforms = np.array([[[0., 1.], [0.1, 1.1]], [[2., 3.], [2.1, 3.1]]]) * pq.mV
  708. # but keep everything else pristine
  709. assert_neo_object_is_compliant(result)
  710. self.assertEqual(self.train1.name, result.name)
  711. self.assertEqual(self.train1.description, result.description)
  712. self.assertEqual(self.train1.annotations, result.annotations)
  713. self.assertEqual(self.train1.file_origin, result.file_origin)
  714. self.assertEqual(self.train1.dtype, result.dtype)
  715. self.assertEqual(self.train1.t_start, result.t_start)
  716. self.assertEqual(self.train1.t_stop, result.t_stop)
  717. # except we update the waveforms
  718. assert_arrays_equal(self.train1.waveforms[:2], result.waveforms)
  719. assert_arrays_equal(targwaveforms, result.waveforms)
  720. # Also array annotations should be updated
  721. assert_arrays_equal(result.array_annotations['index'], np.array([1, 2]))
  722. assert_arrays_equal(result.array_annotations['label'], np.array(['abc', 'def']))
  723. self.assertIsInstance(result.array_annotations['index'], np.ndarray)
  724. self.assertIsInstance(result.array_annotations['label'], np.ndarray)
  725. self.assertIsInstance(result.array_annotations, ArrayDict)
  726. def test_slice_negative_idxs(self):
  727. # slice spike train, keep sliced spike times
  728. result = self.train1[:-1]
  729. assert_arrays_equal(self.train1[:-1], result)
  730. targwaveforms = np.array([[[0., 1.], [0.1, 1.1]], [[2., 3.], [2.1, 3.1]]]) * pq.mV
  731. # but keep everything else pristine
  732. assert_neo_object_is_compliant(result)
  733. self.assertEqual(self.train1.name, result.name)
  734. self.assertEqual(self.train1.description, result.description)
  735. self.assertEqual(self.train1.annotations, result.annotations)
  736. self.assertEqual(self.train1.file_origin, result.file_origin)
  737. self.assertEqual(self.train1.dtype, result.dtype)
  738. self.assertEqual(self.train1.t_start, result.t_start)
  739. self.assertEqual(self.train1.t_stop, result.t_stop)
  740. # except we update the waveforms
  741. assert_arrays_equal(self.train1.waveforms[:-1], result.waveforms)
  742. assert_arrays_equal(targwaveforms, result.waveforms)
  743. # Also array annotations should be updated
  744. assert_arrays_equal(result.array_annotations['index'], np.array([1, 2]))
  745. assert_arrays_equal(result.array_annotations['label'], np.array(['abc', 'def']))
  746. self.assertIsInstance(result.array_annotations['index'], np.ndarray)
  747. self.assertIsInstance(result.array_annotations['label'], np.ndarray)
  748. self.assertIsInstance(result.array_annotations, ArrayDict)
  749. class TestTimeSlice(unittest.TestCase):
  750. def setUp(self):
  751. self.waveforms1 = np.array(
  752. [[[0., 1.], [0.1, 1.1]], [[2., 3.], [2.1, 3.1]], [[4., 5.], [4.1, 5.1]],
  753. [[6., 7.], [6.1, 7.1]], [[8., 9.], [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV
  754. self.data1 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7])
  755. self.data1quant = self.data1 * pq.ms
  756. self.arr_ann = {'index': np.arange(1, 7), 'label': ['a', 'b', 'c', 'd', 'e', 'f']}
  757. self.train1 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms, waveforms=self.waveforms1,
  758. array_annotations=self.arr_ann)
  759. def test_compliant(self):
  760. assert_neo_object_is_compliant(self.train1)
  761. def test_time_slice_typical(self):
  762. # time_slice spike train, keep sliced spike times
  763. # this is the typical time slice falling somewhere
  764. # in the middle of spikes
  765. t_start = 0.12 * pq.ms
  766. t_stop = 3.5 * pq.ms
  767. result = self.train1.time_slice(t_start, t_stop)
  768. targ = SpikeTrain([0.5, 1.2, 3.3] * pq.ms, t_stop=3.3)
  769. assert_arrays_equal(result, targ)
  770. targwaveforms = np.array(
  771. [[[2., 3.], [2.1, 3.1]], [[4., 5.], [4.1, 5.1]], [[6., 7.], [6.1, 7.1]]]) * pq.mV
  772. assert_arrays_equal(targwaveforms, result.waveforms)
  773. # but keep everything else pristine
  774. assert_neo_object_is_compliant(result)
  775. self.assertEqual(self.train1.name, result.name)
  776. self.assertEqual(self.train1.description, result.description)
  777. self.assertEqual(self.train1.annotations, result.annotations)
  778. self.assertEqual(self.train1.file_origin, result.file_origin)
  779. self.assertEqual(self.train1.dtype, result.dtype)
  780. self.assertEqual(t_start, result.t_start)
  781. self.assertEqual(t_stop, result.t_stop)
  782. # Array annotations should be updated according to time slice
  783. assert_arrays_equal(result.array_annotations['index'], np.array([2, 3, 4]))
  784. assert_arrays_equal(result.array_annotations['label'], np.array(['b', 'c', 'd']))
  785. self.assertIsInstance(result.array_annotations['index'], np.ndarray)
  786. self.assertIsInstance(result.array_annotations['label'], np.ndarray)
  787. self.assertIsInstance(result.array_annotations, ArrayDict)
  788. def test_time_slice_differnt_units(self):
  789. # time_slice spike train, keep sliced spike times
  790. t_start = 0.00012 * pq.s
  791. t_stop = 0.0035 * pq.s
  792. result = self.train1.time_slice(t_start, t_stop)
  793. targ = SpikeTrain([0.5, 1.2, 3.3] * pq.ms, t_stop=3.3)
  794. assert_arrays_equal(result, targ)
  795. targwaveforms = np.array(
  796. [[[2., 3.], [2.1, 3.1]], [[4., 5.], [4.1, 5.1]], [[6., 7.], [6.1, 7.1]]]) * pq.mV
  797. assert_arrays_equal(targwaveforms, result.waveforms)
  798. # but keep everything else pristine
  799. assert_neo_object_is_compliant(result)
  800. self.assertEqual(self.train1.name, result.name)
  801. self.assertEqual(self.train1.description, result.description)
  802. self.assertEqual(self.train1.annotations, result.annotations)
  803. self.assertEqual(self.train1.file_origin, result.file_origin)
  804. self.assertEqual(self.train1.dtype, result.dtype)
  805. self.assertEqual(t_start, result.t_start)
  806. self.assertEqual(t_stop, result.t_stop)
  807. # Array annotations should be updated according to time slice
  808. assert_arrays_equal(result.array_annotations['index'], np.array([2, 3, 4]))
  809. assert_arrays_equal(result.array_annotations['label'], np.array(['b', 'c', 'd']))
  810. self.assertIsInstance(result.array_annotations['index'], np.ndarray)
  811. self.assertIsInstance(result.array_annotations['label'], np.ndarray)
  812. self.assertIsInstance(result.array_annotations, ArrayDict)
  813. def test_time_slice_matching_ends(self):
  814. # time_slice spike train, keep sliced spike times
  815. t_start = 0.1 * pq.ms
  816. t_stop = 7.0 * pq.ms
  817. result = self.train1.time_slice(t_start, t_stop)
  818. assert_arrays_equal(self.train1, result)
  819. assert_arrays_equal(self.waveforms1, result.waveforms)
  820. # but keep everything else pristine
  821. assert_neo_object_is_compliant(result)
  822. self.assertEqual(self.train1.name, result.name)
  823. self.assertEqual(self.train1.description, result.description)
  824. self.assertEqual(self.train1.annotations, result.annotations)
  825. self.assertEqual(self.train1.file_origin, result.file_origin)
  826. self.assertEqual(self.train1.dtype, result.dtype)
  827. self.assertEqual(t_start, result.t_start)
  828. self.assertEqual(t_stop, result.t_stop)
  829. # Array annotations should be updated according to time slice
  830. assert_arrays_equal(result.array_annotations['index'], np.array(self.arr_ann['index']))
  831. assert_arrays_equal(result.array_annotations['label'], np.array(self.arr_ann['label']))
  832. self.assertIsInstance(result.array_annotations['index'], np.ndarray)
  833. self.assertIsInstance(result.array_annotations['label'], np.ndarray)
  834. self.assertIsInstance(result.array_annotations, ArrayDict)
  835. def test_time_slice_out_of_boundries(self):
  836. self.train1.t_start = 0.1 * pq.ms
  837. assert_neo_object_is_compliant(self.train1)
  838. # time_slice spike train, keep sliced spike times
  839. t_start = 0.01 * pq.ms
  840. t_stop = 70.0 * pq.ms
  841. result = self.train1.time_slice(t_start, t_stop)
  842. assert_arrays_equal(self.train1, result)
  843. assert_arrays_equal(self.waveforms1, result.waveforms)
  844. # but keep everything else pristine
  845. assert_neo_object_is_compliant(result)
  846. self.assertEqual(self.train1.name, result.name)
  847. self.assertEqual(self.train1.description, result.description)
  848. self.assertEqual(self.train1.annotations, result.annotations)
  849. self.assertEqual(self.train1.file_origin, result.file_origin)
  850. self.assertEqual(self.train1.dtype, result.dtype)
  851. self.assertEqual(self.train1.t_start, result.t_start)
  852. self.assertEqual(self.train1.t_stop, result.t_stop)
  853. # Array annotations should be updated according to time slice
  854. assert_arrays_equal(result.array_annotations['index'], np.array(self.arr_ann['index']))
  855. assert_arrays_equal(result.array_annotations['label'], np.array(self.arr_ann['label']))
  856. self.assertIsInstance(result.array_annotations['index'], np.ndarray)
  857. self.assertIsInstance(result.array_annotations['label'], np.ndarray)
  858. self.assertIsInstance(result.array_annotations, ArrayDict)
  859. def test_time_slice_empty(self):
  860. waveforms = np.array([[[]]]) * pq.mV
  861. train = SpikeTrain([] * pq.ms, t_stop=10.0, waveforms=waveforms)
  862. assert_neo_object_is_compliant(train)
  863. # time_slice spike train, keep sliced spike times
  864. t_start = 0.01 * pq.ms
  865. t_stop = 70.0 * pq.ms
  866. result = train.time_slice(t_start, t_stop)
  867. assert_arrays_equal(train, result)
  868. assert_arrays_equal(waveforms[:-1], result.waveforms)
  869. # but keep everything else pristine
  870. assert_neo_object_is_compliant(result)
  871. self.assertEqual(train.name, result.name)
  872. self.assertEqual(train.description, result.description)
  873. self.assertEqual(train.annotations, result.annotations)
  874. self.assertEqual(train.file_origin, result.file_origin)
  875. self.assertEqual(train.dtype, result.dtype)
  876. self.assertEqual(t_start, result.t_start)
  877. self.assertEqual(train.t_stop, result.t_stop)
  878. def test_time_slice_none_stop(self):
  879. # time_slice spike train, keep sliced spike times
  880. t_start = 1 * pq.ms
  881. result = self.train1.time_slice(t_start, None)
  882. assert_arrays_equal([1.2, 3.3, 6.4, 7] * pq.ms, result)
  883. targwaveforms = np.array(
  884. [[[4., 5.], [4.1, 5.1]], [[6., 7.], [6.1, 7.1]], [[8., 9.], [8.1, 9.1]],
  885. [[10., 11.], [10.1, 11.1]]]) * pq.mV
  886. assert_arrays_equal(targwaveforms, result.waveforms)
  887. # but keep everything else pristine
  888. assert_neo_object_is_compliant(result)
  889. self.assertEqual(self.train1.name, result.name)
  890. self.assertEqual(self.train1.description, result.description)
  891. self.assertEqual(self.train1.annotations, result.annotations)
  892. self.assertEqual(self.train1.file_origin, result.file_origin)
  893. self.assertEqual(self.train1.dtype, result.dtype)
  894. self.assertEqual(t_start, result.t_start)
  895. self.assertEqual(self.train1.t_stop, result.t_stop)
  896. # Array annotations should be updated according to time slice
  897. assert_arrays_equal(result.array_annotations['index'], np.array([3, 4, 5, 6]))
  898. assert_arrays_equal(result.array_annotations['label'], np.array(['c', 'd', 'e', 'f']))
  899. self.assertIsInstance(result.array_annotations['index'], np.ndarray)
  900. self.assertIsInstance(result.array_annotations['label'], np.ndarray)
  901. self.assertIsInstance(result.array_annotations, ArrayDict)
  902. def test_time_slice_none_start(self):
  903. # time_slice spike train, keep sliced spike times
  904. t_stop = 1 * pq.ms
  905. result = self.train1.time_slice(None, t_stop)
  906. assert_arrays_equal([0.1, 0.5] * pq.ms, result)
  907. targwaveforms = np.array([[[0., 1.], [0.1, 1.1]], [[2., 3.], [2.1, 3.1]]]) * pq.mV
  908. assert_arrays_equal(targwaveforms, result.waveforms)
  909. # but keep everything else pristine
  910. assert_neo_object_is_compliant(result)
  911. self.assertEqual(self.train1.name, result.name)
  912. self.assertEqual(self.train1.description, result.description)
  913. self.assertEqual(self.train1.annotations, result.annotations)
  914. self.assertEqual(self.train1.file_origin, result.file_origin)
  915. self.assertEqual(self.train1.dtype, result.dtype)
  916. self.assertEqual(self.train1.t_start, result.t_start)
  917. self.assertEqual(t_stop, result.t_stop)
  918. # Array annotations should be updated according to time slice
  919. assert_arrays_equal(result.array_annotations['index'], np.array([1, 2]))
  920. assert_arrays_equal(result.array_annotations['label'], np.array(['a', 'b']))
  921. self.assertIsInstance(result.array_annotations['index'], np.ndarray)
  922. self.assertIsInstance(result.array_annotations['label'], np.ndarray)
  923. self.assertIsInstance(result.array_annotations, ArrayDict)
  924. def test_time_slice_none_both(self):
  925. self.train1.t_start = 0.1 * pq.ms
  926. assert_neo_object_is_compliant(self.train1)
  927. # time_slice spike train, keep sliced spike times
  928. result = self.train1.time_slice(None, None)
  929. assert_arrays_equal(self.train1, result)
  930. assert_arrays_equal(self.waveforms1, result.waveforms)
  931. # but keep everything else pristine
  932. assert_neo_object_is_compliant(result)
  933. self.assertEqual(self.train1.name, result.name)
  934. self.assertEqual(self.train1.description, result.description)
  935. self.assertEqual(self.train1.annotations, result.annotations)
  936. self.assertEqual(self.train1.file_origin, result.file_origin)
  937. self.assertEqual(self.train1.dtype, result.dtype)
  938. self.assertEqual(self.train1.t_start, result.t_start)
  939. self.assertEqual(self.train1.t_stop, result.t_stop)
  940. # Array annotations should be updated according to time slice
  941. assert_arrays_equal(result.array_annotations['index'], np.array(self.arr_ann['index']))
  942. assert_arrays_equal(result.array_annotations['label'], np.array(self.arr_ann['label']))
  943. self.assertIsInstance(result.array_annotations['index'], np.ndarray)
  944. self.assertIsInstance(result.array_annotations['label'], np.ndarray)
  945. self.assertIsInstance(result.array_annotations, ArrayDict)
  946. class TestMerge(unittest.TestCase):
  947. def setUp(self):
  948. self.waveforms1 = np.array(
  949. [[[0., 1.], [0.1, 1.1]], [[2., 3.], [2.1, 3.1]], [[4., 5.], [4.1, 5.1]],
  950. [[6., 7.], [6.1, 7.1]], [[8., 9.], [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV
  951. self.data1 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7])
  952. self.data1quant = self.data1 * pq.ms
  953. self.arr_ann1 = {'index': np.arange(1, 7), 'label': ['a', 'b', 'c', 'd', 'e', 'f']}
  954. self.train1 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms, waveforms=self.waveforms1,
  955. array_annotations=self.arr_ann1)
  956. self.waveforms2 = np.array(
  957. [[[0., 1.], [0.1, 1.1]], [[2., 3.], [2.1, 3.1]], [[4., 5.], [4.1, 5.1]],
  958. [[6., 7.], [6.1, 7.1]], [[8., 9.], [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV
  959. self.data2 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7])
  960. self.data2quant = self.data1 * pq.ms
  961. self.arr_ann2 = {'index': np.arange(101, 107), 'label2': ['g', 'h', 'i', 'j', 'k', 'l']}
  962. self.train2 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms, waveforms=self.waveforms1,
  963. array_annotations=self.arr_ann2)
  964. self.segment = Segment()
  965. self.segment.spiketrains.extend([self.train1, self.train2])
  966. self.train1.segment = self.segment
  967. self.train2.segment = self.segment
  968. def test_compliant(self):
  969. assert_neo_object_is_compliant(self.train1)
  970. assert_neo_object_is_compliant(self.train2)
  971. def test_merge_typical(self):
  972. self.train1.waveforms = None
  973. self.train2.waveforms = None
  974. with warnings.catch_warnings(record=True) as w:
  975. result = self.train1.merge(self.train2)
  976. self.assertTrue(len(w) == 1)
  977. self.assertEqual(w[0].category, UserWarning)
  978. self.assertSequenceEqual(str(w[0].message), "The following array annotations were "
  979. "omitted, because they were only present"
  980. " in one of the merged objects: "
  981. "['label'] from the one that was merged "
  982. "into and ['label2'] from the one that "
  983. "was merged into the other")
  984. assert_neo_object_is_compliant(result)
  985. # Make sure array annotations are merged correctly
  986. self.assertTrue('label' not in result.array_annotations)
  987. self.assertTrue('label2' not in result.array_annotations)
  988. assert_arrays_equal(result.array_annotations['index'],
  989. np.array([1, 101, 2, 102, 3, 103, 4, 104, 5, 105, 6, 106]))
  990. self.assertIsInstance(result.array_annotations, ArrayDict)
  991. def test_merge_with_waveforms(self):
  992. # Array annotations merge warning was already tested, can be ignored now
  993. with warnings.catch_warnings(record=True) as w:
  994. result = self.train1.merge(self.train2)
  995. self.assertEqual(len(w), 1)
  996. self.assertTrue("array annotations" in str(w[0].message))
  997. assert_neo_object_is_compliant(result)
  998. def test_correct_shape(self):
  999. # Array annotations merge warning was already tested, can be ignored now
  1000. with warnings.catch_warnings(record=True) as w:
  1001. result = self.train1.merge(self.train2)
  1002. self.assertEqual(len(w), 1)
  1003. self.assertTrue("array annotations" in str(w[0].message))
  1004. self.assertEqual(len(result.shape), 1)
  1005. self.assertEqual(result.shape[0], self.train1.shape[0] + self.train2.shape[0])
  1006. def test_correct_times(self):
  1007. # Array annotations merge warning was already tested, can be ignored now
  1008. with warnings.catch_warnings(record=True) as w:
  1009. result = self.train1.merge(self.train2)
  1010. self.assertEqual(len(w), 1)
  1011. self.assertTrue("array annotations" in str(w[0].message))
  1012. expected = sorted(np.concatenate((self.train1.times, self.train2.times)))
  1013. np.testing.assert_array_equal(result, expected)
  1014. # Make sure array annotations are merged correctly
  1015. self.assertTrue('label' not in result.array_annotations)
  1016. self.assertTrue('label2' not in result.array_annotations)
  1017. assert_arrays_equal(result.array_annotations['index'],
  1018. np.array([1, 101, 2, 102, 3, 103, 4, 104, 5, 105, 6, 106]))
  1019. self.assertIsInstance(result.array_annotations, ArrayDict)
  1020. def test_rescaling_units(self):
  1021. train3 = self.train1.duplicate_with_new_data(self.train1.times.magnitude * pq.microsecond)
  1022. train3.segment = self.train1.segment
  1023. train3.array_annotate(**self.arr_ann1)
  1024. # Array annotations merge warning was already tested, can be ignored now
  1025. with warnings.catch_warnings(record=True) as w:
  1026. result = train3.merge(self.train2)
  1027. self.assertEqual(len(w), 1)
  1028. self.assertTrue("array annotations" in str(w[0].message))
  1029. time_unit = result.units
  1030. expected = sorted(np.concatenate(
  1031. (train3.rescale(time_unit).times, self.train2.rescale(time_unit).times)))
  1032. expected = expected * time_unit
  1033. np.testing.assert_array_equal(result.rescale(time_unit), expected)
  1034. # Make sure array annotations are merged correctly
  1035. self.assertTrue('label' not in result.array_annotations)
  1036. self.assertTrue('label2' not in result.array_annotations)
  1037. assert_arrays_equal(result.array_annotations['index'],
  1038. np.array([1, 2, 3, 4, 5, 6, 101, 102, 103, 104, 105, 106]))
  1039. self.assertIsInstance(result.array_annotations, ArrayDict)
  1040. def test_sampling_rate(self):
  1041. # Array annotations merge warning was already tested, can be ignored now
  1042. with warnings.catch_warnings(record=True) as w:
  1043. result = self.train1.merge(self.train2)
  1044. self.assertEqual(len(w), 1)
  1045. self.assertTrue("array annotations" in str(w[0].message))
  1046. self.assertEqual(result.sampling_rate, self.train1.sampling_rate)
  1047. def test_neo_relations(self):
  1048. # Array annotations merge warning was already tested, can be ignored now
  1049. with warnings.catch_warnings(record=True) as w:
  1050. result = self.train1.merge(self.train2)
  1051. self.assertEqual(len(w), 1)
  1052. self.assertTrue("array annotations" in str(w[0].message))
  1053. self.assertEqual(self.train1.segment, result.segment)
  1054. self.assertTrue(result in result.segment.spiketrains)
  1055. def test_missing_waveforms_error(self):
  1056. self.train1.waveforms = None
  1057. with self.assertRaises(MergeError):
  1058. self.train1.merge(self.train2)
  1059. with self.assertRaises(MergeError):
  1060. self.train2.merge(self.train1)
  1061. def test_incompatible_t_start(self):
  1062. train3 = self.train1.duplicate_with_new_data(self.train1, t_start=-1 * pq.s)
  1063. train3.segment = self.train1.segment
  1064. with self.assertRaises(MergeError):
  1065. train3.merge(self.train2)
  1066. with self.assertRaises(MergeError):
  1067. self.train2.merge(train3)
  1068. class TestDuplicateWithNewData(unittest.TestCase):
  1069. def setUp(self):
  1070. self.waveforms = np.array(
  1071. [[[0., 1.], [0.1, 1.1]], [[2., 3.], [2.1, 3.1]], [[4., 5.], [4.1, 5.1]],
  1072. [[6., 7.], [6.1, 7.1]], [[8., 9.], [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV
  1073. self.data = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7])
  1074. self.dataquant = self.data * pq.ms
  1075. self.arr_ann = {'index': np.arange(6)}
  1076. self.train = SpikeTrain(self.dataquant, t_stop=10.0 * pq.ms, waveforms=self.waveforms,
  1077. array_annotations=self.arr_ann)
  1078. def test_duplicate_with_new_data(self):
  1079. signal1 = self.train
  1080. new_t_start = -10 * pq.s
  1081. new_t_stop = 10 * pq.s
  1082. new_data = np.sort(np.random.uniform(new_t_start.magnitude, new_t_stop.magnitude,
  1083. len(self.train))) * pq.ms
  1084. signal1b = signal1.duplicate_with_new_data(new_data, t_start=new_t_start,
  1085. t_stop=new_t_stop)
  1086. assert_arrays_almost_equal(np.asarray(signal1b), np.asarray(new_data), 1e-12)
  1087. self.assertEqual(signal1b.t_start, new_t_start)
  1088. self.assertEqual(signal1b.t_stop, new_t_stop)
  1089. self.assertEqual(signal1b.sampling_rate, signal1.sampling_rate)
  1090. # After duplicating, array annotations should always be empty,
  1091. # because different length of data would cause inconsistencies
  1092. self.assertEqual(signal1b.array_annotations, {})
  1093. self.assertIsInstance(signal1b.array_annotations, ArrayDict)
  1094. def test_deep_copy_attributes(self):
  1095. signal1 = self.train
  1096. new_t_start = -10 * pq.s
  1097. new_t_stop = 10 * pq.s
  1098. new_data = np.sort(np.random.uniform(new_t_start.magnitude, new_t_stop.magnitude,
  1099. len(self.train))) * pq.ms
  1100. signal1b = signal1.duplicate_with_new_data(new_data, t_start=new_t_start,
  1101. t_stop=new_t_stop)
  1102. signal1.annotate(new_annotation='for signal 1')
  1103. self.assertTrue('new_annotation' not in signal1b.annotations)
  1104. class TestAttributesAnnotations(unittest.TestCase):
  1105. def test_set_universally_recommended_attributes(self):
  1106. train = SpikeTrain([3, 4, 5], units='sec', name='Name', description='Desc',
  1107. file_origin='crack.txt', t_stop=99.9)
  1108. assert_neo_object_is_compliant(train)
  1109. self.assertEqual(train.name, 'Name')
  1110. self.assertEqual(train.description, 'Desc')
  1111. self.assertEqual(train.file_origin, 'crack.txt')
  1112. def test_autoset_universally_recommended_attributes(self):
  1113. train = SpikeTrain([3, 4, 5] * pq.s, t_stop=10.0)
  1114. assert_neo_object_is_compliant(train)
  1115. self.assertEqual(train.name, None)
  1116. self.assertEqual(train.description, None)
  1117. self.assertEqual(train.file_origin, None)
  1118. def test_annotations(self):
  1119. train = SpikeTrain([3, 4, 5] * pq.s, t_stop=11.1)
  1120. assert_neo_object_is_compliant(train)
  1121. self.assertEqual(train.annotations, {})
  1122. train = SpikeTrain([3, 4, 5] * pq.s, t_stop=11.1, ratname='Phillippe')
  1123. assert_neo_object_is_compliant(train)
  1124. self.assertEqual(train.annotations, {'ratname': 'Phillippe'})
  1125. def test_array_annotations(self):
  1126. train = SpikeTrain([3, 4, 5] * pq.s, t_stop=11.1)
  1127. assert_neo_object_is_compliant(train)
  1128. self.assertEqual(train.array_annotations, {})
  1129. self.assertIsInstance(train.array_annotations, ArrayDict)
  1130. train = SpikeTrain([3, 4, 5] * pq.s, t_stop=11.1,
  1131. array_annotations={'ratnames': ['L', 'N', 'E']})
  1132. assert_neo_object_is_compliant(train)
  1133. assert_arrays_equal(train.array_annotations['ratnames'], np.array(['L', 'N', 'E']))
  1134. self.assertIsInstance(train.array_annotations, ArrayDict)
  1135. train.array_annotate(index=[1, 2, 3])
  1136. assert_neo_object_is_compliant(train)
  1137. assert_arrays_equal(train.array_annotations['index'], np.arange(1, 4))
  1138. self.assertIsInstance(train.array_annotations, ArrayDict)
  1139. class TestChanging(unittest.TestCase):
  1140. def test_change_with_copy_default(self):
  1141. # Default is copy = True
  1142. # Changing spike train does not change data
  1143. # Data source is quantity
  1144. data = [3, 4, 5] * pq.s
  1145. train = SpikeTrain(data, t_stop=100.0)
  1146. train[0] = 99 * pq.s
  1147. assert_neo_object_is_compliant(train)
  1148. self.assertEqual(train[0], 99 * pq.s)
  1149. self.assertEqual(data[0], 3 * pq.s)
  1150. def test_change_with_copy_false(self):
  1151. # Changing spike train also changes data, because it is a view
  1152. # Data source is quantity
  1153. data = [3, 4, 5] * pq.s
  1154. train = SpikeTrain(data, copy=False, t_stop=100.0)
  1155. train[0] = 99 * pq.s
  1156. assert_neo_object_is_compliant(train)
  1157. self.assertEqual(train[0], 99 * pq.s)
  1158. self.assertEqual(data[0], 99 * pq.s)
  1159. def test_change_with_copy_false_and_fake_rescale(self):
  1160. # Changing spike train also changes data, because it is a view
  1161. # Data source is quantity
  1162. data = [3000, 4000, 5000] * pq.ms
  1163. # even though we specify units, it still returns a view
  1164. train = SpikeTrain(data, units='ms', copy=False, t_stop=100000)
  1165. train[0] = 99000 * pq.ms
  1166. assert_neo_object_is_compliant(train)
  1167. self.assertEqual(train[0], 99000 * pq.ms)
  1168. self.assertEqual(data[0], 99000 * pq.ms)
  1169. def test_change_with_copy_false_and_rescale_true(self):
  1170. # When rescaling, a view cannot be returned
  1171. # Changing spike train also changes data, because it is a view
  1172. data = [3, 4, 5] * pq.s
  1173. self.assertRaises(ValueError, SpikeTrain, data, units='ms', copy=False, t_stop=10000)
  1174. def test_init_with_rescale(self):
  1175. data = [3, 4, 5] * pq.s
  1176. train = SpikeTrain(data, units='ms', t_stop=6000)
  1177. assert_neo_object_is_compliant(train)
  1178. self.assertEqual(train[0], 3000 * pq.ms)
  1179. self.assertEqual(train._dimensionality, pq.ms._dimensionality)
  1180. self.assertEqual(train.t_stop, 6000 * pq.ms)
  1181. def test_change_with_copy_true(self):
  1182. # Changing spike train does not change data
  1183. # Data source is quantity
  1184. data = [3, 4, 5] * pq.s
  1185. train = SpikeTrain(data, copy=True, t_stop=100)
  1186. train[0] = 99 * pq.s
  1187. assert_neo_object_is_compliant(train)
  1188. self.assertEqual(train[0], 99 * pq.s)
  1189. self.assertEqual(data[0], 3 * pq.s)
  1190. def test_change_with_copy_default_and_data_not_quantity(self):
  1191. # Default is copy = True
  1192. # Changing spike train does not change data
  1193. # Data source is array
  1194. # Array and quantity are tested separately because copy default
  1195. # is different for these two.
  1196. data = [3, 4, 5]
  1197. train = SpikeTrain(data, units='sec', t_stop=100)
  1198. train[0] = 99 * pq.s
  1199. assert_neo_object_is_compliant(train)
  1200. self.assertEqual(train[0], 99 * pq.s)
  1201. self.assertEqual(data[0], 3 * pq.s)
  1202. def test_change_with_copy_false_and_data_not_quantity(self):
  1203. # Changing spike train also changes data, because it is a view
  1204. # Data source is array
  1205. # Array and quantity are tested separately because copy default
  1206. # is different for these two.
  1207. data = np.array([3, 4, 5])
  1208. train = SpikeTrain(data, units='sec', copy=False, dtype=np.int, t_stop=101)
  1209. train[0] = 99 * pq.s
  1210. assert_neo_object_is_compliant(train)
  1211. self.assertEqual(train[0], 99 * pq.s)
  1212. self.assertEqual(data[0], 99)
  1213. def test_change_with_copy_false_and_dtype_change(self):
  1214. # You cannot change dtype and request a view
  1215. data = np.array([3, 4, 5])
  1216. self.assertRaises(ValueError, SpikeTrain, data, units='sec', copy=False, t_stop=101,
  1217. dtype=np.float64)
  1218. def test_change_with_copy_true_and_data_not_quantity(self):
  1219. # Changing spike train does not change data
  1220. # Data source is array
  1221. # Array and quantity are tested separately because copy default
  1222. # is different for these two.
  1223. data = [3, 4, 5]
  1224. train = SpikeTrain(data, units='sec', copy=True, t_stop=123.4)
  1225. train[0] = 99 * pq.s
  1226. assert_neo_object_is_compliant(train)
  1227. self.assertEqual(train[0], 99 * pq.s)
  1228. self.assertEqual(data[0], 3)
  1229. def test_changing_slice_changes_original_spiketrain(self):
  1230. # If we slice a spiketrain and then change the slice, the
  1231. # original spiketrain should change.
  1232. # Whether the original data source changes is dependent on the
  1233. # copy parameter.
  1234. # This is compatible with both np and quantity default behavior.
  1235. data = [3, 4, 5] * pq.s
  1236. train = SpikeTrain(data, copy=True, t_stop=99.9)
  1237. result = train[1:3]
  1238. result[0] = 99 * pq.s
  1239. assert_neo_object_is_compliant(train)
  1240. self.assertEqual(train[1], 99 * pq.s)
  1241. self.assertEqual(result[0], 99 * pq.s)
  1242. self.assertEqual(data[1], 4 * pq.s)
  1243. def test_changing_slice_changes_original_spiketrain_with_copy_false(self):
  1244. # If we slice a spiketrain and then change the slice, the
  1245. # original spiketrain should change.
  1246. # Whether the original data source changes is dependent on the
  1247. # copy parameter.
  1248. # This is compatible with both np and quantity default behavior.
  1249. data = [3, 4, 5] * pq.s
  1250. train = SpikeTrain(data, copy=False, t_stop=100.0)
  1251. result = train[1:3]
  1252. result[0] = 99 * pq.s
  1253. assert_neo_object_is_compliant(train)
  1254. assert_neo_object_is_compliant(result)
  1255. self.assertEqual(train[1], 99 * pq.s)
  1256. self.assertEqual(result[0], 99 * pq.s)
  1257. self.assertEqual(data[1], 99 * pq.s)
  1258. def test__changing_spiketime_should_check_time_in_range(self):
  1259. data = [3, 4, 5] * pq.ms
  1260. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1261. assert_neo_object_is_compliant(train)
  1262. self.assertRaises(ValueError, train.__setitem__, 0, 10.1 * pq.ms)
  1263. self.assertRaises(ValueError, train.__setitem__, 1, 5.0 * pq.s)
  1264. self.assertRaises(ValueError, train.__setitem__, 2, 5.0 * pq.s)
  1265. self.assertRaises(ValueError, train.__setitem__, 0, 0)
  1266. def test__changing_multiple_spiketimes(self):
  1267. data = [3, 4, 5] * pq.ms
  1268. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1269. train[:] = [7, 8, 9] * pq.ms
  1270. assert_neo_object_is_compliant(train)
  1271. assert_arrays_equal(train, np.array([7, 8, 9]))
  1272. def test__changing_multiple_spiketimes_should_check_time_in_range(self):
  1273. data = [3, 4, 5] * pq.ms
  1274. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1275. assert_neo_object_is_compliant(train)
  1276. if sys.version_info[0] == 2:
  1277. self.assertRaises(ValueError, train.__setslice__, 0, 3, [3, 4, 11] * pq.ms)
  1278. self.assertRaises(ValueError, train.__setslice__, 0, 3, [0, 4, 5] * pq.ms)
  1279. def test__adding_time_scalar(self):
  1280. data = [3, 4, 5] * pq.ms
  1281. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1282. assert_neo_object_is_compliant(train)
  1283. # t_start and t_stop are also changed
  1284. self.assertEqual((train + 10 * pq.ms).t_start, 10.5 * pq.ms)
  1285. self.assertEqual((train + 11 * pq.ms).t_stop, 21.0 * pq.ms)
  1286. assert_arrays_equal(train + 1 * pq.ms, data + 1 * pq.ms)
  1287. self.assertIsInstance(train + 10 * pq.ms, SpikeTrain)
  1288. def test__adding_time_array(self):
  1289. data = [3, 4, 5] * pq.ms
  1290. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1291. assert_neo_object_is_compliant(train)
  1292. delta = [-2, 2, 4] * pq.ms
  1293. assert_arrays_equal(train + delta, np.array([1, 6, 9]) * pq.ms)
  1294. self.assertIsInstance(train + delta, SpikeTrain)
  1295. # if new times are within t_start and t_stop, they
  1296. # are not changed
  1297. self.assertEqual((train + delta).t_start, train.t_start)
  1298. self.assertEqual((train + delta).t_stop, train.t_stop)
  1299. # if new times are outside t_start and/or t_stop, these are
  1300. # expanded to fit
  1301. delta = [-4, 2, 6] * pq.ms
  1302. self.assertEqual((train + delta).t_start, -1 * pq.ms)
  1303. self.assertEqual((train + delta).t_stop, 11 * pq.ms)
  1304. def test__adding_two_spike_trains(self):
  1305. data = [3, 4, 5] * pq.ms
  1306. train1 = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1307. train2 = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1308. self.assertRaises(TypeError, train1.__add__, train2)
  1309. def test__subtracting_time_scalar(self):
  1310. data = [3, 4, 5] * pq.ms
  1311. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1312. assert_neo_object_is_compliant(train)
  1313. # t_start and t_stop are also changed
  1314. self.assertEqual((train - 1 * pq.ms).t_start, -0.5 * pq.ms)
  1315. self.assertEqual((train - 3.0 * pq.ms).t_stop, 7.0 * pq.ms)
  1316. assert_arrays_equal(train - 1 * pq.ms, data - 1 * pq.ms)
  1317. self.assertIsInstance(train - 5 * pq.ms, SpikeTrain)
  1318. def test__subtracting_time_array(self):
  1319. data = [3, 4, 5] * pq.ms
  1320. train = SpikeTrain(data, copy=False, t_start=0.5, t_stop=10.0)
  1321. assert_neo_object_is_compliant(train)
  1322. delta = [2, 1, -2] * pq.ms
  1323. self.assertIsInstance(train - delta, SpikeTrain)
  1324. # if new times are within t_start and t_stop, they
  1325. # are not changed
  1326. self.assertEqual((train - delta).t_start, train.t_start)
  1327. self.assertEqual((train - delta).t_stop, train.t_stop)
  1328. # if new times are outside t_start and/or t_stop, these are
  1329. # expanded to fit
  1330. delta = [4, 1, -6] * pq.ms
  1331. self.assertEqual((train - delta).t_start, -1 * pq.ms)
  1332. self.assertEqual((train - delta).t_stop, 11 * pq.ms)
  1333. def test__subtracting_two_spike_trains(self):
  1334. train1 = SpikeTrain([3, 4, 5] * pq.ms, copy=False, t_start=0.5, t_stop=10.0)
  1335. train2 = SpikeTrain([4, 5, 6] * pq.ms, copy=False, t_start=0.5, t_stop=10.0)
  1336. train3 = SpikeTrain([3, 4, 5, 6] * pq.ms, copy=False, t_start=0.5, t_stop=10.0)
  1337. self.assertRaises(TypeError, train1.__sub__, train3)
  1338. self.assertRaises(TypeError, train3.__sub__, train1)
  1339. self.assertIsInstance(train1 - train2, pq.Quantity)
  1340. self.assertNotIsInstance(train1 - train2, SpikeTrain)
  1341. def test__rescale(self):
  1342. data = [3, 4, 5] * pq.ms
  1343. train = SpikeTrain(data, t_start=0.5, t_stop=10.0)
  1344. train.segment = Segment()
  1345. train.unit = Unit()
  1346. self.assertEqual(train.t_start.magnitude, 0.5)
  1347. self.assertEqual(train.t_stop.magnitude, 10.0)
  1348. result = train.rescale(pq.s)
  1349. assert_neo_object_is_compliant(train)
  1350. assert_neo_object_is_compliant(result)
  1351. assert_arrays_equal(train, result)
  1352. self.assertEqual(result.units, 1 * pq.s)
  1353. self.assertIs(result.segment, train.segment)
  1354. self.assertIs(result.unit, train.unit)
  1355. self.assertEqual(result.t_start.magnitude, 0.0005)
  1356. self.assertEqual(result.t_stop.magnitude, 0.01)
  1357. def test__rescale_same_units(self):
  1358. data = [3, 4, 5] * pq.ms
  1359. train = SpikeTrain(data, t_start=0.5, t_stop=10.0)
  1360. result = train.rescale(pq.ms)
  1361. assert_neo_object_is_compliant(train)
  1362. assert_arrays_equal(train, result)
  1363. self.assertEqual(result.units, 1 * pq.ms)
  1364. def test__rescale_incompatible_units_ValueError(self):
  1365. data = [3, 4, 5] * pq.ms
  1366. train = SpikeTrain(data, t_start=0.5, t_stop=10.0)
  1367. assert_neo_object_is_compliant(train)
  1368. self.assertRaises(ValueError, train.rescale, pq.m)
  1369. class TestPropertiesMethods(unittest.TestCase):
  1370. def setUp(self):
  1371. self.data1 = [3, 4, 5]
  1372. self.data1quant = self.data1 * pq.ms
  1373. self.waveforms1 = np.array(
  1374. [[[0., 1.], [0.1, 1.1]], [[2., 3.], [2.1, 3.1]], [[4., 5.], [4.1, 5.1]]]) * pq.mV
  1375. self.t_start1 = 0.5
  1376. self.t_stop1 = 10.0
  1377. self.t_start1quant = self.t_start1 * pq.ms
  1378. self.t_stop1quant = self.t_stop1 * pq.ms
  1379. self.sampling_rate1 = .1 * pq.Hz
  1380. self.left_sweep1 = 2. * pq.s
  1381. self.name1 = 'train 1'
  1382. self.description1 = 'a test object'
  1383. self.ann1 = {'targ0': [1, 2], 'targ1': 1.1}
  1384. self.train1 = SpikeTrain(self.data1quant, t_start=self.t_start1, t_stop=self.t_stop1,
  1385. waveforms=self.waveforms1, left_sweep=self.left_sweep1,
  1386. sampling_rate=self.sampling_rate1, name=self.name1,
  1387. description=self.description1, **self.ann1)
  1388. def test__compliant(self):
  1389. assert_neo_object_is_compliant(self.train1)
  1390. def test__repr(self):
  1391. result = repr(self.train1)
  1392. if np.__version__.split(".")[:2] > ['1', '13']:
  1393. # see https://github.com/numpy/numpy/blob/master/doc/release/1.14.0-notes.rst#many
  1394. # -changes-to-array-printing-disableable-with-the-new-legacy-printing-mode # nopep8
  1395. targ = '<SpikeTrain(array([3., 4., 5.]) * ms, [0.5 ms, 10.0 ms])>'
  1396. else:
  1397. targ = '<SpikeTrain(array([ 3., 4., 5.]) * ms, [0.5 ms, 10.0 ms])>'
  1398. self.assertEqual(result, targ)
  1399. def test__duration(self):
  1400. result1 = self.train1.duration
  1401. self.train1.t_start = None
  1402. assert_neo_object_is_compliant(self.train1)
  1403. result2 = self.train1.duration
  1404. self.train1.t_start = self.t_start1quant
  1405. self.train1.t_stop = None
  1406. assert_neo_object_is_compliant(self.train1)
  1407. result3 = self.train1.duration
  1408. self.assertEqual(result1, 9.5 * pq.ms)
  1409. self.assertEqual(result1.units, 1. * pq.ms)
  1410. self.assertEqual(result2, None)
  1411. self.assertEqual(result3, None)
  1412. def test__spike_duration(self):
  1413. result1 = self.train1.spike_duration
  1414. self.train1.sampling_rate = None
  1415. assert_neo_object_is_compliant(self.train1)
  1416. result2 = self.train1.spike_duration
  1417. self.train1.sampling_rate = self.sampling_rate1
  1418. self.train1.waveforms = None
  1419. assert_neo_object_is_compliant(self.train1)
  1420. result3 = self.train1.spike_duration
  1421. self.assertEqual(result1, 20. / pq.Hz)
  1422. self.assertEqual(result1.units, 1. / pq.Hz)
  1423. self.assertEqual(result2, None)
  1424. self.assertEqual(result3, None)
  1425. def test__sampling_period(self):
  1426. result1 = self.train1.sampling_period
  1427. self.train1.sampling_rate = None
  1428. assert_neo_object_is_compliant(self.train1)
  1429. result2 = self.train1.sampling_period
  1430. self.train1.sampling_rate = self.sampling_rate1
  1431. self.train1.sampling_period = 10. * pq.ms
  1432. assert_neo_object_is_compliant(self.train1)
  1433. result3a = self.train1.sampling_period
  1434. result3b = self.train1.sampling_rate
  1435. self.train1.sampling_period = None
  1436. result4a = self.train1.sampling_period
  1437. result4b = self.train1.sampling_rate
  1438. self.assertEqual(result1, 10. / pq.Hz)
  1439. self.assertEqual(result1.units, 1. / pq.Hz)
  1440. self.assertEqual(result2, None)
  1441. self.assertEqual(result3a, 10. * pq.ms)
  1442. self.assertEqual(result3a.units, 1. * pq.ms)
  1443. self.assertEqual(result3b, .1 / pq.ms)
  1444. self.assertEqual(result3b.units, 1. / pq.ms)
  1445. self.assertEqual(result4a, None)
  1446. self.assertEqual(result4b, None)
  1447. def test__right_sweep(self):
  1448. result1 = self.train1.right_sweep
  1449. self.train1.left_sweep = None
  1450. assert_neo_object_is_compliant(self.train1)
  1451. result2 = self.train1.right_sweep
  1452. self.train1.left_sweep = self.left_sweep1
  1453. self.train1.sampling_rate = None
  1454. assert_neo_object_is_compliant(self.train1)
  1455. result3 = self.train1.right_sweep
  1456. self.train1.sampling_rate = self.sampling_rate1
  1457. self.train1.waveforms = None
  1458. assert_neo_object_is_compliant(self.train1)
  1459. result4 = self.train1.right_sweep
  1460. self.assertEqual(result1, 22. * pq.s)
  1461. self.assertEqual(result1.units, 1. * pq.s)
  1462. self.assertEqual(result2, None)
  1463. self.assertEqual(result3, None)
  1464. self.assertEqual(result4, None)
  1465. def test__times(self):
  1466. result1 = self.train1.times
  1467. self.assertIsInstance(result1, pq.Quantity)
  1468. self.assertTrue((result1 == self.train1).all)
  1469. self.assertEqual(len(result1), len(self.train1))
  1470. self.assertEqual(result1.units, self.train1.units)
  1471. self.assertEqual(result1.dtype, self.train1.dtype)
  1472. def test__children(self):
  1473. segment = Segment(name='seg1')
  1474. segment.spiketrains = [self.train1]
  1475. segment.create_many_to_one_relationship()
  1476. unit = Unit(name='unit1')
  1477. unit.spiketrains = [self.train1]
  1478. unit.create_many_to_one_relationship()
  1479. self.assertEqual(self.train1._single_parent_objects, ('Segment', 'Unit'))
  1480. self.assertEqual(self.train1._multi_parent_objects, ())
  1481. self.assertEqual(self.train1._single_parent_containers, ('segment', 'unit'))
  1482. self.assertEqual(self.train1._multi_parent_containers, ())
  1483. self.assertEqual(self.train1._parent_objects, ('Segment', 'Unit'))
  1484. self.assertEqual(self.train1._parent_containers, ('segment', 'unit'))
  1485. self.assertEqual(len(self.train1.parents), 2)
  1486. self.assertEqual(self.train1.parents[0].name, 'seg1')
  1487. self.assertEqual(self.train1.parents[1].name, 'unit1')
  1488. assert_neo_object_is_compliant(self.train1)
  1489. @unittest.skipUnless(HAVE_IPYTHON, "requires IPython")
  1490. def test__pretty(self):
  1491. res = pretty(self.train1)
  1492. targ = ("SpikeTrain\n" + "name: '%s'\ndescription: '%s'\nannotations: %s"
  1493. "" % (self.name1, self.description1, pretty(self.ann1)))
  1494. self.assertEqual(res, targ)
  1495. class TestMiscellaneous(unittest.TestCase):
  1496. def test__different_dtype_for_t_start_and_array(self):
  1497. data = np.array([0, 9.9999999], dtype=np.float64) * pq.s
  1498. data16 = data.astype(np.float16)
  1499. data32 = data.astype(np.float32)
  1500. data64 = data.astype(np.float64)
  1501. t_start = data[0]
  1502. t_stop = data[1]
  1503. t_start16 = data[0].astype(dtype=np.float16)
  1504. t_stop16 = data[1].astype(dtype=np.float16)
  1505. t_start32 = data[0].astype(dtype=np.float32)
  1506. t_stop32 = data[1].astype(dtype=np.float32)
  1507. t_start64 = data[0].astype(dtype=np.float64)
  1508. t_stop64 = data[1].astype(dtype=np.float64)
  1509. t_start_custom = 0.0
  1510. t_stop_custom = 10.0
  1511. t_start_custom16 = np.array(t_start_custom, dtype=np.float16)
  1512. t_stop_custom16 = np.array(t_stop_custom, dtype=np.float16)
  1513. t_start_custom32 = np.array(t_start_custom, dtype=np.float32)
  1514. t_stop_custom32 = np.array(t_stop_custom, dtype=np.float32)
  1515. t_start_custom64 = np.array(t_start_custom, dtype=np.float64)
  1516. t_stop_custom64 = np.array(t_stop_custom, dtype=np.float64)
  1517. # This is OK.
  1518. train = SpikeTrain(data64, copy=True, t_start=t_start, t_stop=t_stop)
  1519. assert_neo_object_is_compliant(train)
  1520. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop, dtype=np.float16)
  1521. assert_neo_object_is_compliant(train)
  1522. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop, dtype=np.float32)
  1523. assert_neo_object_is_compliant(train)
  1524. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop, dtype=np.float16)
  1525. assert_neo_object_is_compliant(train)
  1526. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop, dtype=np.float32)
  1527. assert_neo_object_is_compliant(train)
  1528. train = SpikeTrain(data32, copy=True, t_start=t_start16, t_stop=t_stop16)
  1529. assert_neo_object_is_compliant(train)
  1530. train = SpikeTrain(data32, copy=True, t_start=t_start16, t_stop=t_stop16, dtype=np.float16)
  1531. assert_neo_object_is_compliant(train)
  1532. train = SpikeTrain(data32, copy=True, t_start=t_start16, t_stop=t_stop16, dtype=np.float32)
  1533. assert_neo_object_is_compliant(train)
  1534. train = SpikeTrain(data32, copy=True, t_start=t_start16, t_stop=t_stop16, dtype=np.float64)
  1535. assert_neo_object_is_compliant(train)
  1536. train = SpikeTrain(data32, copy=True, t_start=t_start32, t_stop=t_stop32)
  1537. assert_neo_object_is_compliant(train)
  1538. train = SpikeTrain(data32, copy=True, t_start=t_start32, t_stop=t_stop32, dtype=np.float16)
  1539. assert_neo_object_is_compliant(train)
  1540. train = SpikeTrain(data32, copy=True, t_start=t_start32, t_stop=t_stop32, dtype=np.float32)
  1541. assert_neo_object_is_compliant(train)
  1542. train = SpikeTrain(data32, copy=True, t_start=t_start32, t_stop=t_stop32, dtype=np.float64)
  1543. assert_neo_object_is_compliant(train)
  1544. train = SpikeTrain(data32, copy=True, t_start=t_start64, t_stop=t_stop64, dtype=np.float16)
  1545. assert_neo_object_is_compliant(train)
  1546. train = SpikeTrain(data32, copy=True, t_start=t_start64, t_stop=t_stop64, dtype=np.float32)
  1547. assert_neo_object_is_compliant(train)
  1548. train = SpikeTrain(data16, copy=True, t_start=t_start_custom, t_stop=t_stop_custom)
  1549. assert_neo_object_is_compliant(train)
  1550. train = SpikeTrain(data16, copy=True, t_start=t_start_custom, t_stop=t_stop_custom,
  1551. dtype=np.float16)
  1552. assert_neo_object_is_compliant(train)
  1553. train = SpikeTrain(data16, copy=True, t_start=t_start_custom, t_stop=t_stop_custom,
  1554. dtype=np.float32)
  1555. assert_neo_object_is_compliant(train)
  1556. train = SpikeTrain(data16, copy=True, t_start=t_start_custom, t_stop=t_stop_custom,
  1557. dtype=np.float64)
  1558. assert_neo_object_is_compliant(train)
  1559. train = SpikeTrain(data32, copy=True, t_start=t_start_custom, t_stop=t_stop_custom)
  1560. assert_neo_object_is_compliant(train)
  1561. train = SpikeTrain(data32, copy=True, t_start=t_start_custom, t_stop=t_stop_custom,
  1562. dtype=np.float16)
  1563. assert_neo_object_is_compliant(train)
  1564. train = SpikeTrain(data32, copy=True, t_start=t_start_custom, t_stop=t_stop_custom,
  1565. dtype=np.float32)
  1566. assert_neo_object_is_compliant(train)
  1567. train = SpikeTrain(data32, copy=True, t_start=t_start_custom, t_stop=t_stop_custom,
  1568. dtype=np.float64)
  1569. assert_neo_object_is_compliant(train)
  1570. train = SpikeTrain(data16, copy=True, t_start=t_start_custom, t_stop=t_stop_custom)
  1571. assert_neo_object_is_compliant(train)
  1572. train = SpikeTrain(data16, copy=True, t_start=t_start_custom, t_stop=t_stop_custom,
  1573. dtype=np.float16)
  1574. assert_neo_object_is_compliant(train)
  1575. train = SpikeTrain(data16, copy=True, t_start=t_start_custom, t_stop=t_stop_custom,
  1576. dtype=np.float32)
  1577. assert_neo_object_is_compliant(train)
  1578. train = SpikeTrain(data16, copy=True, t_start=t_start_custom, t_stop=t_stop_custom,
  1579. dtype=np.float64)
  1580. assert_neo_object_is_compliant(train)
  1581. train = SpikeTrain(data32, copy=True, t_start=t_start_custom16, t_stop=t_stop_custom16)
  1582. assert_neo_object_is_compliant(train)
  1583. train = SpikeTrain(data32, copy=True, t_start=t_start_custom16, t_stop=t_stop_custom16,
  1584. dtype=np.float16)
  1585. assert_neo_object_is_compliant(train)
  1586. train = SpikeTrain(data32, copy=True, t_start=t_start_custom16, t_stop=t_stop_custom16,
  1587. dtype=np.float32)
  1588. assert_neo_object_is_compliant(train)
  1589. train = SpikeTrain(data32, copy=True, t_start=t_start_custom16, t_stop=t_stop_custom16,
  1590. dtype=np.float64)
  1591. assert_neo_object_is_compliant(train)
  1592. train = SpikeTrain(data32, copy=True, t_start=t_start_custom32, t_stop=t_stop_custom32)
  1593. assert_neo_object_is_compliant(train)
  1594. train = SpikeTrain(data32, copy=True, t_start=t_start_custom32, t_stop=t_stop_custom32,
  1595. dtype=np.float16)
  1596. assert_neo_object_is_compliant(train)
  1597. train = SpikeTrain(data32, copy=True, t_start=t_start_custom32, t_stop=t_stop_custom32,
  1598. dtype=np.float32)
  1599. assert_neo_object_is_compliant(train)
  1600. train = SpikeTrain(data32, copy=True, t_start=t_start_custom32, t_stop=t_stop_custom32,
  1601. dtype=np.float64)
  1602. assert_neo_object_is_compliant(train)
  1603. train = SpikeTrain(data32, copy=True, t_start=t_start_custom64, t_stop=t_stop_custom64)
  1604. assert_neo_object_is_compliant(train)
  1605. train = SpikeTrain(data32, copy=True, t_start=t_start_custom64, t_stop=t_stop_custom64,
  1606. dtype=np.float16)
  1607. assert_neo_object_is_compliant(train)
  1608. train = SpikeTrain(data32, copy=True, t_start=t_start_custom64, t_stop=t_stop_custom64,
  1609. dtype=np.float32)
  1610. assert_neo_object_is_compliant(train)
  1611. train = SpikeTrain(data32, copy=True, t_start=t_start_custom64, t_stop=t_stop_custom64,
  1612. dtype=np.float64)
  1613. assert_neo_object_is_compliant(train)
  1614. # This use to bug - see ticket #38
  1615. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop)
  1616. assert_neo_object_is_compliant(train)
  1617. train = SpikeTrain(data16, copy=True, t_start=t_start, t_stop=t_stop, dtype=np.float64)
  1618. assert_neo_object_is_compliant(train)
  1619. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop)
  1620. assert_neo_object_is_compliant(train)
  1621. train = SpikeTrain(data32, copy=True, t_start=t_start, t_stop=t_stop, dtype=np.float64)
  1622. assert_neo_object_is_compliant(train)
  1623. train = SpikeTrain(data32, copy=True, t_start=t_start64, t_stop=t_stop64)
  1624. assert_neo_object_is_compliant(train)
  1625. train = SpikeTrain(data32, copy=True, t_start=t_start64, t_stop=t_stop64, dtype=np.float64)
  1626. assert_neo_object_is_compliant(train)
  1627. def test_as_array(self):
  1628. data = np.arange(10.0)
  1629. st = SpikeTrain(data, t_stop=10.0, units='ms')
  1630. st_as_arr = st.as_array()
  1631. self.assertIsInstance(st_as_arr, np.ndarray)
  1632. assert_array_equal(data, st_as_arr)
  1633. def test_as_quantity(self):
  1634. data = np.arange(10.0)
  1635. st = SpikeTrain(data, t_stop=10.0, units='ms')
  1636. st_as_q = st.as_quantity()
  1637. self.assertIsInstance(st_as_q, pq.Quantity)
  1638. assert_array_equal(data * pq.ms, st_as_q)
  1639. if __name__ == "__main__":
  1640. unittest.main()