1
0

tools.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. # -*- coding: utf-8 -*-
  2. '''
  3. Tools for use with neo tests.
  4. '''
  5. import hashlib
  6. import os
  7. import numpy as np
  8. import quantities as pq
  9. import neo
  10. from neo.core import objectlist
  11. from neo.core.baseneo import _reference_name, _container_name
  12. def assert_arrays_equal(a, b, dtype=False):
  13. '''
  14. Check if two arrays have the same shape and contents.
  15. If dtype is True (default=False), then also theck that they have the same
  16. dtype.
  17. '''
  18. assert isinstance(a, np.ndarray), "a is a %s" % type(a)
  19. assert isinstance(b, np.ndarray), "b is a %s" % type(b)
  20. assert a.shape == b.shape, "%s != %s" % (a, b)
  21. # assert a.dtype == b.dtype, "%s and %s not same dtype %s %s" % (a, b,
  22. # a.dtype,
  23. # b.dtype)
  24. try:
  25. assert (a.flatten() == b.flatten()).all(), "%s != %s" % (a, b)
  26. except (AttributeError, ValueError):
  27. try:
  28. ar = np.array(a)
  29. br = np.array(b)
  30. assert (ar.flatten() == br.flatten()).all(), "%s != %s" % (ar, br)
  31. except (AttributeError, ValueError):
  32. assert np.all(a.flatten() == b.flatten()), "%s != %s" % (a, b)
  33. if dtype:
  34. assert a.dtype == b.dtype, "%s and %s not same dtype %s and %s" % (a, b, a.dtype, b.dtype)
  35. def assert_arrays_almost_equal(a, b, threshold, dtype=False):
  36. '''
  37. Check if two arrays have the same shape and contents that differ
  38. by abs(a - b) <= threshold for all elements.
  39. If threshold is None, do an absolute comparison rather than a relative
  40. comparison.
  41. '''
  42. if threshold is None:
  43. return assert_arrays_equal(a, b, dtype=dtype)
  44. assert isinstance(a, np.ndarray), "a is a %s" % type(a)
  45. assert isinstance(b, np.ndarray), "b is a %s" % type(b)
  46. assert a.shape == b.shape, "%s != %s" % (a, b)
  47. # assert a.dtype == b.dtype, "%s and %b not same dtype %s %s" % (a, b,
  48. # a.dtype,
  49. # b.dtype)
  50. if a.dtype.kind in ['f', 'c', 'i']:
  51. assert (abs(
  52. a - b) < threshold).all(), "abs(%s - %s) max(|a - b|) = %s threshold:%s" \
  53. "" % (a, b, (abs(a - b)).max(), threshold)
  54. if dtype:
  55. assert a.dtype == b.dtype, "%s and %s not same dtype %s and %s" % (a, b, a.dtype, b.dtype)
  56. def file_digest(filename):
  57. '''
  58. Get the sha1 hash of the file with the given filename.
  59. '''
  60. with open(filename, 'rb') as fobj:
  61. return hashlib.sha1(fobj.read()).hexdigest()
  62. def assert_file_contents_equal(a, b):
  63. '''
  64. Assert that two files have the same size and hash.
  65. '''
  66. def generate_error_message(a, b):
  67. '''
  68. This creates the error message for the assertion error
  69. '''
  70. size_a = os.stat(a).st_size
  71. size_b = os.stat(b).st_size
  72. if size_a == size_b:
  73. return "Files have the same size but different contents"
  74. else:
  75. return "Files have different sizes: a:%d b: %d" % (size_a, size_b)
  76. assert file_digest(a) == file_digest(b), generate_error_message(a, b)
  77. def assert_neo_object_is_compliant(ob):
  78. '''
  79. Test neo compliance of one object and sub objects
  80. (one_to_many_relation only):
  81. * check types and/or presence of necessary and recommended attribute.
  82. * If attribute is Quantities or numpy.ndarray it also check ndim.
  83. * If attribute is numpy.ndarray also check dtype.kind.
  84. '''
  85. assert type(ob) in objectlist, '%s is not a neo object' % (type(ob))
  86. classname = ob.__class__.__name__
  87. # test presence of necessary attributes
  88. for ioattr in ob._necessary_attrs:
  89. attrname, attrtype = ioattr[0], ioattr[1]
  90. # ~ if attrname != '':
  91. if not hasattr(ob, '_quantity_attr'):
  92. assert hasattr(ob, attrname), '%s neo obect does not have %s' % (classname, attrname)
  93. # test attributes types
  94. for ioattr in ob._all_attrs:
  95. attrname, attrtype = ioattr[0], ioattr[1]
  96. if (hasattr(ob, '_quantity_attr') and ob._quantity_attr == attrname and (
  97. attrtype == pq.Quantity or attrtype == np.ndarray)):
  98. # object inherits from Quantity (AnalogSignal, SpikeTrain, ...)
  99. ndim = ioattr[2]
  100. assert ob.ndim == ndim, '%s dimension is %d should be %d' % (classname, ob.ndim, ndim)
  101. if attrtype == np.ndarray:
  102. dtp = ioattr[3]
  103. assert ob.dtype.kind == dtp.kind, '%s dtype.kind is %s should be %s' \
  104. '' % (classname, ob.dtype.kind, dtp.kind)
  105. elif hasattr(ob, attrname):
  106. if getattr(ob, attrname) is not None:
  107. obattr = getattr(ob, attrname)
  108. assert issubclass(type(obattr), attrtype), '%s in %s is %s should be %s' \
  109. '' % (attrname, classname,
  110. type(obattr), attrtype)
  111. if attrtype == pq.Quantity or attrtype == np.ndarray:
  112. ndim = ioattr[2]
  113. assert obattr.ndim == ndim, '%s.%s dimension is %d should be %d' \
  114. '' % (classname, attrname, obattr.ndim, ndim)
  115. if attrtype == np.ndarray:
  116. dtp = ioattr[3]
  117. assert obattr.dtype.kind == dtp.kind, '%s.%s dtype.kind is %s should be %s' \
  118. '' % (classname, attrname,
  119. obattr.dtype.kind, dtp.kind)
  120. # test bijectivity : parents and children
  121. for container in getattr(ob, '_single_child_containers', []):
  122. for i, child in enumerate(getattr(ob, container, [])):
  123. assert hasattr(child, _reference_name(
  124. classname)), '%s should have %s attribute (2 way relationship)' \
  125. '' % (container, _reference_name(classname))
  126. if hasattr(child, _reference_name(classname)):
  127. parent = getattr(child, _reference_name(classname))
  128. assert parent == ob, '%s.%s %s is not symetric with %s.%s' \
  129. '' % (container, _reference_name(classname), i, classname,
  130. container)
  131. # recursive on one to many rel
  132. for i, child in enumerate(getattr(ob, 'children', [])):
  133. try:
  134. assert_neo_object_is_compliant(child)
  135. # intercept exceptions and add more information
  136. except BaseException as exc:
  137. exc.args += ('from %s %s of %s' % (child.__class__.__name__, i, classname),)
  138. raise
  139. def assert_same_sub_schema(ob1, ob2, equal_almost=True, threshold=1e-10, exclude=None):
  140. '''
  141. Test if ob1 and ob2 has the same sub schema.
  142. Explore all parent/child relationships.
  143. Many_to_many_relationship is not tested
  144. because of infinite recursive loops.
  145. Arguments:
  146. equal_almost: if False do a strict arrays_equal if
  147. True do arrays_almost_equal
  148. exclude: a list of attributes and annotations to ignore in
  149. the comparison
  150. '''
  151. assert type(ob1) == type(ob2), 'type(%s) != type(%s)' % (type(ob1), type(ob2))
  152. classname = ob1.__class__.__name__
  153. if exclude is None:
  154. exclude = []
  155. if isinstance(ob1, list):
  156. assert len(ob1) == len(ob2), 'lens %s and %s not equal for %s and %s' \
  157. '' % (len(ob1), len(ob2), ob1, ob2)
  158. for i, (sub1, sub2) in enumerate(zip(ob1, ob2)):
  159. try:
  160. assert_same_sub_schema(sub1, sub2, equal_almost=equal_almost, threshold=threshold,
  161. exclude=exclude)
  162. # intercept exceptions and add more information
  163. except BaseException as exc:
  164. exc.args += ('%s[%s]' % (classname, i),)
  165. raise
  166. return
  167. # test parent/child relationship
  168. for container in getattr(ob1, '_single_child_containers', []):
  169. if container in exclude:
  170. continue
  171. if not hasattr(ob1, container):
  172. assert not hasattr(ob2, container), '%s 2 does have %s but not %s 1' \
  173. '' % (classname, container, classname)
  174. continue
  175. else:
  176. assert hasattr(ob2, container), '%s 1 has %s but not %s 2' % (classname, container,
  177. classname)
  178. sub1 = getattr(ob1, container)
  179. sub2 = getattr(ob2, container)
  180. assert len(sub1) == len(
  181. sub2), 'theses two %s do not have the same %s number: %s and %s' \
  182. '' % (classname, container, len(sub1), len(sub2))
  183. for i in range(len(getattr(ob1, container))):
  184. # previously lacking parameter
  185. try:
  186. assert_same_sub_schema(sub1[i], sub2[i], equal_almost=equal_almost,
  187. threshold=threshold, exclude=exclude)
  188. # intercept exceptions and add more information
  189. except BaseException as exc:
  190. exc.args += ('from %s[%s] of %s' % (container, i, classname),)
  191. raise
  192. assert_same_attributes(ob1, ob2, equal_almost=equal_almost, threshold=threshold,
  193. exclude=exclude)
  194. def assert_same_attributes(ob1, ob2, equal_almost=True, threshold=1e-10, exclude=None):
  195. '''
  196. Test if ob1 and ob2 has the same attributes.
  197. Arguments:
  198. equal_almost: if False do a strict arrays_equal if
  199. True do arrays_almost_equal
  200. exclude: a list of attributes and annotations to ignore in
  201. the comparison
  202. '''
  203. classname = ob1.__class__.__name__
  204. if exclude is None:
  205. exclude = []
  206. if not equal_almost:
  207. threshold = None
  208. dtype = True
  209. else:
  210. dtype = False
  211. for ioattr in ob1._all_attrs:
  212. if ioattr[0] in exclude:
  213. continue
  214. attrname, attrtype = ioattr[0], ioattr[1]
  215. # ~ if attrname =='':
  216. if hasattr(ob1, '_quantity_attr') and ob1._quantity_attr == attrname:
  217. # object is hinerited from Quantity (AnalogSignal, SpikeTrain, ...)
  218. try:
  219. assert_arrays_almost_equal(ob1.magnitude, ob2.magnitude, threshold=threshold,
  220. dtype=dtype)
  221. # intercept exceptions and add more information
  222. except BaseException as exc:
  223. exc.args += ('from %s %s' % (classname, attrname),)
  224. raise
  225. assert ob1.dimensionality.string == ob2.dimensionality.string,\
  226. 'Units of %s %s are not the same: %s and %s' \
  227. '' % (classname, attrname, ob1.dimensionality.string, ob2.dimensionality.string)
  228. continue
  229. if not hasattr(ob1, attrname):
  230. assert not hasattr(ob2, attrname), '%s 2 does have %s but not %s 1' \
  231. '' % (classname, attrname, classname)
  232. continue
  233. else:
  234. assert hasattr(ob2, attrname), '%s 1 has %s but not %s 2' \
  235. '' % (classname, attrname, classname)
  236. if getattr(ob1, attrname) is None:
  237. assert getattr(ob2, attrname) is None, 'In %s.%s %s and %s differed' \
  238. '' % (classname, attrname,
  239. getattr(ob1, attrname),
  240. getattr(ob2, attrname))
  241. continue
  242. if getattr(ob2, attrname) is None:
  243. assert getattr(ob1, attrname) is None, 'In %s.%s %s and %s differed' \
  244. '' % (classname, attrname,
  245. getattr(ob1, attrname),
  246. getattr(ob2, attrname))
  247. continue
  248. if attrtype == pq.Quantity:
  249. # Compare magnitudes
  250. mag1 = getattr(ob1, attrname).magnitude
  251. mag2 = getattr(ob2, attrname).magnitude
  252. # print "2. ob1(%s) %s:%s\n ob2(%s) %s:%s" % \
  253. # (ob1,attrname,mag1,ob2,attrname,mag2)
  254. try:
  255. assert_arrays_almost_equal(mag1, mag2, threshold=threshold, dtype=dtype)
  256. # intercept exceptions and add more information
  257. except BaseException as exc:
  258. exc.args += ('from %s of %s' % (attrname, classname),)
  259. raise
  260. # Compare dimensionalities
  261. dim1 = getattr(ob1, attrname).dimensionality.simplified
  262. dim2 = getattr(ob2, attrname).dimensionality.simplified
  263. dimstr1 = getattr(ob1, attrname).dimensionality.string
  264. dimstr2 = getattr(ob2, attrname).dimensionality.string
  265. assert dim1 == dim2, 'Attribute %s of %s are not the same: %s != %s' \
  266. '' % (attrname, classname, dimstr1, dimstr2)
  267. elif attrtype == np.ndarray:
  268. try:
  269. assert_arrays_almost_equal(getattr(ob1, attrname), getattr(ob2, attrname),
  270. threshold=threshold, dtype=dtype)
  271. # intercept exceptions and add more information
  272. except BaseException as exc:
  273. exc.args += ('from %s of %s' % (attrname, classname),)
  274. raise
  275. else:
  276. # ~ print 'yep', getattr(ob1, attrname), getattr(ob2, attrname)
  277. assert getattr(ob1, attrname) == getattr(ob2, attrname),\
  278. 'Attribute %s.%s are not the same %s %s %s %s' \
  279. '' % (classname, attrname, type(getattr(ob1, attrname)), getattr(ob1, attrname),
  280. type(getattr(ob2, attrname)), getattr(ob2, attrname))
  281. def assert_same_annotations(ob1, ob2, equal_almost=True, threshold=1e-10, exclude=None):
  282. '''
  283. Test if ob1 and ob2 has the same annotations.
  284. Arguments:
  285. equal_almost: if False do a strict arrays_equal if
  286. True do arrays_almost_equal
  287. exclude: a list of attributes and annotations to ignore in
  288. the comparison
  289. '''
  290. if exclude is None:
  291. exclude = []
  292. if not equal_almost:
  293. threshold = None
  294. dtype = False
  295. else:
  296. dtype = True
  297. for key in ob2.annotations:
  298. if key in exclude:
  299. continue
  300. assert key in ob1.annotations
  301. for key, value in ob1.annotations.items():
  302. if key in exclude:
  303. continue
  304. assert key in ob2.annotations
  305. try:
  306. assert value == ob2.annotations[key]
  307. except ValueError:
  308. assert_arrays_almost_equal(ob1, ob2, threshold=threshold, dtype=False)
  309. def assert_same_array_annotations(ob1, ob2, equal_almost=True, threshold=1e-10, exclude=None):
  310. '''
  311. Test if ob1 and ob2 has the same annotations.
  312. Arguments:
  313. equal_almost: if False do a strict arrays_equal if
  314. True do arrays_almost_equal
  315. exclude: a list of attributes and annotations to ignore in
  316. the comparison
  317. '''
  318. if exclude is None:
  319. exclude = []
  320. if not equal_almost:
  321. threshold = None
  322. dtype = False
  323. else:
  324. dtype = True
  325. for key in ob2.array_annotations:
  326. if key in exclude:
  327. continue
  328. assert key in ob1.array_annotations
  329. for key, value in ob1.array_annotations.items():
  330. if key in exclude:
  331. continue
  332. assert key in ob2.array_annotations
  333. try:
  334. assert_arrays_equal(value, ob2.array_annotations[key])
  335. except ValueError:
  336. assert_arrays_almost_equal(ob1, ob2, threshold=threshold, dtype=False)
  337. def assert_sub_schema_is_lazy_loaded(ob):
  338. '''
  339. This is util for testing lazy load. All object must load with ndarray.size
  340. or Quantity.size ==0
  341. '''
  342. classname = ob.__class__.__name__
  343. for container in getattr(ob, '_single_child_containers', []):
  344. if not hasattr(ob, container):
  345. continue
  346. sub = getattr(ob, container)
  347. for i, child in enumerate(sub):
  348. try:
  349. assert_sub_schema_is_lazy_loaded(child)
  350. # intercept exceptions and add more information
  351. except BaseException as exc:
  352. exc.args += ('from %s %s of %s' % (container, i, classname),)
  353. raise
  354. for ioattr in ob._all_attrs:
  355. attrname, attrtype = ioattr[0], ioattr[1]
  356. # ~ print 'xdsd', classname, attrname
  357. # ~ if attrname == '':
  358. if hasattr(ob, '_quantity_attr') and ob._quantity_attr == attrname:
  359. assert ob.size == 0, 'Lazy loaded error %s.size = %s' % (classname, ob.size)
  360. assert hasattr(ob, 'lazy_shape'),\
  361. 'Lazy loaded error, %s should have lazy_shape attribute' % classname
  362. continue
  363. if not hasattr(ob, attrname) or getattr(ob, attrname) is None:
  364. continue
  365. # ~ print 'hjkjh'
  366. if (attrtype == pq.Quantity or attrtype == np.ndarray):
  367. # FIXME: it is a workaround for recordingChannelGroup.channel_names
  368. # which is nupy.array but allowed to be loaded when lazy == True
  369. if ob.__class__ == neo.ChannelIndex:
  370. continue
  371. ndim = ioattr[2]
  372. # ~ print 'ndim', ndim
  373. # ~ print getattr(ob, attrname).size
  374. if ndim >= 1:
  375. assert getattr(ob, attrname).size == 0,\
  376. 'Lazy loaded error %s.%s.size = %s' % (classname, attrname,
  377. getattr(ob, attrname).size)
  378. assert hasattr(ob, 'lazy_shape'),\
  379. 'Lazy loaded error %s should have lazy_shape attribute ' % classname + \
  380. 'because of %s attribute' % attrname
  381. lazy_shape_arrays = {'SpikeTrain': 'times', 'AnalogSignal': 'signal', 'Event': 'times',
  382. 'Epoch': 'times'}
  383. def assert_lazy_sub_schema_can_be_loaded(ob, io):
  384. '''
  385. This is util for testing lazy load. All object must load with ndarray.size
  386. or Quantity.size ==0
  387. '''
  388. classname = ob.__class__.__name__
  389. if classname in lazy_shape_arrays:
  390. new_load = io.load_lazy_object(ob)
  391. assert hasattr(ob, 'lazy_shape'), 'Object %s was not lazy loaded' % classname
  392. assert not hasattr(new_load, 'lazy_shape'),\
  393. 'Newly loaded object from %s was also lazy loaded' % classname
  394. if hasattr(ob, '_quantity_attr'):
  395. assert ob.lazy_shape == new_load.shape,\
  396. 'Shape of loaded object %sis not equal to lazy shape' % classname
  397. else:
  398. assert ob.lazy_shape == getattr(new_load, lazy_shape_arrays[
  399. classname]).shape, 'Shape of loaded object %s not equal to lazy shape' % classname
  400. return
  401. for container in getattr(ob, '_single_child_containers', []):
  402. if not hasattr(ob, container):
  403. continue
  404. sub = getattr(ob, container)
  405. for i, child in enumerate(sub):
  406. try:
  407. assert_lazy_sub_schema_can_be_loaded(child, io)
  408. # intercept exceptions and add more information
  409. except BaseException as exc:
  410. exc.args += ('from of %s %s of %s' % (container, i, classname),)
  411. raise
  412. def assert_objects_equivalent(obj1, obj2):
  413. '''
  414. Compares two NEO objects by looping over the attributes and annotations
  415. and asserting their hashes. No relationships involved.
  416. '''
  417. def assert_attr(obj1, obj2, attr_name):
  418. '''
  419. Assert a single attribute and annotation are the same
  420. '''
  421. assert hasattr(obj1, attr_name)
  422. attr1 = hashlib.md5(getattr(obj1, attr_name)).hexdigest()
  423. assert hasattr(obj2, attr_name)
  424. attr2 = hashlib.md5(getattr(obj2, attr_name)).hexdigest()
  425. assert attr1 == attr2, "Attribute %s for class %s is not equal." \
  426. "" % (attr_name, obj1.__class__.__name__)
  427. obj_type = obj1.__class__.__name__
  428. assert obj_type == obj2.__class__.__name__
  429. for ioattr in obj1._necessary_attrs:
  430. assert_attr(obj1, obj2, ioattr[0])
  431. for ioattr in obj1._recommended_attrs:
  432. if hasattr(obj1, ioattr[0]) or hasattr(obj2, ioattr[0]):
  433. assert_attr(obj1, obj2, ioattr[0])
  434. if hasattr(obj1, "annotations"):
  435. assert hasattr(obj2, "annotations")
  436. for key, value in obj1.annotations:
  437. assert hasattr(obj2.annotations, key)
  438. assert obj2.annotations[key] == value
  439. def assert_children_empty(obj, parent):
  440. '''
  441. Check that the children of a neo object are empty. Used
  442. to check the cascade is implemented properly
  443. '''
  444. classname = obj.__class__.__name__
  445. errmsg = '''%s reader with cascade=False should return
  446. empty children''' % parent.__name__
  447. if hasattr(obj, 'children'):
  448. assert not obj.children, errmsg