tools.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # -*- coding: utf-8 -*-
  2. '''
  3. Tools for use with neo tests.
  4. '''
  5. import hashlib
  6. import os
  7. import numpy as np
  8. import quantities as pq
  9. import neo
  10. from neo.core import objectlist
  11. from neo.core.baseneo import _reference_name, _container_name
  12. def assert_arrays_equal(a, b, dtype=False):
  13. '''
  14. Check if two arrays have the same shape and contents.
  15. If dtype is True (default=False), then also theck that they have the same
  16. dtype.
  17. '''
  18. assert isinstance(a, np.ndarray), "a is a %s" % type(a)
  19. assert isinstance(b, np.ndarray), "b is a %s" % type(b)
  20. assert a.shape == b.shape, "%s != %s" % (a, b)
  21. #assert a.dtype == b.dtype, "%s and %s not same dtype %s %s" % (a, b,
  22. # a.dtype,
  23. # b.dtype)
  24. try:
  25. assert (a.flatten() == b.flatten()).all(), "%s != %s" % (a, b)
  26. except (AttributeError, ValueError):
  27. try:
  28. ar = np.array(a)
  29. br = np.array(b)
  30. assert (ar.flatten() == br.flatten()).all(), "%s != %s" % (ar, br)
  31. except (AttributeError, ValueError):
  32. assert np.all(a.flatten() == b.flatten()), "%s != %s" % (a, b)
  33. if dtype:
  34. assert a.dtype == b.dtype, \
  35. "%s and %s not same dtype %s and %s" % (a, b, a.dtype, b.dtype)
  36. def assert_arrays_almost_equal(a, b, threshold, dtype=False):
  37. '''
  38. Check if two arrays have the same shape and contents that differ
  39. by abs(a - b) <= threshold for all elements.
  40. If threshold is None, do an absolute comparison rather than a relative
  41. comparison.
  42. '''
  43. if threshold is None:
  44. return assert_arrays_equal(a, b, dtype=dtype)
  45. assert isinstance(a, np.ndarray), "a is a %s" % type(a)
  46. assert isinstance(b, np.ndarray), "b is a %s" % type(b)
  47. assert a.shape == b.shape, "%s != %s" % (a, b)
  48. #assert a.dtype == b.dtype, "%s and %b not same dtype %s %s" % (a, b,
  49. # a.dtype,
  50. # b.dtype)
  51. if a.dtype.kind in ['f', 'c', 'i']:
  52. assert (abs(a - b) < threshold).all(), \
  53. "abs(%s - %s) max(|a - b|) = %s threshold:%s" % \
  54. (a, b, (abs(a - b)).max(), threshold)
  55. if dtype:
  56. assert a.dtype == b.dtype, \
  57. "%s and %s not same dtype %s and %s" % (a, b, a.dtype, b.dtype)
  58. def file_digest(filename):
  59. '''
  60. Get the sha1 hash of the file with the given filename.
  61. '''
  62. with open(filename, 'rb') as fobj:
  63. return hashlib.sha1(fobj.read()).hexdigest()
  64. def assert_file_contents_equal(a, b):
  65. '''
  66. Assert that two files have the same size and hash.
  67. '''
  68. def generate_error_message(a, b):
  69. '''
  70. This creates the error message for the assertion error
  71. '''
  72. size_a = os.stat(a).st_size
  73. size_b = os.stat(b).st_size
  74. if size_a == size_b:
  75. return "Files have the same size but different contents"
  76. else:
  77. return "Files have different sizes: a:%d b: %d" % (size_a, size_b)
  78. assert file_digest(a) == file_digest(b), generate_error_message(a, b)
  79. def assert_neo_object_is_compliant(ob):
  80. '''
  81. Test neo compliance of one object and sub objects
  82. (one_to_many_relation only):
  83. * check types and/or presence of necessary and recommended attribute.
  84. * If attribute is Quantities or numpy.ndarray it also check ndim.
  85. * If attribute is numpy.ndarray also check dtype.kind.
  86. '''
  87. assert type(ob) in objectlist, \
  88. '%s is not a neo object' % (type(ob))
  89. classname = ob.__class__.__name__
  90. # test presence of necessary attributes
  91. for ioattr in ob._necessary_attrs:
  92. attrname, attrtype = ioattr[0], ioattr[1]
  93. #~ if attrname != '':
  94. if not hasattr(ob, '_quantity_attr'):
  95. assert hasattr(ob, attrname), '%s neo obect does not have %s' % \
  96. (classname, attrname)
  97. # test attributes types
  98. for ioattr in ob._all_attrs:
  99. attrname, attrtype = ioattr[0], ioattr[1]
  100. if (hasattr(ob, '_quantity_attr') and
  101. ob._quantity_attr == attrname and
  102. (attrtype == pq.Quantity or attrtype == np.ndarray)):
  103. # object inherits from Quantity (AnalogSignal, SpikeTrain, ...)
  104. ndim = ioattr[2]
  105. assert ob.ndim == ndim, \
  106. '%s dimension is %d should be %d' % (classname, ob.ndim, ndim)
  107. if attrtype == np.ndarray:
  108. dtp = ioattr[3]
  109. assert ob.dtype.kind == dtp.kind, \
  110. '%s dtype.kind is %s should be %s' % (classname,
  111. ob.dtype.kind,
  112. dtp.kind)
  113. elif hasattr(ob, attrname):
  114. if getattr(ob, attrname) is not None:
  115. obattr = getattr(ob, attrname)
  116. assert issubclass(type(obattr), attrtype), \
  117. '%s in %s is %s should be %s' % \
  118. (attrname, classname, type(obattr), attrtype)
  119. if attrtype == pq.Quantity or attrtype == np.ndarray:
  120. ndim = ioattr[2]
  121. assert obattr.ndim == ndim, \
  122. '%s.%s dimension is %d should be %d' % \
  123. (classname, attrname, obattr.ndim, ndim)
  124. if attrtype == np.ndarray:
  125. dtp = ioattr[3]
  126. assert obattr.dtype.kind == dtp.kind, \
  127. '%s.%s dtype.kind is %s should be %s' % \
  128. (classname, attrname, obattr.dtype.kind, dtp.kind)
  129. # test bijectivity : parents and children
  130. for container in getattr(ob, '_single_child_containers', []):
  131. for i, child in enumerate(getattr(ob, container, [])):
  132. assert hasattr(child, _reference_name(classname)), \
  133. '%s should have %s attribute (2 way relationship)' % \
  134. (container, _reference_name(classname))
  135. if hasattr(child, _reference_name(classname)):
  136. parent = getattr(child, _reference_name(classname))
  137. assert parent == ob, \
  138. '%s.%s %s is not symetric with %s.%s' % \
  139. (container, _reference_name(classname), i,
  140. classname, container)
  141. # recursive on one to many rel
  142. for i, child in enumerate(getattr(ob, 'children', [])):
  143. try:
  144. assert_neo_object_is_compliant(child)
  145. # intercept exceptions and add more information
  146. except BaseException as exc:
  147. exc.args += ('from %s %s of %s' % (child.__class__.__name__, i,
  148. classname),)
  149. raise
  150. def assert_same_sub_schema(ob1, ob2, equal_almost=True, threshold=1e-10,
  151. exclude=None):
  152. '''
  153. Test if ob1 and ob2 has the same sub schema.
  154. Explore all parent/child relationships.
  155. Many_to_many_relationship is not tested
  156. because of infinite recursive loops.
  157. Arguments:
  158. equal_almost: if False do a strict arrays_equal if
  159. True do arrays_almost_equal
  160. exclude: a list of attributes and annotations to ignore in
  161. the comparison
  162. '''
  163. assert type(ob1) == type(ob2), 'type(%s) != type(%s)' % (type(ob1),
  164. type(ob2))
  165. classname = ob1.__class__.__name__
  166. if exclude is None:
  167. exclude = []
  168. if isinstance(ob1, list):
  169. assert len(ob1) == len(ob2), \
  170. 'lens %s and %s not equal for %s and %s' % \
  171. (len(ob1), len(ob2), ob1, ob2)
  172. for i, (sub1, sub2) in enumerate(zip(ob1, ob2)):
  173. try:
  174. assert_same_sub_schema(sub1, sub2, equal_almost=equal_almost,
  175. threshold=threshold, exclude=exclude)
  176. # intercept exceptions and add more information
  177. except BaseException as exc:
  178. exc.args += ('%s[%s]' % (classname, i),)
  179. raise
  180. return
  181. # test parent/child relationship
  182. for container in getattr(ob1, '_single_child_containers', []):
  183. if container in exclude:
  184. continue
  185. if not hasattr(ob1, container):
  186. assert not hasattr(ob2, container), \
  187. '%s 2 does have %s but not %s 1' % (classname, container,
  188. classname)
  189. continue
  190. else:
  191. assert hasattr(ob2, container), \
  192. '%s 1 has %s but not %s 2' % (classname, container,
  193. classname)
  194. sub1 = getattr(ob1, container)
  195. sub2 = getattr(ob2, container)
  196. assert len(sub1) == len(sub2), \
  197. 'theses two %s do not have the same %s number: %s and %s' % \
  198. (classname, container, len(sub1), len(sub2))
  199. for i in range(len(getattr(ob1, container))):
  200. # previously lacking parameter
  201. try:
  202. assert_same_sub_schema(sub1[i], sub2[i],
  203. equal_almost=equal_almost,
  204. threshold=threshold,
  205. exclude=exclude)
  206. # intercept exceptions and add more information
  207. except BaseException as exc:
  208. exc.args += ('from %s[%s] of %s' % (container, i,
  209. classname),)
  210. raise
  211. assert_same_attributes(ob1, ob2, equal_almost=equal_almost,
  212. threshold=threshold, exclude=exclude)
  213. def assert_same_attributes(ob1, ob2, equal_almost=True, threshold=1e-10,
  214. exclude=None):
  215. '''
  216. Test if ob1 and ob2 has the same attributes.
  217. Arguments:
  218. equal_almost: if False do a strict arrays_equal if
  219. True do arrays_almost_equal
  220. exclude: a list of attributes and annotations to ignore in
  221. the comparison
  222. '''
  223. classname = ob1.__class__.__name__
  224. if exclude is None:
  225. exclude = []
  226. if not equal_almost:
  227. threshold = None
  228. dtype = True
  229. else:
  230. dtype = False
  231. for ioattr in ob1._all_attrs:
  232. if ioattr[0] in exclude:
  233. continue
  234. attrname, attrtype = ioattr[0], ioattr[1]
  235. #~ if attrname =='':
  236. if hasattr(ob1, '_quantity_attr') and ob1._quantity_attr == attrname:
  237. # object is hinerited from Quantity (AnalogSignal, SpikeTrain, ...)
  238. try:
  239. assert_arrays_almost_equal(ob1.magnitude, ob2.magnitude,
  240. threshold=threshold,
  241. dtype=dtype)
  242. # intercept exceptions and add more information
  243. except BaseException as exc:
  244. exc.args += ('from %s %s' % (classname, attrname),)
  245. raise
  246. assert ob1.dimensionality.string == ob2.dimensionality.string, \
  247. 'Units of %s %s are not the same: %s and %s' % \
  248. (classname, attrname,
  249. ob1.dimensionality.string, ob2.dimensionality.string)
  250. continue
  251. if not hasattr(ob1, attrname):
  252. assert not hasattr(ob2, attrname), \
  253. '%s 2 does have %s but not %s 1' % (classname, attrname,
  254. classname)
  255. continue
  256. else:
  257. assert hasattr(ob2, attrname), \
  258. '%s 1 has %s but not %s 2' % (classname, attrname, classname)
  259. if getattr(ob1, attrname) is None:
  260. assert getattr(ob2, attrname) is None, \
  261. 'In %s.%s %s and %s differed' % (classname, attrname,
  262. getattr(ob1, attrname),
  263. getattr(ob2, attrname))
  264. continue
  265. if getattr(ob2, attrname) is None:
  266. assert getattr(ob1, attrname) is None, \
  267. 'In %s.%s %s and %s differed' % (classname, attrname,
  268. getattr(ob1, attrname),
  269. getattr(ob2, attrname))
  270. continue
  271. if attrtype == pq.Quantity:
  272. # Compare magnitudes
  273. mag1 = getattr(ob1, attrname).magnitude
  274. mag2 = getattr(ob2, attrname).magnitude
  275. #print "2. ob1(%s) %s:%s\n ob2(%s) %s:%s" % \
  276. #(ob1,attrname,mag1,ob2,attrname,mag2)
  277. try:
  278. assert_arrays_almost_equal(mag1, mag2,
  279. threshold=threshold,
  280. dtype=dtype)
  281. # intercept exceptions and add more information
  282. except BaseException as exc:
  283. exc.args += ('from %s of %s' % (attrname, classname),)
  284. raise
  285. # Compare dimensionalities
  286. dim1 = getattr(ob1, attrname).dimensionality.simplified
  287. dim2 = getattr(ob2, attrname).dimensionality.simplified
  288. dimstr1 = getattr(ob1, attrname).dimensionality.string
  289. dimstr2 = getattr(ob2, attrname).dimensionality.string
  290. assert dim1 == dim2, \
  291. 'Attribute %s of %s are not the same: %s != %s' % \
  292. (attrname, classname, dimstr1, dimstr2)
  293. elif attrtype == np.ndarray:
  294. try:
  295. assert_arrays_almost_equal(getattr(ob1, attrname),
  296. getattr(ob2, attrname),
  297. threshold=threshold,
  298. dtype=dtype)
  299. # intercept exceptions and add more information
  300. except BaseException as exc:
  301. exc.args += ('from %s of %s' % (attrname, classname),)
  302. raise
  303. else:
  304. #~ print 'yep', getattr(ob1, attrname), getattr(ob2, attrname)
  305. assert getattr(ob1, attrname) == getattr(ob2, attrname), \
  306. 'Attribute %s.%s are not the same %s %s %s %s' % \
  307. (classname, attrname,
  308. type(getattr(ob1, attrname)), getattr(ob1, attrname),
  309. type(getattr(ob2, attrname)), getattr(ob2, attrname))
  310. def assert_same_annotations(ob1, ob2, equal_almost=True, threshold=1e-10,
  311. exclude=None):
  312. '''
  313. Test if ob1 and ob2 has the same annotations.
  314. Arguments:
  315. equal_almost: if False do a strict arrays_equal if
  316. True do arrays_almost_equal
  317. exclude: a list of attributes and annotations to ignore in
  318. the comparison
  319. '''
  320. if exclude is None:
  321. exclude = []
  322. if not equal_almost:
  323. threshold = None
  324. dtype = False
  325. else:
  326. dtype = True
  327. for key in ob2.annotations:
  328. if key in exclude:
  329. continue
  330. assert key in ob1.annotations
  331. for key, value in ob1.annotations.items():
  332. if key in exclude:
  333. continue
  334. assert key in ob2.annotations
  335. try:
  336. assert value == ob2.annotations[key]
  337. except ValueError:
  338. assert_arrays_almost_equal(ob1, ob2,
  339. threshold=threshold, dtype=False)
  340. def assert_sub_schema_is_lazy_loaded(ob):
  341. '''
  342. This is util for testing lazy load. All object must load with ndarray.size
  343. or Quantity.size ==0
  344. '''
  345. classname = ob.__class__.__name__
  346. for container in getattr(ob, '_single_child_containers', []):
  347. if not hasattr(ob, container):
  348. continue
  349. sub = getattr(ob, container)
  350. for i, child in enumerate(sub):
  351. try:
  352. assert_sub_schema_is_lazy_loaded(child)
  353. # intercept exceptions and add more information
  354. except BaseException as exc:
  355. exc.args += ('from %s %s of %s' % (container, i, classname),)
  356. raise
  357. for ioattr in ob._all_attrs:
  358. attrname, attrtype = ioattr[0], ioattr[1]
  359. #~ print 'xdsd', classname, attrname
  360. #~ if attrname == '':
  361. if hasattr(ob, '_quantity_attr') and ob._quantity_attr == attrname:
  362. assert ob.size == 0, \
  363. 'Lazy loaded error %s.size = %s' % (classname, ob.size)
  364. assert hasattr(ob, 'lazy_shape'), \
  365. 'Lazy loaded error, %s should have lazy_shape attribute' % \
  366. classname
  367. continue
  368. if not hasattr(ob, attrname) or getattr(ob, attrname) is None:
  369. continue
  370. #~ print 'hjkjh'
  371. if (attrtype == pq.Quantity or attrtype == np.ndarray):
  372. # FIXME: it is a workaround for recordingChannelGroup.channel_names
  373. # which is nupy.array but allowed to be loaded when lazy == True
  374. if ob.__class__ == neo.ChannelIndex:
  375. continue
  376. ndim = ioattr[2]
  377. #~ print 'ndim', ndim
  378. #~ print getattr(ob, attrname).size
  379. if ndim >= 1:
  380. assert getattr(ob, attrname).size == 0, \
  381. 'Lazy loaded error %s.%s.size = %s' % \
  382. (classname, attrname, getattr(ob, attrname).size)
  383. assert hasattr(ob, 'lazy_shape'), \
  384. 'Lazy loaded error ' +\
  385. '%s should have lazy_shape attribute ' % classname +\
  386. 'because of %s attribute' % attrname
  387. lazy_shape_arrays = {'SpikeTrain': 'times',
  388. 'AnalogSignal': 'signal',
  389. 'Event': 'times', 'Epoch': 'times'}
  390. def assert_lazy_sub_schema_can_be_loaded(ob, io):
  391. '''
  392. This is util for testing lazy load. All object must load with ndarray.size
  393. or Quantity.size ==0
  394. '''
  395. classname = ob.__class__.__name__
  396. if classname in lazy_shape_arrays:
  397. new_load = io.load_lazy_object(ob)
  398. assert hasattr(ob, 'lazy_shape'), \
  399. 'Object %s was not lazy loaded' % classname
  400. assert not hasattr(new_load, 'lazy_shape'), \
  401. 'Newly loaded object from %s was also lazy loaded' % classname
  402. if hasattr(ob, '_quantity_attr'):
  403. assert ob.lazy_shape == new_load.shape, \
  404. 'Shape of loaded object %sis not equal to lazy shape' % \
  405. classname
  406. else:
  407. assert ob.lazy_shape == \
  408. getattr(new_load, lazy_shape_arrays[classname]).shape, \
  409. 'Shape of loaded object %s not equal to lazy shape' %\
  410. classname
  411. return
  412. for container in getattr(ob, '_single_child_containers', []):
  413. if not hasattr(ob, container):
  414. continue
  415. sub = getattr(ob, container)
  416. for i, child in enumerate(sub):
  417. try:
  418. assert_lazy_sub_schema_can_be_loaded(child, io)
  419. # intercept exceptions and add more information
  420. except BaseException as exc:
  421. exc.args += ('from of %s %s of %s' %
  422. (container, i, classname),)
  423. raise
  424. def assert_objects_equivalent(obj1, obj2):
  425. '''
  426. Compares two NEO objects by looping over the attributes and annotations
  427. and asserting their hashes. No relationships involved.
  428. '''
  429. def assert_attr(obj1, obj2, attr_name):
  430. '''
  431. Assert a single attribute and annotation are the same
  432. '''
  433. assert hasattr(obj1, attr_name)
  434. attr1 = hashlib.md5(getattr(obj1, attr_name)).hexdigest()
  435. assert hasattr(obj2, attr_name)
  436. attr2 = hashlib.md5(getattr(obj2, attr_name)).hexdigest()
  437. assert attr1 == attr2, "Attribute %s for class %s is not equal." % \
  438. (attr_name, obj1.__class__.__name__)
  439. obj_type = obj1.__class__.__name__
  440. assert obj_type == obj2.__class__.__name__
  441. for ioattr in obj1._necessary_attrs:
  442. assert_attr(obj1, obj2, ioattr[0])
  443. for ioattr in obj1._recommended_attrs:
  444. if hasattr(obj1, ioattr[0]) or hasattr(obj2, ioattr[0]):
  445. assert_attr(obj1, obj2, ioattr[0])
  446. if hasattr(obj1, "annotations"):
  447. assert hasattr(obj2, "annotations")
  448. for key, value in obj1.annotations:
  449. assert hasattr(obj2.annotations, key)
  450. assert obj2.annotations[key] == value
  451. def assert_children_empty(obj, parent):
  452. '''
  453. Check that the children of a neo object are empty. Used
  454. to check the cascade is implemented properly
  455. '''
  456. classname = obj.__class__.__name__
  457. errmsg = '''%s reader with cascade=False should return
  458. empty children''' % parent.__name__
  459. if hasattr(obj, 'children'):
  460. assert not obj.children, errmsg