test_klustakwikio.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of neo.io.klustakwikio
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import
  7. import glob
  8. import os.path
  9. import sys
  10. import tempfile
  11. import unittest
  12. import numpy as np
  13. import quantities as pq
  14. import neo
  15. from neo.test.iotest.common_io_test import BaseTestIO
  16. from neo.test.tools import assert_arrays_almost_equal
  17. from neo.io.klustakwikio import KlustaKwikIO, HAVE_MLAB
  18. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  19. @unittest.skipIf(sys.version_info[0] > 2, "not Python 3 compatible")
  20. class testFilenameParser(unittest.TestCase):
  21. """Tests that filenames can be loaded with or without basename.
  22. The test directory contains two basenames and some decoy files with
  23. malformed group numbers."""
  24. def setUp(self):
  25. self.dirname = os.path.join(tempfile.gettempdir(),
  26. 'files_for_testing_neo',
  27. 'klustakwik/test1')
  28. if not os.path.exists(self.dirname):
  29. raise unittest.SkipTest('data directory does not exist: ' +
  30. self.dirname)
  31. def test1(self):
  32. """Tests that files can be loaded by basename"""
  33. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'basename'))
  34. if not BaseTestIO.use_network:
  35. raise unittest.SkipTest("Requires download of data from the web")
  36. fetfiles = kio._fp.read_filenames('fet')
  37. self.assertEqual(len(fetfiles), 2)
  38. self.assertEqual(os.path.abspath(fetfiles[0]),
  39. os.path.abspath(os.path.join(self.dirname,
  40. 'basename.fet.0')))
  41. self.assertEqual(os.path.abspath(fetfiles[1]),
  42. os.path.abspath(os.path.join(self.dirname,
  43. 'basename.fet.1')))
  44. def test2(self):
  45. """Tests that files are loaded even without basename"""
  46. pass
  47. # this test is in flux, should probably have it default to
  48. # basename = os.path.split(dirname)[1] when dirname is a directory
  49. #~ dirname = os.path.normpath('./files_for_tests/klustakwik/test1')
  50. #~ kio = KlustaKwikIO(filename=dirname)
  51. #~ fetfiles = kio._fp.read_filenames('fet')
  52. #~ # It will just choose one of the two basenames, depending on which
  53. #~ # is first, so just assert that it did something without error.
  54. #~ self.assertNotEqual(len(fetfiles), 0)
  55. def test3(self):
  56. """Tests that files can be loaded by basename2"""
  57. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'basename2'))
  58. if not BaseTestIO.use_network:
  59. raise unittest.SkipTest("Requires download of data from the web")
  60. clufiles = kio._fp.read_filenames('clu')
  61. self.assertEqual(len(clufiles), 1)
  62. self.assertEqual(os.path.abspath(clufiles[1]),
  63. os.path.abspath(os.path.join(self.dirname,
  64. 'basename2.clu.1')))
  65. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  66. @unittest.skipIf(sys.version_info[0] > 2, "not Python 3 compatible")
  67. class testRead(unittest.TestCase):
  68. """Tests that data can be read from KlustaKwik files"""
  69. def setUp(self):
  70. self.dirname = os.path.join(tempfile.gettempdir(),
  71. 'files_for_testing_neo',
  72. 'klustakwik/test2')
  73. if not os.path.exists(self.dirname):
  74. raise unittest.SkipTest('data directory does not exist: ' +
  75. self.dirname)
  76. def test1(self):
  77. """Tests that data and metadata are read correctly"""
  78. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base'),
  79. sampling_rate=1000.)
  80. block = kio.read()[0]
  81. seg = block.segments[0]
  82. self.assertEqual(len(seg.spiketrains), 4)
  83. for st in seg.spiketrains:
  84. self.assertEqual(st.units, np.array(1.0) * pq.s)
  85. self.assertEqual(st.t_start, 0.0)
  86. self.assertEqual(seg.spiketrains[0].name, 'unit 1 from group 0')
  87. self.assertEqual(seg.spiketrains[0].annotations['cluster'], 1)
  88. self.assertEqual(seg.spiketrains[0].annotations['group'], 0)
  89. self.assertTrue(np.all(seg.spiketrains[0].times == np.array([.100,
  90. .200])))
  91. self.assertEqual(seg.spiketrains[1].name, 'unit 2 from group 0')
  92. self.assertEqual(seg.spiketrains[1].annotations['cluster'], 2)
  93. self.assertEqual(seg.spiketrains[1].annotations['group'], 0)
  94. self.assertEqual(seg.spiketrains[1].t_start, 0.0)
  95. self.assertTrue(np.all(seg.spiketrains[1].times == np.array([.305])))
  96. self.assertEqual(seg.spiketrains[2].name, 'unit -1 from group 1')
  97. self.assertEqual(seg.spiketrains[2].annotations['cluster'], -1)
  98. self.assertEqual(seg.spiketrains[2].annotations['group'], 1)
  99. self.assertEqual(seg.spiketrains[2].t_start, 0.0)
  100. self.assertTrue(np.all(seg.spiketrains[2].times == np.array([.253])))
  101. self.assertEqual(seg.spiketrains[3].name, 'unit 2 from group 1')
  102. self.assertEqual(seg.spiketrains[3].annotations['cluster'], 2)
  103. self.assertEqual(seg.spiketrains[3].annotations['group'], 1)
  104. self.assertEqual(seg.spiketrains[3].t_start, 0.0)
  105. self.assertTrue(np.all(seg.spiketrains[3].times == np.array([.050,
  106. .152])))
  107. def test2(self):
  108. """Checks that cluster id autosets to 0 without clu file"""
  109. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base2'),
  110. sampling_rate=1000.)
  111. block = kio.read()[0]
  112. seg = block.segments[0]
  113. self.assertEqual(len(seg.spiketrains), 1)
  114. self.assertEqual(seg.spiketrains[0].name, 'unit 0 from group 5')
  115. self.assertEqual(seg.spiketrains[0].annotations['cluster'], 0)
  116. self.assertEqual(seg.spiketrains[0].annotations['group'], 5)
  117. self.assertEqual(seg.spiketrains[0].t_start, 0.0)
  118. self.assertTrue(np.all(seg.spiketrains[0].times == np.array([0.026,
  119. 0.122,
  120. 0.228])))
  121. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  122. @unittest.skipIf(sys.version_info[0] > 2, "not Python 3 compatible")
  123. class testWrite(unittest.TestCase):
  124. def setUp(self):
  125. self.dirname = os.path.join(tempfile.gettempdir(),
  126. 'files_for_testing_neo',
  127. 'klustakwik/test3')
  128. if not os.path.exists(self.dirname):
  129. raise unittest.SkipTest('data directory does not exist: ' +
  130. self.dirname)
  131. def test1(self):
  132. """Create clu and fet files based on spiketrains in a block.
  133. Checks that
  134. Files are created
  135. Converted to samples correctly
  136. Missing sampling rate are taken from IO reader default
  137. Spiketrains without cluster info are assigned to cluster 0
  138. Spiketrains across segments are concatenated
  139. """
  140. block = neo.Block()
  141. segment = neo.Segment()
  142. segment2 = neo.Segment()
  143. block.segments.append(segment)
  144. block.segments.append(segment2)
  145. # Fake spiketrain 1, will be sorted
  146. st1 = neo.SpikeTrain(times=[.002, .004, .006], units='s', t_stop=1.)
  147. st1.annotations['cluster'] = 0
  148. st1.annotations['group'] = 0
  149. segment.spiketrains.append(st1)
  150. # Fake spiketrain 1B, on another segment. No group specified,
  151. # default is 0.
  152. st1B = neo.SpikeTrain(times=[.106], units='s', t_stop=1.)
  153. st1B.annotations['cluster'] = 0
  154. segment2.spiketrains.append(st1B)
  155. # Fake spiketrain 2 on same group, no sampling rate specified
  156. st2 = neo.SpikeTrain(times=[.001, .003, .011], units='s', t_stop=1.)
  157. st2.annotations['cluster'] = 1
  158. st2.annotations['group'] = 0
  159. segment.spiketrains.append(st2)
  160. # Fake spiketrain 3 on new group, with different sampling rate
  161. st3 = neo.SpikeTrain(times=[.05, .09, .10], units='s', t_stop=1.)
  162. st3.annotations['cluster'] = -1
  163. st3.annotations['group'] = 1
  164. segment.spiketrains.append(st3)
  165. # Fake spiketrain 4 on new group, without cluster info
  166. st4 = neo.SpikeTrain(times=[.005, .009], units='s', t_stop=1.)
  167. st4.annotations['group'] = 2
  168. segment.spiketrains.append(st4)
  169. # Create empty directory for writing
  170. delete_test_session()
  171. # Create writer with default sampling rate
  172. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base1'),
  173. sampling_rate=1000.)
  174. kio.write_block(block)
  175. # Check files were created
  176. for fn in ['.fet.0', '.fet.1', '.clu.0', '.clu.1']:
  177. self.assertTrue(os.path.exists(os.path.join(self.dirname,
  178. 'base1' + fn)))
  179. # Check files contain correct content
  180. # Spike times on group 0
  181. data = file(os.path.join(self.dirname, 'base1.fet.0')).readlines()
  182. data = [int(d) for d in data]
  183. self.assertEqual(data, [0, 2, 4, 6, 1, 3, 11, 106])
  184. # Clusters on group 0
  185. data = file(os.path.join(self.dirname, 'base1.clu.0')).readlines()
  186. data = [int(d) for d in data]
  187. self.assertEqual(data, [2, 0, 0, 0, 1, 1, 1, 0])
  188. # Spike times on group 1
  189. data = file(os.path.join(self.dirname, 'base1.fet.1')).readlines()
  190. data = [int(d) for d in data]
  191. self.assertEqual(data, [0, 50, 90, 100])
  192. # Clusters on group 1
  193. data = file(os.path.join(self.dirname, 'base1.clu.1')).readlines()
  194. data = [int(d) for d in data]
  195. self.assertEqual(data, [1, -1, -1, -1])
  196. # Spike times on group 2
  197. data = file(os.path.join(self.dirname, 'base1.fet.2')).readlines()
  198. data = [int(d) for d in data]
  199. self.assertEqual(data, [0, 5, 9])
  200. # Clusters on group 2
  201. data = file(os.path.join(self.dirname, 'base1.clu.2')).readlines()
  202. data = [int(d) for d in data]
  203. self.assertEqual(data, [1, 0, 0])
  204. # Empty out test session again
  205. delete_test_session()
  206. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  207. @unittest.skipIf(sys.version_info[0] > 2, "not Python 3 compatible")
  208. class testWriteWithFeatures(unittest.TestCase):
  209. def setUp(self):
  210. self.dirname = os.path.join(tempfile.gettempdir(),
  211. 'files_for_testing_neo',
  212. 'klustakwik/test4')
  213. if not os.path.exists(self.dirname):
  214. raise unittest.SkipTest('data directory does not exist: ' +
  215. self.dirname)
  216. def test1(self):
  217. """Create clu and fet files based on spiketrains in a block.
  218. Checks that
  219. Files are created
  220. Converted to samples correctly
  221. Missing sampling rate are taken from IO reader default
  222. Spiketrains without cluster info are assigned to cluster 0
  223. Spiketrains across segments are concatenated
  224. """
  225. block = neo.Block()
  226. segment = neo.Segment()
  227. segment2 = neo.Segment()
  228. block.segments.append(segment)
  229. block.segments.append(segment2)
  230. # Fake spiketrain 1
  231. st1 = neo.SpikeTrain(times=[.002, .004, .006], units='s', t_stop=1.)
  232. st1.annotations['cluster'] = 0
  233. st1.annotations['group'] = 0
  234. wff = np.array([
  235. [11.3, 0.2],
  236. [-0.3, 12.3],
  237. [3.0, -2.5]])
  238. st1.annotations['waveform_features'] = wff
  239. segment.spiketrains.append(st1)
  240. # Create empty directory for writing
  241. if not os.path.exists(self.dirname):
  242. os.mkdir(self.dirname)
  243. delete_test_session(self.dirname)
  244. # Create writer
  245. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base2'),
  246. sampling_rate=1000.)
  247. kio.write_block(block)
  248. # Check files were created
  249. for fn in ['.fet.0', '.clu.0']:
  250. self.assertTrue(os.path.exists(os.path.join(self.dirname,
  251. 'base2' + fn)))
  252. # Check files contain correct content
  253. fi = file(os.path.join(self.dirname, 'base2.fet.0'))
  254. # first line is nbFeatures
  255. self.assertEqual(fi.readline(), '2\n')
  256. # Now check waveforms and times are same
  257. data = fi.readlines()
  258. new_wff = []
  259. new_times = []
  260. for line in data:
  261. line_split = line.split()
  262. new_wff.append([float(val) for val in line_split[:-1]])
  263. new_times.append(int(line_split[-1]))
  264. self.assertEqual(new_times, [2, 4, 6])
  265. assert_arrays_almost_equal(wff, np.array(new_wff), .00001)
  266. # Clusters on group 0
  267. data = file(os.path.join(self.dirname, 'base2.clu.0')).readlines()
  268. data = [int(d) for d in data]
  269. self.assertEqual(data, [1, 0, 0, 0])
  270. # Now read the features and test same
  271. block = kio.read_block()
  272. train = block.segments[0].spiketrains[0]
  273. assert_arrays_almost_equal(wff, train.annotations['waveform_features'],
  274. .00001)
  275. # Empty out test session again
  276. delete_test_session(self.dirname)
  277. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  278. @unittest.skipIf(sys.version_info[0] > 2, "not Python 3 compatible")
  279. class CommonTests(BaseTestIO, unittest.TestCase):
  280. ioclass = KlustaKwikIO
  281. # These are the files it tries to read and test for compliance
  282. files_to_test = [
  283. 'test2/base',
  284. 'test2/base2',
  285. ]
  286. # Will fetch from g-node if they don't already exist locally
  287. # How does it know to do this before any of the other tests?
  288. files_to_download = [
  289. 'test1/basename.clu.0',
  290. 'test1/basename.fet.-1',
  291. 'test1/basename.fet.0',
  292. 'test1/basename.fet.1',
  293. 'test1/basename.fet.1a',
  294. 'test1/basename.fet.a1',
  295. 'test1/basename2.clu.1',
  296. 'test1/basename2.fet.1',
  297. 'test1/basename2.fet.1a',
  298. 'test2/base2.fet.5',
  299. 'test2/base.clu.0',
  300. 'test2/base.clu.1',
  301. 'test2/base.fet.0',
  302. 'test2/base.fet.1',
  303. 'test3/base1.clu.0',
  304. 'test3/base1.clu.1',
  305. 'test3/base1.clu.2',
  306. 'test3/base1.fet.0',
  307. 'test3/base1.fet.1',
  308. 'test3/base1.fet.2'
  309. ]
  310. def delete_test_session(dirname=None):
  311. """Removes all file in directory so we can test writing to it"""
  312. if dirname is None:
  313. dirname = os.path.join(os.path.dirname(__file__),
  314. 'files_for_tests/klustakwik/test3')
  315. for fi in glob.glob(os.path.join(dirname, '*')):
  316. os.remove(fi)
  317. if __name__ == '__main__':
  318. unittest.main()