test_unit.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of the neo.core.unit.Unit class
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import, division, print_function
  7. try:
  8. import unittest2 as unittest
  9. except ImportError:
  10. import unittest
  11. import numpy as np
  12. try:
  13. from IPython.lib.pretty import pretty
  14. except ImportError as err:
  15. HAVE_IPYTHON = False
  16. else:
  17. HAVE_IPYTHON = True
  18. from neo.core.unit import Unit
  19. from neo.core.container import filterdata
  20. from neo.core import SpikeTrain, ChannelIndex
  21. from neo.test.tools import (assert_neo_object_is_compliant,
  22. assert_arrays_equal,
  23. assert_same_sub_schema)
  24. from neo.test.generate_datasets import (fake_neo, get_fake_value,
  25. get_fake_values, get_annotations,
  26. clone_object, TEST_ANNOTATIONS)
  27. class Test__generate_datasets(unittest.TestCase):
  28. def setUp(self):
  29. np.random.seed(0)
  30. self.annotations = dict([(str(x), TEST_ANNOTATIONS[x]) for x in
  31. range(len(TEST_ANNOTATIONS))])
  32. def test__get_fake_values(self):
  33. self.annotations['seed'] = 0
  34. name = get_fake_value('name', str, seed=0, obj=Unit)
  35. description = get_fake_value('description', str, seed=1, obj='Unit')
  36. file_origin = get_fake_value('file_origin', str)
  37. attrs1 = {'name': name,
  38. 'description': description,
  39. 'file_origin': file_origin}
  40. attrs2 = attrs1.copy()
  41. attrs2.update(self.annotations)
  42. res11 = get_fake_values(Unit, annotate=False, seed=0)
  43. res12 = get_fake_values('Unit', annotate=False, seed=0)
  44. res21 = get_fake_values(Unit, annotate=True, seed=0)
  45. res22 = get_fake_values('Unit', annotate=True, seed=0)
  46. self.assertEqual(res11, attrs1)
  47. self.assertEqual(res12, attrs1)
  48. self.assertEqual(res21, attrs2)
  49. self.assertEqual(res22, attrs2)
  50. def test__fake_neo__cascade(self):
  51. self.annotations['seed'] = None
  52. obj_type = 'Unit'
  53. cascade = True
  54. res = fake_neo(obj_type=obj_type, cascade=cascade)
  55. self.assertTrue(isinstance(res, Unit))
  56. assert_neo_object_is_compliant(res)
  57. self.assertEqual(res.annotations, self.annotations)
  58. self.assertEqual(len(res.spiketrains), 1)
  59. for child in res.children_recur:
  60. del child.annotations['i']
  61. del child.annotations['j']
  62. self.assertEqual(res.spiketrains[0].annotations,
  63. self.annotations)
  64. def test__fake_neo__nocascade(self):
  65. self.annotations['seed'] = None
  66. obj_type = Unit
  67. cascade = False
  68. res = fake_neo(obj_type=obj_type, cascade=cascade)
  69. self.assertTrue(isinstance(res, Unit))
  70. assert_neo_object_is_compliant(res)
  71. self.assertEqual(res.annotations, self.annotations)
  72. self.assertEqual(len(res.spiketrains), 0)
  73. class TestUnit(unittest.TestCase):
  74. def setUp(self):
  75. self.nchildren = 2
  76. self.seed1 = 0
  77. self.seed2 = 10000
  78. self.unit1 = fake_neo(Unit, seed=self.seed1, n=self.nchildren)
  79. self.unit2 = fake_neo(Unit, seed=self.seed2, n=self.nchildren)
  80. self.targobj = self.unit1
  81. self.trains1 = self.unit1.spiketrains
  82. self.trains2 = self.unit2.spiketrains
  83. self.trains1a = clone_object(self.trains1)
  84. def check_creation(self, unit):
  85. assert_neo_object_is_compliant(unit)
  86. seed = unit.annotations['seed']
  87. targ1 = get_fake_value('name', str, seed=seed, obj=Unit)
  88. self.assertEqual(unit.name, targ1)
  89. targ2 = get_fake_value('description', str,
  90. seed=seed+1, obj=Unit)
  91. self.assertEqual(unit.description, targ2)
  92. targ3 = get_fake_value('file_origin', str)
  93. self.assertEqual(unit.file_origin, targ3)
  94. targ4 = get_annotations()
  95. targ4['seed'] = seed
  96. self.assertEqual(unit.annotations, targ4)
  97. self.assertTrue(hasattr(unit, 'spiketrains'))
  98. self.assertEqual(len(unit.spiketrains), self.nchildren)
  99. def test__creation(self):
  100. self.check_creation(self.unit1)
  101. self.check_creation(self.unit2)
  102. def test__merge(self):
  103. unit1a = fake_neo(Unit, seed=self.seed1, n=self.nchildren)
  104. assert_same_sub_schema(self.unit1, unit1a)
  105. unit1a.annotate(seed=self.seed2)
  106. unit1a.spiketrains.append(self.trains2[0])
  107. unit1a.merge(self.unit2)
  108. self.check_creation(self.unit2)
  109. assert_same_sub_schema(self.trains1a + self.trains2,
  110. unit1a.spiketrains)
  111. def test__children(self):
  112. chx = ChannelIndex(index=np.arange(self.nchildren), name='chx1')
  113. chx.units = [self.unit1]
  114. chx.create_many_to_one_relationship()
  115. assert_neo_object_is_compliant(self.unit1)
  116. assert_neo_object_is_compliant(chx)
  117. self.assertEqual(self.unit1._container_child_objects, ())
  118. self.assertEqual(self.unit1._data_child_objects, ('SpikeTrain',))
  119. self.assertEqual(self.unit1._single_parent_objects,
  120. ('ChannelIndex',))
  121. self.assertEqual(self.unit1._multi_child_objects, ())
  122. self.assertEqual(self.unit1._multi_parent_objects, ())
  123. self.assertEqual(self.unit1._child_properties, ())
  124. self.assertEqual(self.unit1._single_child_objects, ('SpikeTrain',))
  125. self.assertEqual(self.unit1._container_child_containers, ())
  126. self.assertEqual(self.unit1._data_child_containers, ('spiketrains',))
  127. self.assertEqual(self.unit1._single_child_containers, ('spiketrains',))
  128. self.assertEqual(self.unit1._single_parent_containers,
  129. ('channel_index',))
  130. self.assertEqual(self.unit1._multi_child_containers, ())
  131. self.assertEqual(self.unit1._multi_parent_containers, ())
  132. self.assertEqual(self.unit1._child_objects, ('SpikeTrain',))
  133. self.assertEqual(self.unit1._child_containers, ('spiketrains',))
  134. self.assertEqual(self.unit1._parent_objects,
  135. ('ChannelIndex',))
  136. self.assertEqual(self.unit1._parent_containers,
  137. ('channel_index',))
  138. self.assertEqual(len(self.unit1._single_children), self.nchildren)
  139. self.assertEqual(len(self.unit1._multi_children), 0)
  140. self.assertEqual(len(self.unit1.data_children), self.nchildren)
  141. self.assertEqual(len(self.unit1.data_children_recur), self.nchildren)
  142. self.assertEqual(len(self.unit1.container_children), 0)
  143. self.assertEqual(len(self.unit1.container_children_recur), 0)
  144. self.assertEqual(len(self.unit1.children), self.nchildren)
  145. self.assertEqual(len(self.unit1.children_recur), self.nchildren)
  146. self.assertEqual(self.unit1._multi_children, ())
  147. self.assertEqual(self.unit1.container_children, ())
  148. self.assertEqual(self.unit1.container_children_recur, ())
  149. assert_same_sub_schema(list(self.unit1._single_children),
  150. self.trains1a)
  151. assert_same_sub_schema(list(self.unit1.data_children),
  152. self.trains1a)
  153. assert_same_sub_schema(list(self.unit1.data_children_recur),
  154. self.trains1a)
  155. assert_same_sub_schema(list(self.unit1.children),
  156. self.trains1a)
  157. assert_same_sub_schema(list(self.unit1.children_recur),
  158. self.trains1a)
  159. self.assertEqual(len(self.unit1.parents), 1)
  160. self.assertEqual(self.unit1.parents[0].name, 'chx1')
  161. def test__size(self):
  162. targ = {'spiketrains': self.nchildren}
  163. self.assertEqual(self.targobj.size, targ)
  164. def test__filter_none(self):
  165. targ = []
  166. res1 = self.targobj.filter()
  167. res2 = self.targobj.filter({})
  168. res3 = self.targobj.filter([])
  169. res4 = self.targobj.filter([{}])
  170. res5 = self.targobj.filter([{}, {}])
  171. res6 = self.targobj.filter([{}, {}])
  172. res7 = self.targobj.filter(targdict={})
  173. res8 = self.targobj.filter(targdict=[])
  174. res9 = self.targobj.filter(targdict=[{}])
  175. res10 = self.targobj.filter(targdict=[{}, {}])
  176. assert_same_sub_schema(res1, targ)
  177. assert_same_sub_schema(res2, targ)
  178. assert_same_sub_schema(res3, targ)
  179. assert_same_sub_schema(res4, targ)
  180. assert_same_sub_schema(res5, targ)
  181. assert_same_sub_schema(res6, targ)
  182. assert_same_sub_schema(res7, targ)
  183. assert_same_sub_schema(res8, targ)
  184. assert_same_sub_schema(res9, targ)
  185. assert_same_sub_schema(res10, targ)
  186. def test__filter_annotation_single(self):
  187. targ = [self.trains1a[1]]
  188. res0 = self.targobj.filter(j=1)
  189. res1 = self.targobj.filter({'j': 1})
  190. res2 = self.targobj.filter(targdict={'j': 1})
  191. res3 = self.targobj.filter([{'j': 1}])
  192. res4 = self.targobj.filter(targdict=[{'j': 1}])
  193. assert_same_sub_schema(res0, targ)
  194. assert_same_sub_schema(res1, targ)
  195. assert_same_sub_schema(res2, targ)
  196. assert_same_sub_schema(res3, targ)
  197. assert_same_sub_schema(res4, targ)
  198. def test__filter_single_annotation_nores(self):
  199. targ = []
  200. res0 = self.targobj.filter(j=5)
  201. res1 = self.targobj.filter({'j': 5})
  202. res2 = self.targobj.filter(targdict={'j': 5})
  203. res3 = self.targobj.filter([{'j': 5}])
  204. res4 = self.targobj.filter(targdict=[{'j': 5}])
  205. assert_same_sub_schema(res0, targ)
  206. assert_same_sub_schema(res1, targ)
  207. assert_same_sub_schema(res2, targ)
  208. assert_same_sub_schema(res3, targ)
  209. assert_same_sub_schema(res4, targ)
  210. def test__filter_attribute_single(self):
  211. targ = [self.trains1a[0]]
  212. name = self.trains1a[0].name
  213. res0 = self.targobj.filter(name=name)
  214. res1 = self.targobj.filter({'name': name})
  215. res2 = self.targobj.filter(targdict={'name': name})
  216. assert_same_sub_schema(res0, targ)
  217. assert_same_sub_schema(res1, targ)
  218. assert_same_sub_schema(res2, targ)
  219. def test__filter_attribute_single_nores(self):
  220. targ = []
  221. name = self.trains2[0].name
  222. res0 = self.targobj.filter(name=name)
  223. res1 = self.targobj.filter({'name': name})
  224. res2 = self.targobj.filter(targdict={'name': name})
  225. assert_same_sub_schema(res0, targ)
  226. assert_same_sub_schema(res1, targ)
  227. assert_same_sub_schema(res2, targ)
  228. def test__filter_multi(self):
  229. targ = [self.trains1a[1], self.trains1a[0]]
  230. name = self.trains1a[0].name
  231. res0 = self.targobj.filter(name=name, j=1)
  232. res1 = self.targobj.filter({'name': name, 'j': 1})
  233. res2 = self.targobj.filter(targdict={'name': name, 'j': 1})
  234. assert_same_sub_schema(res0, targ)
  235. assert_same_sub_schema(res1, targ)
  236. assert_same_sub_schema(res2, targ)
  237. def test__filter_multi_nores(self):
  238. targ = []
  239. name0 = self.trains2[0].name
  240. res0 = self.targobj.filter([{'j': 5}, {}])
  241. res1 = self.targobj.filter({}, j=0)
  242. res2 = self.targobj.filter([{}], i=0)
  243. res3 = self.targobj.filter({'name': name0}, j=1)
  244. res4 = self.targobj.filter(targdict={'name': name0}, j=1)
  245. res5 = self.targobj.filter(name=name0, targdict={'j': 1})
  246. res6 = self.targobj.filter(name=name0, j=5)
  247. res7 = self.targobj.filter({'name': name0, 'j': 5})
  248. res8 = self.targobj.filter(targdict={'name': name0, 'j': 5})
  249. res9 = self.targobj.filter({'name': name0}, j=5)
  250. res10 = self.targobj.filter(targdict={'name': name0}, j=5)
  251. res11 = self.targobj.filter(name=name0, targdict={'j': 5})
  252. res12 = self.targobj.filter({'name': name0}, j=5)
  253. res13 = self.targobj.filter(targdict={'name': name0}, j=5)
  254. res14 = self.targobj.filter(name=name0, targdict={'j': 5})
  255. assert_same_sub_schema(res0, targ)
  256. assert_same_sub_schema(res1, targ)
  257. assert_same_sub_schema(res2, targ)
  258. assert_same_sub_schema(res3, targ)
  259. assert_same_sub_schema(res4, targ)
  260. assert_same_sub_schema(res5, targ)
  261. assert_same_sub_schema(res6, targ)
  262. assert_same_sub_schema(res7, targ)
  263. assert_same_sub_schema(res8, targ)
  264. assert_same_sub_schema(res9, targ)
  265. assert_same_sub_schema(res10, targ)
  266. assert_same_sub_schema(res11, targ)
  267. assert_same_sub_schema(res12, targ)
  268. assert_same_sub_schema(res13, targ)
  269. assert_same_sub_schema(res14, targ)
  270. def test__filter_multi_partres(self):
  271. targ = [self.trains1a[0]]
  272. name = self.trains1a[0].name
  273. res0 = self.targobj.filter(name=name, j=5)
  274. res1 = self.targobj.filter({'name': name, 'j': 5})
  275. res2 = self.targobj.filter(targdict={'name': name, 'j': 5})
  276. res3 = self.targobj.filter([{'j': 0}, {'i': 0}])
  277. res4 = self.targobj.filter({'j': 0}, i=0)
  278. res5 = self.targobj.filter([{'j': 0}], i=0)
  279. assert_same_sub_schema(res0, targ)
  280. assert_same_sub_schema(res1, targ)
  281. assert_same_sub_schema(res2, targ)
  282. assert_same_sub_schema(res3, targ)
  283. assert_same_sub_schema(res4, targ)
  284. assert_same_sub_schema(res5, targ)
  285. def test__filter_single_annotation_obj_single(self):
  286. targ = [self.trains1a[1]]
  287. res0 = self.targobj.filter(j=1, objects='SpikeTrain')
  288. res1 = self.targobj.filter(j=1, objects=SpikeTrain)
  289. res2 = self.targobj.filter(j=1, objects=['SpikeTrain'])
  290. res3 = self.targobj.filter(j=1, objects=[SpikeTrain])
  291. res4 = self.targobj.filter(j=1, objects=[SpikeTrain,
  292. ChannelIndex])
  293. assert_same_sub_schema(res0, targ)
  294. assert_same_sub_schema(res1, targ)
  295. assert_same_sub_schema(res2, targ)
  296. assert_same_sub_schema(res3, targ)
  297. assert_same_sub_schema(res4, targ)
  298. def test__filter_single_annotation_obj_none(self):
  299. targ = []
  300. res0 = self.targobj.filter(j=1, objects=ChannelIndex)
  301. res1 = self.targobj.filter(j=1, objects='ChannelIndex')
  302. res2 = self.targobj.filter(j=1, objects=[])
  303. assert_same_sub_schema(res0, targ)
  304. assert_same_sub_schema(res1, targ)
  305. assert_same_sub_schema(res2, targ)
  306. def test__filter_single_annotation_norecur(self):
  307. targ = [self.trains1a[1]]
  308. res0 = self.targobj.filter(j=1, recursive=False)
  309. assert_same_sub_schema(res0, targ)
  310. def test__filter_single_attribute_norecur(self):
  311. targ = [self.trains1a[0]]
  312. res0 = self.targobj.filter(name=self.trains1a[0].name, recursive=False)
  313. assert_same_sub_schema(res0, targ)
  314. def test__filter_single_annotation_nodata(self):
  315. targ = []
  316. res0 = self.targobj.filter(j=1, data=False)
  317. assert_same_sub_schema(res0, targ)
  318. def test__filter_single_attribute_nodata(self):
  319. targ = []
  320. res0 = self.targobj.filter(name=self.trains1a[0].name, data=False)
  321. assert_same_sub_schema(res0, targ)
  322. def test__filter_single_annotation_nodata_norecur(self):
  323. targ = []
  324. res0 = self.targobj.filter(j=1,
  325. data=False, recursive=False)
  326. assert_same_sub_schema(res0, targ)
  327. def test__filter_single_attribute_nodata_norecur(self):
  328. targ = []
  329. res0 = self.targobj.filter(name=self.trains1a[0].name,
  330. data=False, recursive=False)
  331. assert_same_sub_schema(res0, targ)
  332. def test__filter_single_annotation_container(self):
  333. targ = [self.trains1a[1]]
  334. res0 = self.targobj.filter(j=1, container=True)
  335. assert_same_sub_schema(res0, targ)
  336. def test__filter_single_attribute_container(self):
  337. targ = [self.trains1a[0]]
  338. res0 = self.targobj.filter(name=self.trains1a[0].name, container=True)
  339. assert_same_sub_schema(res0, targ)
  340. def test__filter_single_annotation_container_norecur(self):
  341. targ = [self.trains1a[1]]
  342. res0 = self.targobj.filter(j=1, container=True, recursive=False)
  343. assert_same_sub_schema(res0, targ)
  344. def test__filter_single_attribute_container_norecur(self):
  345. targ = [self.trains1a[0]]
  346. res0 = self.targobj.filter(name=self.trains1a[0].name,
  347. container=True, recursive=False)
  348. assert_same_sub_schema(res0, targ)
  349. def test__filter_single_annotation_nodata_container(self):
  350. targ = []
  351. res0 = self.targobj.filter(j=1,
  352. data=False, container=True)
  353. assert_same_sub_schema(res0, targ)
  354. def test__filter_single_attribute_nodata_container(self):
  355. targ = []
  356. res0 = self.targobj.filter(name=self.trains1a[0].name,
  357. data=False, container=True)
  358. assert_same_sub_schema(res0, targ)
  359. def test__filter_single_annotation_nodata_container_norecur(self):
  360. targ = []
  361. res0 = self.targobj.filter(j=1,
  362. data=False, container=True,
  363. recursive=False)
  364. assert_same_sub_schema(res0, targ)
  365. def test__filter_single_attribute_nodata_container_norecur(self):
  366. targ = []
  367. res0 = self.targobj.filter(name=self.trains1a[0].name,
  368. data=False, container=True,
  369. recursive=False)
  370. assert_same_sub_schema(res0, targ)
  371. def test__filterdata_multi(self):
  372. data = self.targobj.children_recur
  373. targ = [self.trains1a[1], self.trains1a[0]]
  374. name = self.trains1a[0].name
  375. res0 = filterdata(data, name=name, j=1)
  376. res1 = filterdata(data, {'name': name, 'j': 1})
  377. res2 = filterdata(data, targdict={'name': name, 'j': 1})
  378. assert_same_sub_schema(res0, targ)
  379. assert_same_sub_schema(res1, targ)
  380. assert_same_sub_schema(res2, targ)
  381. def test__filterdata_multi_nores(self):
  382. data = self.targobj.children_recur
  383. targ = []
  384. name1 = self.trains1a[0].name
  385. name2 = self.trains2[0].name
  386. res0 = filterdata(data, [{'j': 0}, {}])
  387. res1 = filterdata(data, {}, i=0)
  388. res2 = filterdata(data, [{}], i=0)
  389. res3 = filterdata(data, name=name1, targdict={'j': 1})
  390. res4 = filterdata(data, {'name': name1}, j=1)
  391. res5 = filterdata(data, targdict={'name': name1}, j=1)
  392. res6 = filterdata(data, name=name2, j=5)
  393. res7 = filterdata(data, {'name': name2, 'j': 5})
  394. res8 = filterdata(data, targdict={'name': name2, 'j': 5})
  395. res9 = filterdata(data, {'name': name2}, j=5)
  396. res10 = filterdata(data, targdict={'name': name2}, j=5)
  397. res11 = filterdata(data, name=name2, targdict={'j': 5})
  398. res12 = filterdata(data, {'name': name1}, j=5)
  399. res13 = filterdata(data, targdict={'name': name1}, j=5)
  400. res14 = filterdata(data, name=name1, targdict={'j': 5})
  401. assert_same_sub_schema(res0, targ)
  402. assert_same_sub_schema(res1, targ)
  403. assert_same_sub_schema(res2, targ)
  404. assert_same_sub_schema(res3, targ)
  405. assert_same_sub_schema(res4, targ)
  406. assert_same_sub_schema(res5, targ)
  407. assert_same_sub_schema(res6, targ)
  408. assert_same_sub_schema(res7, targ)
  409. assert_same_sub_schema(res8, targ)
  410. assert_same_sub_schema(res9, targ)
  411. assert_same_sub_schema(res10, targ)
  412. assert_same_sub_schema(res11, targ)
  413. assert_same_sub_schema(res12, targ)
  414. assert_same_sub_schema(res13, targ)
  415. assert_same_sub_schema(res14, targ)
  416. def test__filterdata_multi_partres(self):
  417. data = self.targobj.children_recur
  418. targ = [self.trains1a[0]]
  419. name = self.trains1a[0].name
  420. res0 = filterdata(data, name=name, j=5)
  421. res1 = filterdata(data, {'name': name, 'j': 5})
  422. res2 = filterdata(data, targdict={'name': name, 'j': 5})
  423. res3 = filterdata(data, [{'j': 0}, {'i': 0}])
  424. res4 = filterdata(data, {'j': 0}, i=0)
  425. res5 = filterdata(data, [{'j': 0}], i=0)
  426. assert_same_sub_schema(res0, targ)
  427. assert_same_sub_schema(res1, targ)
  428. assert_same_sub_schema(res2, targ)
  429. assert_same_sub_schema(res3, targ)
  430. assert_same_sub_schema(res4, targ)
  431. assert_same_sub_schema(res5, targ)
  432. # @unittest.skipUnless(HAVE_IPYTHON, "requires IPython")
  433. # def test__pretty(self):
  434. # res = pretty(self.unit1)
  435. # ann = get_annotations()
  436. # ann['seed'] = self.seed1
  437. # ann = pretty(ann).replace('\n ', '\n ')
  438. # targ = ("Unit with " +
  439. # ("%s spiketrains\n" % len(self.trains1a)) +
  440. # ("name: '%s'\ndescription: '%s'\n" % (self.unit1.name,
  441. # self.unit1.description)
  442. # ) +
  443. # ("annotations: %s" % ann))
  444. #
  445. # self.assertEqual(res, targ)
  446. if __name__ == "__main__":
  447. unittest.main()