test_analogsignal.py 31 KB


  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of the neo.core.analogsignal.AnalogSignal class and related functions
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import division
  7. import os
  8. import pickle
  9. import copy
  10. import unittest
  11. import numpy as np
  12. import quantities as pq
  13. try:
  14. from IPython.lib.pretty import pretty
  15. except ImportError as err:
  16. HAVE_IPYTHON = False
  17. else:
  18. HAVE_IPYTHON = True
  19. from numpy.testing import assert_array_equal
  20. from neo.core.analogsignal import AnalogSignal, _get_sampling_rate
  21. from neo.core.channelindex import ChannelIndex
  22. from neo.core import Segment
  23. from neo.test.tools import (assert_arrays_almost_equal,
  24. assert_neo_object_is_compliant,
  25. assert_same_sub_schema,
  26. assert_objects_equivalent,
  27. assert_same_attributes,
  28. assert_same_sub_schema)
  29. from neo.test.generate_datasets import (get_fake_value, get_fake_values,
  30. fake_neo, TEST_ANNOTATIONS)
  31. class Test__generate_datasets(unittest.TestCase):
  32. def setUp(self):
  33. np.random.seed(0)
  34. self.annotations = dict([(str(x), TEST_ANNOTATIONS[x]) for x in
  35. range(len(TEST_ANNOTATIONS))])
  36. def test__fake_neo__cascade(self):
  37. self.annotations['seed'] = None
  38. obj_type = AnalogSignal
  39. cascade = True
  40. res = fake_neo(obj_type=obj_type, cascade=cascade)
  41. self.assertTrue(isinstance(res, AnalogSignal))
  42. assert_neo_object_is_compliant(res)
  43. self.assertEqual(res.annotations, self.annotations)
  44. def test__fake_neo__nocascade(self):
  45. self.annotations['seed'] = None
  46. obj_type = 'AnalogSignal'
  47. cascade = False
  48. res = fake_neo(obj_type=obj_type, cascade=cascade)
  49. self.assertTrue(isinstance(res, AnalogSignal))
  50. assert_neo_object_is_compliant(res)
  51. self.assertEqual(res.annotations, self.annotations)
  52. class TestAnalogSignalConstructor(unittest.TestCase):
  53. def test__create_from_list(self):
  54. data = range(10)
  55. rate = 1000*pq.Hz
  56. signal = AnalogSignal(data, sampling_rate=rate, units="mV")
  57. assert_neo_object_is_compliant(signal)
  58. self.assertEqual(signal.t_start, 0*pq.ms)
  59. self.assertEqual(signal.t_stop, len(data)/rate)
  60. self.assertEqual(signal[9, 0], 9000*pq.uV)
  61. def test__create_from_np_array(self):
  62. data = np.arange(10.0)
  63. rate = 1*pq.kHz
  64. signal = AnalogSignal(data, sampling_rate=rate, units="uV")
  65. assert_neo_object_is_compliant(signal)
  66. self.assertEqual(signal.t_start, 0*pq.ms)
  67. self.assertEqual(signal.t_stop, data.size/rate)
  68. self.assertEqual(signal[9, 0], 0.009*pq.mV)
  69. def test__create_from_quantities_array(self):
  70. data = np.arange(10.0) * pq.mV
  71. rate = 5000*pq.Hz
  72. signal = AnalogSignal(data, sampling_rate=rate)
  73. assert_neo_object_is_compliant(signal)
  74. self.assertEqual(signal.t_start, 0*pq.ms)
  75. self.assertEqual(signal.t_stop, data.size/rate)
  76. self.assertEqual(signal[9, 0], 0.009*pq.V)
  77. def test__create_from_array_no_units_ValueError(self):
  78. data = np.arange(10.0)
  79. self.assertRaises(ValueError, AnalogSignal, data,
  80. sampling_rate=1 * pq.kHz)
  81. def test__create_from_quantities_array_inconsistent_units_ValueError(self):
  82. data = np.arange(10.0) * pq.mV
  83. self.assertRaises(ValueError, AnalogSignal, data,
  84. sampling_rate=1 * pq.kHz, units="nA")
  85. def test__create_without_sampling_rate_or_period_ValueError(self):
  86. data = np.arange(10.0) * pq.mV
  87. self.assertRaises(ValueError, AnalogSignal, data)
  88. def test__create_with_None_sampling_rate_should_raise_ValueError(self):
  89. data = np.arange(10.0) * pq.mV
  90. self.assertRaises(ValueError, AnalogSignal, data, sampling_rate=None)
  91. def test__create_with_None_t_start_should_raise_ValueError(self):
  92. data = np.arange(10.0) * pq.mV
  93. rate = 5000 * pq.Hz
  94. self.assertRaises(ValueError, AnalogSignal, data,
  95. sampling_rate=rate, t_start=None)
  96. def test__create_inconsistent_sampling_rate_and_period_ValueError(self):
  97. data = np.arange(10.0) * pq.mV
  98. self.assertRaises(ValueError, AnalogSignal, data,
  99. sampling_rate=1 * pq.kHz, sampling_period=5 * pq.s)
  100. def test__create_with_copy_true_should_return_copy(self):
  101. data = np.arange(10.0) * pq.mV
  102. rate = 5000*pq.Hz
  103. signal = AnalogSignal(data, copy=True, sampling_rate=rate)
  104. data[3] = 99*pq.mV
  105. assert_neo_object_is_compliant(signal)
  106. self.assertNotEqual(signal[3, 0], 99*pq.mV)
  107. def test__create_with_copy_false_should_return_view(self):
  108. data = np.arange(10.0) * pq.mV
  109. rate = 5000*pq.Hz
  110. signal = AnalogSignal(data, copy=False, sampling_rate=rate)
  111. data[3] = 99*pq.mV
  112. assert_neo_object_is_compliant(signal)
  113. self.assertEqual(signal[3, 0], 99*pq.mV)
  114. def test__create2D_with_copy_false_should_return_view(self):
  115. data = np.arange(10.0) * pq.mV
  116. data = data.reshape((5, 2))
  117. rate = 5000*pq.Hz
  118. signal = AnalogSignal(data, copy=False, sampling_rate=rate)
  119. data[3, 0] = 99*pq.mV
  120. assert_neo_object_is_compliant(signal)
  121. self.assertEqual(signal[3, 0], 99*pq.mV)
  122. def test__create_with_additional_argument(self):
  123. signal = AnalogSignal([1, 2, 3], units="mV", sampling_rate=1*pq.kHz,
  124. file_origin='crack.txt', ratname='Nicolas')
  125. assert_neo_object_is_compliant(signal)
  126. self.assertEqual(signal.annotations, {'ratname': 'Nicolas'})
  127. # This one is universally recommended and handled by BaseNeo
  128. self.assertEqual(signal.file_origin, 'crack.txt')
  129. # signal must be 1D - should raise Exception if not 1D
  130. class TestAnalogSignalProperties(unittest.TestCase):
  131. def setUp(self):
  132. self.t_start = [0.0*pq.ms, 100*pq.ms, -200*pq.ms]
  133. self.rates = [1*pq.kHz, 420*pq.Hz, 999*pq.Hz]
  134. self.rates2 = [2*pq.kHz, 290*pq.Hz, 1111*pq.Hz]
  135. self.data = [np.arange(10.0)*pq.nA,
  136. np.arange(-100.0, 100.0, 10.0)*pq.mV,
  137. np.random.uniform(size=100)*pq.uV]
  138. self.signals = [AnalogSignal(D, sampling_rate=r, t_start=t,
  139. testattr='test')
  140. for r, D, t in zip(self.rates,
  141. self.data,
  142. self.t_start)]
  143. def test__compliant(self):
  144. for signal in self.signals:
  145. assert_neo_object_is_compliant(signal)
  146. def test__t_stop_getter(self):
  147. for i, signal in enumerate(self.signals):
  148. self.assertEqual(signal.t_stop,
  149. self.t_start[i] + self.data[i].size/self.rates[i])
  150. def test__duration_getter(self):
  151. for signal in self.signals:
  152. self.assertAlmostEqual(signal.duration,
  153. signal.t_stop - signal.t_start,
  154. delta=1e-15)
  155. def test__sampling_rate_getter(self):
  156. for signal, rate in zip(self.signals, self.rates):
  157. self.assertEqual(signal.sampling_rate, rate)
  158. def test__sampling_period_getter(self):
  159. for signal, rate in zip(self.signals, self.rates):
  160. self.assertEqual(signal.sampling_period, 1 / rate)
  161. def test__sampling_rate_setter(self):
  162. for signal, rate in zip(self.signals, self.rates2):
  163. signal.sampling_rate = rate
  164. assert_neo_object_is_compliant(signal)
  165. self.assertEqual(signal.sampling_rate, rate)
  166. self.assertEqual(signal.sampling_period, 1 / rate)
  167. def test__sampling_period_setter(self):
  168. for signal, rate in zip(self.signals, self.rates2):
  169. signal.sampling_period = 1 / rate
  170. assert_neo_object_is_compliant(signal)
  171. self.assertEqual(signal.sampling_rate, rate)
  172. self.assertEqual(signal.sampling_period, 1 / rate)
  173. def test__sampling_rate_setter_None_ValueError(self):
  174. self.assertRaises(ValueError, setattr, self.signals[0],
  175. 'sampling_rate', None)
  176. def test__sampling_rate_setter_not_quantity_ValueError(self):
  177. self.assertRaises(ValueError, setattr, self.signals[0],
  178. 'sampling_rate', 5.5)
  179. def test__sampling_period_setter_None_ValueError(self):
  180. signal = self.signals[0]
  181. assert_neo_object_is_compliant(signal)
  182. self.assertRaises(ValueError, setattr, signal, 'sampling_period', None)
  183. def test__sampling_period_setter_not_quantity_ValueError(self):
  184. self.assertRaises(ValueError, setattr, self.signals[0],
  185. 'sampling_period', 5.5)
  186. def test__t_start_setter_None_ValueError(self):
  187. signal = self.signals[0]
  188. assert_neo_object_is_compliant(signal)
  189. self.assertRaises(ValueError, setattr, signal, 't_start', None)
  190. def test__times_getter(self):
  191. for i, signal in enumerate(self.signals):
  192. targ = np.arange(self.data[i].size)
  193. targ = targ/self.rates[i] + self.t_start[i]
  194. assert_neo_object_is_compliant(signal)
  195. assert_arrays_almost_equal(signal.times, targ, 1e-12*pq.ms)
  196. def test__duplicate_with_new_array(self):
  197. signal1 = self.signals[1]
  198. signal2 = self.signals[2]
  199. data2 = self.data[2]
  200. signal1b = signal1.duplicate_with_new_array(data2)
  201. assert_arrays_almost_equal(np.asarray(signal1b),
  202. np.asarray(signal2/1000.), 1e-12)
  203. self.assertEqual(signal1b.t_start, signal1.t_start)
  204. self.assertEqual(signal1b.sampling_rate, signal1.sampling_rate)
  205. # def test__children(self):
  206. # signal = self.signals[0]
  207. #
  208. # segment = Segment(name='seg1')
  209. # segment.analogsignals = [signal]
  210. # segment.create_many_to_one_relationship()
  211. #
  212. # rchan = RecordingChannel(name='rchan1')
  213. # rchan.analogsignals = [signal]
  214. # rchan.create_many_to_one_relationship()
  215. #
  216. # self.assertEqual(signal._single_parent_objects,
  217. # ('Segment', 'RecordingChannel'))
  218. # self.assertEqual(signal._multi_parent_objects, ())
  219. #
  220. # self.assertEqual(signal._single_parent_containers,
  221. # ('segment', 'recordingchannel'))
  222. # self.assertEqual(signal._multi_parent_containers, ())
  223. #
  224. # self.assertEqual(signal._parent_objects,
  225. # ('Segment', 'RecordingChannel'))
  226. # self.assertEqual(signal._parent_containers,
  227. # ('segment', 'recordingchannel'))
  228. #
  229. # self.assertEqual(len(signal.parents), 2)
  230. # self.assertEqual(signal.parents[0].name, 'seg1')
  231. # self.assertEqual(signal.parents[1].name, 'rchan1')
  232. #
  233. # assert_neo_object_is_compliant(signal)
  234. def test__repr(self):
  235. for i, signal in enumerate(self.signals):
  236. prepr = repr(signal)
  237. targ = '<AnalogSignal(%s, [%s, %s], sampling rate: %s)>' % \
  238. (repr(self.data[i].reshape(-1, 1)),
  239. self.t_start[i],
  240. self.t_start[i] + len(self.data[i])/self.rates[i],
  241. self.rates[i])
  242. self.assertEqual(prepr, targ)
  243. @unittest.skipUnless(HAVE_IPYTHON, "requires IPython")
  244. def test__pretty(self):
  245. for i, signal in enumerate(self.signals):
  246. prepr = pretty(signal)
  247. targ = (('AnalogSignal with %d channels of length %d; units %s; datatype %s \n' %
  248. (signal.shape[1], signal.shape[0], signal.units.dimensionality.unicode, signal.dtype)) +
  249. ('annotations: %s\n' % signal.annotations) +
  250. ('sampling rate: %s\n' % signal.sampling_rate) +
  251. ('time: %s to %s' % (signal.t_start, signal.t_stop)))
  252. self.assertEqual(prepr, targ)
  253. class TestAnalogSignalArrayMethods(unittest.TestCase):
  254. def setUp(self):
  255. self.data1 = np.arange(10.0)
  256. self.data1quant = self.data1 * pq.nA
  257. self.signal1 = AnalogSignal(self.data1quant, sampling_rate=1*pq.kHz,
  258. name='spam', description='eggs',
  259. file_origin='testfile.txt', arg1='test')
  260. self.signal1.segment = Segment()
  261. self.signal1.channel_index = ChannelIndex(index=[0])
  262. def test__compliant(self):
  263. assert_neo_object_is_compliant(self.signal1)
  264. def test__slice_should_return_AnalogSignalArray(self):
  265. # slice
  266. for index in (0, np.int64(0)):
  267. result = self.signal1[3:8, index]
  268. self.assertIsInstance(result, AnalogSignal)
  269. assert_neo_object_is_compliant(result)
  270. self.assertEqual(result.name, 'spam') # should slicing really preserve name and description?
  271. self.assertEqual(result.description, 'eggs') # perhaps these should be modified to indicate the slice?
  272. self.assertEqual(result.file_origin, 'testfile.txt')
  273. self.assertEqual(result.annotations, {'arg1': 'test'})
  274. self.assertEqual(result.size, 5)
  275. self.assertEqual(result.sampling_period, self.signal1.sampling_period)
  276. self.assertEqual(result.sampling_rate, self.signal1.sampling_rate)
  277. self.assertEqual(result.t_start,
  278. self.signal1.t_start+3*result.sampling_period)
  279. self.assertEqual(result.t_stop,
  280. result.t_start + 5*result.sampling_period)
  281. assert_array_equal(result.magnitude, self.data1[3:8].reshape(-1, 1))
  282. # Test other attributes were copied over (in this case, defaults)
  283. self.assertEqual(result.file_origin, self.signal1.file_origin)
  284. self.assertEqual(result.name, self.signal1.name)
  285. self.assertEqual(result.description, self.signal1.description)
  286. self.assertEqual(result.annotations, self.signal1.annotations)
  287. def test__slice_should_let_access_to_parents_objects(self):
  288. result = self.signal1.time_slice(1*pq.ms,3*pq.ms)
  289. self.assertEqual(result.segment, self.signal1.segment)
  290. self.assertEqual(result.channel_index, self.signal1.channel_index)
  291. def test__slice_should_change_sampling_period(self):
  292. result1 = self.signal1[:2, 0]
  293. result2 = self.signal1[::2, 0]
  294. result3 = self.signal1[1:7:2, 0]
  295. self.assertIsInstance(result1, AnalogSignal)
  296. assert_neo_object_is_compliant(result1)
  297. self.assertEqual(result1.name, 'spam')
  298. self.assertEqual(result1.description, 'eggs')
  299. self.assertEqual(result1.file_origin, 'testfile.txt')
  300. self.assertEqual(result1.annotations, {'arg1': 'test'})
  301. self.assertIsInstance(result2, AnalogSignal)
  302. assert_neo_object_is_compliant(result2)
  303. self.assertEqual(result2.name, 'spam')
  304. self.assertEqual(result2.description, 'eggs')
  305. self.assertEqual(result2.file_origin, 'testfile.txt')
  306. self.assertEqual(result2.annotations, {'arg1': 'test'})
  307. self.assertIsInstance(result3, AnalogSignal)
  308. assert_neo_object_is_compliant(result3)
  309. self.assertEqual(result3.name, 'spam')
  310. self.assertEqual(result3.description, 'eggs')
  311. self.assertEqual(result3.file_origin, 'testfile.txt')
  312. self.assertEqual(result3.annotations, {'arg1': 'test'})
  313. self.assertEqual(result1.sampling_period, self.signal1.sampling_period)
  314. self.assertEqual(result2.sampling_period,
  315. self.signal1.sampling_period * 2)
  316. self.assertEqual(result3.sampling_period,
  317. self.signal1.sampling_period * 2)
  318. assert_array_equal(result1.magnitude, self.data1[:2].reshape(-1, 1))
  319. assert_array_equal(result2.magnitude, self.data1[::2].reshape(-1, 1))
  320. assert_array_equal(result3.magnitude, self.data1[1:7:2].reshape(-1, 1))
  321. def test__slice_should_modify_linked_channelindex(self):
  322. n = 8 # number of channels
  323. signal = AnalogSignal(np.arange(n * 100.0).reshape(100, n),
  324. sampling_rate=1*pq.kHz,
  325. units="mV",
  326. name="test")
  327. self.assertEqual(signal.shape, (100, n))
  328. signal.channel_index = ChannelIndex(index=np.arange(n, dtype=int),
  329. channel_names=["channel{0}".format(i) for i in range(n)])
  330. signal.channel_index.analogsignals.append(signal)
  331. odd_channels = signal[:, 1::2]
  332. self.assertEqual(odd_channels.shape, (100, n//2))
  333. assert_array_equal(odd_channels.channel_index.index, np.arange(n//2, dtype=int))
  334. assert_array_equal(odd_channels.channel_index.channel_names, ["channel{0}".format(i) for i in range(1, n, 2)])
  335. assert_array_equal(signal.channel_index.channel_names, ["channel{0}".format(i) for i in range(n)])
  336. self.assertEqual(odd_channels.channel_index.analogsignals[0].name, signal.name)
  337. def test__copy_should_let_access_to_parents_objects(self):
  338. result = self.signal1.copy()
  339. self.assertIs(result.segment, self.signal1.segment)
  340. self.assertIs(result.channel_index, self.signal1.channel_index)
  341. def test__deepcopy_should_let_access_to_parents_objects(self):
  342. result = copy.deepcopy(self.signal1)
  343. self.assertIsInstance(result.segment, Segment)
  344. self.assertIsInstance(result.channel_index, ChannelIndex)
  345. assert_same_sub_schema(result.segment, self.signal1.segment)
  346. assert_same_sub_schema(result.channel_index, self.signal1.channel_index)
  347. def test__getitem_should_return_single_quantity(self):
  348. result1 = self.signal1[0, 0]
  349. result2 = self.signal1[9, 0]
  350. self.assertIsInstance(result1, pq.Quantity)
  351. self.assertFalse(hasattr(result1, 'name'))
  352. self.assertFalse(hasattr(result1, 'description'))
  353. self.assertFalse(hasattr(result1, 'file_origin'))
  354. self.assertFalse(hasattr(result1, 'annotations'))
  355. self.assertIsInstance(result2, pq.Quantity)
  356. self.assertFalse(hasattr(result2, 'name'))
  357. self.assertFalse(hasattr(result2, 'description'))
  358. self.assertFalse(hasattr(result2, 'file_origin'))
  359. self.assertFalse(hasattr(result2, 'annotations'))
  360. self.assertEqual(result1, 0*pq.nA)
  361. self.assertEqual(result2, 9*pq.nA)
  362. def test__getitem_out_of_bounds_IndexError(self):
  363. self.assertRaises(IndexError, self.signal1.__getitem__, (10, 0))
  364. def test_comparison_operators(self):
  365. assert_array_equal(self.signal1 >= 5*pq.nA,
  366. np.array([False, False, False, False, False,
  367. True, True, True, True, True]).reshape(-1, 1))
  368. assert_array_equal(self.signal1 >= 5*pq.pA,
  369. np.array([False, True, True, True, True,
  370. True, True, True, True, True]).reshape(-1, 1))
  371. def test__comparison_with_inconsistent_units_should_raise_Exception(self):
  372. self.assertRaises(ValueError, self.signal1.__gt__, 5*pq.mV)
  373. def test__simple_statistics(self):
  374. self.assertEqual(self.signal1.max(), 9*pq.nA)
  375. self.assertEqual(self.signal1.min(), 0*pq.nA)
  376. self.assertEqual(self.signal1.mean(), 4.5*pq.nA)
  377. def test__rescale_same(self):
  378. result = self.signal1.copy()
  379. result = result.rescale(pq.nA)
  380. self.assertIsInstance(result, AnalogSignal)
  381. assert_neo_object_is_compliant(result)
  382. self.assertEqual(result.name, 'spam')
  383. self.assertEqual(result.description, 'eggs')
  384. self.assertEqual(result.file_origin, 'testfile.txt')
  385. self.assertEqual(result.annotations, {'arg1': 'test'})
  386. self.assertEqual(result.units, 1*pq.nA)
  387. assert_array_equal(result.magnitude, self.data1.reshape(-1, 1))
  388. assert_same_sub_schema(result, self.signal1)
  389. self.assertIsInstance(result.channel_index, ChannelIndex)
  390. self.assertIsInstance(result.segment, Segment)
  391. self.assertIs(result.channel_index, self.signal1.channel_index)
  392. self.assertIs(result.segment, self.signal1.segment)
  393. def test__rescale_new(self):
  394. result = self.signal1.copy()
  395. result = result.rescale(pq.pA)
  396. self.assertIsInstance(result, AnalogSignal)
  397. assert_neo_object_is_compliant(result)
  398. self.assertEqual(result.name, 'spam')
  399. self.assertEqual(result.description, 'eggs')
  400. self.assertEqual(result.file_origin, 'testfile.txt')
  401. self.assertEqual(result.annotations, {'arg1': 'test'})
  402. self.assertEqual(result.units, 1*pq.pA)
  403. assert_arrays_almost_equal(np.array(result), self.data1.reshape(-1, 1)*1000., 1e-10)
  404. self.assertIsInstance(result.channel_index, ChannelIndex)
  405. self.assertIsInstance(result.segment, Segment)
  406. self.assertIs(result.channel_index, self.signal1.channel_index)
  407. self.assertIs(result.segment, self.signal1.segment)
  408. def test__rescale_new_incompatible_ValueError(self):
  409. self.assertRaises(ValueError, self.signal1.rescale, pq.mV)
  410. def test_as_array(self):
  411. sig_as_arr = self.signal1.as_array()
  412. self.assertIsInstance(sig_as_arr, np.ndarray)
  413. assert_array_equal(self.data1, sig_as_arr.flat)
  414. def test_as_quantity(self):
  415. sig_as_q = self.signal1.as_quantity()
  416. self.assertIsInstance(sig_as_q, pq.Quantity)
  417. assert_array_equal(self.data1, sig_as_q.magnitude.flat)
  418. class TestAnalogSignalEquality(unittest.TestCase):
  419. def test__signals_with_different_data_complement_should_be_not_equal(self):
  420. signal1 = AnalogSignal(np.arange(10.0), units="mV",
  421. sampling_rate=1*pq.kHz)
  422. signal2 = AnalogSignal(np.arange(10.0), units="mV",
  423. sampling_rate=2*pq.kHz)
  424. assert_neo_object_is_compliant(signal1)
  425. assert_neo_object_is_compliant(signal2)
  426. self.assertNotEqual(signal1, signal2)
  427. class TestAnalogSignalCombination(unittest.TestCase):
  428. def setUp(self):
  429. self.data1 = np.arange(10.0)
  430. self.data1quant = self.data1 * pq.mV
  431. self.signal1 = AnalogSignal(self.data1quant,
  432. sampling_rate=1*pq.kHz,
  433. name='spam', description='eggs',
  434. file_origin='testfile.txt',
  435. arg1='test')
  436. def test__compliant(self):
  437. assert_neo_object_is_compliant(self.signal1)
  438. self.assertEqual(self.signal1.name, 'spam')
  439. self.assertEqual(self.signal1.description, 'eggs')
  440. self.assertEqual(self.signal1.file_origin, 'testfile.txt')
  441. self.assertEqual(self.signal1.annotations, {'arg1': 'test'})
  442. def test__add_const_quantity_should_preserve_data_complement(self):
  443. result = self.signal1 + 0.065*pq.V
  444. self.assertIsInstance(result, AnalogSignal)
  445. assert_neo_object_is_compliant(result)
  446. self.assertEqual(result.name, 'spam')
  447. self.assertEqual(result.description, 'eggs')
  448. self.assertEqual(result.file_origin, 'testfile.txt')
  449. self.assertEqual(result.annotations, {'arg1': 'test'})
  450. assert_array_equal(result.magnitude.flatten(), self.data1 + 65)
  451. self.assertEqual(self.signal1[9, 0], 9*pq.mV)
  452. self.assertEqual(result[9, 0], 74*pq.mV)
  453. self.assertEqual(self.signal1.t_start, result.t_start)
  454. self.assertEqual(self.signal1.sampling_rate, result.sampling_rate)
  455. def test__add_quantity_should_preserve_data_complement(self):
  456. data2 = np.arange(10.0, 20.0).reshape(-1, 1)
  457. data2quant = data2*pq.mV
  458. result = self.signal1 + data2quant
  459. self.assertIsInstance(result, AnalogSignal)
  460. assert_neo_object_is_compliant(result)
  461. self.assertEqual(result.name, 'spam')
  462. self.assertEqual(result.description, 'eggs')
  463. self.assertEqual(result.file_origin, 'testfile.txt')
  464. self.assertEqual(result.annotations, {'arg1': 'test'})
  465. targ = AnalogSignal(np.arange(10.0, 30.0, 2.0), units="mV",
  466. sampling_rate=1*pq.kHz,
  467. name='spam', description='eggs',
  468. file_origin='testfile.txt', arg1='test')
  469. assert_neo_object_is_compliant(targ)
  470. assert_array_equal(result, targ)
  471. assert_same_sub_schema(result, targ)
  472. def test__add_two_consistent_signals_should_preserve_data_complement(self):
  473. data2 = np.arange(10.0, 20.0)
  474. data2quant = data2*pq.mV
  475. signal2 = AnalogSignal(data2quant, sampling_rate=1*pq.kHz)
  476. assert_neo_object_is_compliant(signal2)
  477. result = self.signal1 + signal2
  478. self.assertIsInstance(result, AnalogSignal)
  479. assert_neo_object_is_compliant(result)
  480. self.assertEqual(result.name, 'spam')
  481. self.assertEqual(result.description, 'eggs')
  482. self.assertEqual(result.file_origin, 'testfile.txt')
  483. self.assertEqual(result.annotations, {'arg1': 'test'})
  484. targ = AnalogSignal(np.arange(10.0, 30.0, 2.0), units="mV",
  485. sampling_rate=1*pq.kHz,
  486. name='spam', description='eggs',
  487. file_origin='testfile.txt', arg1='test')
  488. assert_neo_object_is_compliant(targ)
  489. assert_array_equal(result, targ)
  490. assert_same_sub_schema(result, targ)
  491. def test__add_signals_with_inconsistent_data_complement_ValueError(self):
  492. self.signal1.t_start = 0.0*pq.ms
  493. assert_neo_object_is_compliant(self.signal1)
  494. signal2 = AnalogSignal(np.arange(10.0), units="mV",
  495. t_start=100.0*pq.ms, sampling_rate=0.5*pq.kHz)
  496. assert_neo_object_is_compliant(signal2)
  497. self.assertRaises(ValueError, self.signal1.__add__, signal2)
  498. def test__subtract_const_should_preserve_data_complement(self):
  499. result = self.signal1 - 65*pq.mV
  500. self.assertIsInstance(result, AnalogSignal)
  501. assert_neo_object_is_compliant(result)
  502. self.assertEqual(result.name, 'spam')
  503. self.assertEqual(result.description, 'eggs')
  504. self.assertEqual(result.file_origin, 'testfile.txt')
  505. self.assertEqual(result.annotations, {'arg1': 'test'})
  506. self.assertEqual(self.signal1[9, 0], 9*pq.mV)
  507. self.assertEqual(result[9, 0], -56*pq.mV)
  508. assert_array_equal(result.magnitude.flatten(), self.data1 - 65)
  509. self.assertEqual(self.signal1.sampling_rate, result.sampling_rate)
  510. def test__subtract_from_const_should_return_signal(self):
  511. result = 10*pq.mV - self.signal1
  512. self.assertIsInstance(result, AnalogSignal)
  513. assert_neo_object_is_compliant(result)
  514. self.assertEqual(result.name, 'spam')
  515. self.assertEqual(result.description, 'eggs')
  516. self.assertEqual(result.file_origin, 'testfile.txt')
  517. self.assertEqual(result.annotations, {'arg1': 'test'})
  518. self.assertEqual(self.signal1[9, 0], 9*pq.mV)
  519. self.assertEqual(result[9, 0], 1*pq.mV)
  520. assert_array_equal(result.magnitude.flatten(), 10 - self.data1)
  521. self.assertEqual(self.signal1.sampling_rate, result.sampling_rate)
  522. def test__mult_by_const_float_should_preserve_data_complement(self):
  523. result = self.signal1*2
  524. self.assertIsInstance(result, AnalogSignal)
  525. assert_neo_object_is_compliant(result)
  526. self.assertEqual(result.name, 'spam')
  527. self.assertEqual(result.description, 'eggs')
  528. self.assertEqual(result.file_origin, 'testfile.txt')
  529. self.assertEqual(result.annotations, {'arg1': 'test'})
  530. self.assertEqual(self.signal1[9, 0], 9*pq.mV)
  531. self.assertEqual(result[9, 0], 18*pq.mV)
  532. assert_array_equal(result.magnitude.flatten(), self.data1*2)
  533. self.assertEqual(self.signal1.sampling_rate, result.sampling_rate)
  534. def test__divide_by_const_should_preserve_data_complement(self):
  535. result = self.signal1/0.5
  536. self.assertIsInstance(result, AnalogSignal)
  537. assert_neo_object_is_compliant(result)
  538. self.assertEqual(result.name, 'spam')
  539. self.assertEqual(result.description, 'eggs')
  540. self.assertEqual(result.file_origin, 'testfile.txt')
  541. self.assertEqual(result.annotations, {'arg1': 'test'})
  542. self.assertEqual(self.signal1[9, 0], 9*pq.mV)
  543. self.assertEqual(result[9, 0], 18*pq.mV)
  544. assert_array_equal(result.magnitude.flatten(), self.data1/0.5)
  545. self.assertEqual(self.signal1.sampling_rate, result.sampling_rate)
  546. class TestAnalogSignalFunctions(unittest.TestCase):
  547. def test__pickle(self):
  548. signal1 = AnalogSignal([1, 2, 3, 4], sampling_period=1*pq.ms,
  549. units=pq.S)
  550. signal1.annotations['index'] = 2
  551. signal1.channel_index = ChannelIndex(index=[0])
  552. fobj = open('./pickle', 'wb')
  553. pickle.dump(signal1, fobj)
  554. fobj.close()
  555. fobj = open('./pickle', 'rb')
  556. try:
  557. signal2 = pickle.load(fobj)
  558. except ValueError:
  559. signal2 = None
  560. assert_array_equal(signal1, signal2)
  561. assert_array_equal(signal2.channel_index.index, np.array([0]))
  562. fobj.close()
  563. os.remove('./pickle')
  564. class TestAnalogSignalSampling(unittest.TestCase):
  565. def test___get_sampling_rate__period_none_rate_none_ValueError(self):
  566. sampling_rate = None
  567. sampling_period = None
  568. self.assertRaises(ValueError, _get_sampling_rate,
  569. sampling_rate, sampling_period)
  570. def test___get_sampling_rate__period_quant_rate_none(self):
  571. sampling_rate = None
  572. sampling_period = pq.Quantity(10., units=pq.s)
  573. targ_rate = 1/sampling_period
  574. out_rate = _get_sampling_rate(sampling_rate, sampling_period)
  575. self.assertEqual(targ_rate, out_rate)
  576. def test___get_sampling_rate__period_none_rate_quant(self):
  577. sampling_rate = pq.Quantity(10., units=pq.Hz)
  578. sampling_period = None
  579. targ_rate = sampling_rate
  580. out_rate = _get_sampling_rate(sampling_rate, sampling_period)
  581. self.assertEqual(targ_rate, out_rate)
  582. def test___get_sampling_rate__period_rate_equivalent(self):
  583. sampling_rate = pq.Quantity(10., units=pq.Hz)
  584. sampling_period = pq.Quantity(0.1, units=pq.s)
  585. targ_rate = sampling_rate
  586. out_rate = _get_sampling_rate(sampling_rate, sampling_period)
  587. self.assertEqual(targ_rate, out_rate)
  588. def test___get_sampling_rate__period_rate_not_equivalent_ValueError(self):
  589. sampling_rate = pq.Quantity(10., units=pq.Hz)
  590. sampling_period = pq.Quantity(10, units=pq.s)
  591. self.assertRaises(ValueError, _get_sampling_rate,
  592. sampling_rate, sampling_period)
  593. def test___get_sampling_rate__period_none_rate_float_TypeError(self):
  594. sampling_rate = 10.
  595. sampling_period = None
  596. self.assertRaises(TypeError, _get_sampling_rate,
  597. sampling_rate, sampling_period)
  598. def test___get_sampling_rate__period_array_rate_none_TypeError(self):
  599. sampling_rate = None
  600. sampling_period = np.array(10.)
  601. self.assertRaises(TypeError, _get_sampling_rate,
  602. sampling_rate, sampling_period)
  603. if __name__ == "__main__":
  604. unittest.main()