test_spectral.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for the spectral module.
  4. :copyright: Copyright 2015 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. import unittest
  8. import numpy as np
  9. import scipy.signal as spsig
  10. import quantities as pq
  11. import neo.core as n
  12. import elephant.spectral
  13. class WelchPSDTestCase(unittest.TestCase):
  14. def test_welch_psd_errors(self):
  15. # generate a dummy data
  16. data = n.AnalogSignal(np.zeros(5000), sampling_period=0.001*pq.s,
  17. units='mV')
  18. # check for invalid parameter values
  19. # - length of segments
  20. self.assertRaises(ValueError, elephant.spectral.welch_psd, data,
  21. len_seg=0)
  22. self.assertRaises(ValueError, elephant.spectral.welch_psd, data,
  23. len_seg=data.shape[0] * 2)
  24. # - number of segments
  25. self.assertRaises(ValueError, elephant.spectral.welch_psd, data,
  26. num_seg=0)
  27. self.assertRaises(ValueError, elephant.spectral.welch_psd, data,
  28. num_seg=data.shape[0] * 2)
  29. # - frequency resolution
  30. self.assertRaises(ValueError, elephant.spectral.welch_psd, data,
  31. freq_res=-1)
  32. self.assertRaises(ValueError, elephant.spectral.welch_psd, data,
  33. freq_res=data.sampling_rate/(data.shape[0]+1))
  34. # - overlap
  35. self.assertRaises(ValueError, elephant.spectral.welch_psd, data,
  36. overlap=-1.0)
  37. self.assertRaises(ValueError, elephant.spectral.welch_psd, data,
  38. overlap=1.1)
  39. def test_welch_psd_behavior(self):
  40. # generate data by adding white noise and a sinusoid
  41. data_length = 5000
  42. sampling_period = 0.001
  43. signal_freq = 100.0
  44. noise = np.random.normal(size=data_length)
  45. signal = [np.sin(2*np.pi*signal_freq*t)
  46. for t in np.arange(0, data_length*sampling_period,
  47. sampling_period)]
  48. data = n.AnalogSignal(np.array(signal+noise),
  49. sampling_period=sampling_period*pq.s,
  50. units='mV')
  51. # consistency between different ways of specifying segment length
  52. freqs1, psd1 = elephant.spectral.welch_psd(data, len_seg=data_length//5, overlap=0)
  53. freqs2, psd2 = elephant.spectral.welch_psd(data, num_seg=5, overlap=0)
  54. self.assertTrue((psd1==psd2).all() and (freqs1==freqs2).all())
  55. # frequency resolution and consistency with data
  56. freq_res = 1.0 * pq.Hz
  57. freqs, psd = elephant.spectral.welch_psd(data, freq_res=freq_res)
  58. self.assertAlmostEqual(freq_res, freqs[1]-freqs[0])
  59. self.assertEqual(freqs[psd.argmax()], signal_freq)
  60. freqs_np, psd_np = elephant.spectral.welch_psd(data.magnitude.flatten(), fs=1/sampling_period, freq_res=freq_res)
  61. self.assertTrue((freqs==freqs_np).all() and (psd==psd_np).all())
  62. # check of scipy.signal.welch() parameters
  63. params = {'window': 'hamming', 'nfft': 1024, 'detrend': 'linear',
  64. 'return_onesided': False, 'scaling': 'spectrum'}
  65. for key, val in params.items():
  66. freqs, psd = elephant.spectral.welch_psd(data, len_seg=1000, overlap=0, **{key: val})
  67. freqs_spsig, psd_spsig = spsig.welch(np.rollaxis(data, 0, len(data.shape)),
  68. fs=1/sampling_period, nperseg=1000, noverlap=0, **{key: val})
  69. self.assertTrue((freqs==freqs_spsig).all() and (psd==psd_spsig).all())
  70. # - generate multidimensional data for check of parameter `axis`
  71. num_channel = 4
  72. data_length = 5000
  73. data_multidim = np.random.normal(size=(num_channel, data_length))
  74. freqs, psd = elephant.spectral.welch_psd(data_multidim)
  75. freqs_T, psd_T = elephant.spectral.welch_psd(data_multidim.T, axis=0)
  76. self.assertTrue(np.all(freqs==freqs_T))
  77. self.assertTrue(np.all(psd==psd_T.T))
  78. def test_welch_psd_input_types(self):
  79. # generate a test data
  80. sampling_period = 0.001
  81. data = n.AnalogSignal(np.array(np.random.normal(size=5000)),
  82. sampling_period=sampling_period*pq.s,
  83. units='mV')
  84. # outputs from AnalogSignal input are of Quantity type (standard usage)
  85. freqs_neo, psd_neo = elephant.spectral.welch_psd(data)
  86. self.assertTrue(isinstance(freqs_neo, pq.quantity.Quantity))
  87. self.assertTrue(isinstance(psd_neo, pq.quantity.Quantity))
  88. # outputs from Quantity array input are of Quantity type
  89. freqs_pq, psd_pq = elephant.spectral.welch_psd(data.magnitude.flatten()*data.units, fs=1/sampling_period)
  90. self.assertTrue(isinstance(freqs_pq, pq.quantity.Quantity))
  91. self.assertTrue(isinstance(psd_pq, pq.quantity.Quantity))
  92. # outputs from Numpy ndarray input are NOT of Quantity type
  93. freqs_np, psd_np = elephant.spectral.welch_psd(data.magnitude.flatten(), fs=1/sampling_period)
  94. self.assertFalse(isinstance(freqs_np, pq.quantity.Quantity))
  95. self.assertFalse(isinstance(psd_np, pq.quantity.Quantity))
  96. # check if the results from different input types are identical
  97. self.assertTrue((freqs_neo==freqs_pq).all() and (psd_neo==psd_pq).all())
  98. self.assertTrue((freqs_neo==freqs_np).all() and (psd_neo==psd_np).all())
  99. def test_welch_psd_multidim_input(self):
  100. # generate multidimensional data
  101. num_channel = 4
  102. data_length = 5000
  103. sampling_period = 0.001
  104. noise = np.random.normal(size=(num_channel, data_length))
  105. data_np = np.array(noise)
  106. # Since row-column order in AnalogSignal is different from the
  107. # conventional one, `data_np` needs to be transposed when its used to
  108. # define an AnalogSignal
  109. data_neo = n.AnalogSignal(data_np.T,
  110. sampling_period=sampling_period*pq.s,
  111. units='mV')
  112. data_neo_1dim = n.AnalogSignal(data_np[0],
  113. sampling_period=sampling_period*pq.s,
  114. units='mV')
  115. # check if the results from different input types are identical
  116. freqs_np, psd_np = elephant.spectral.welch_psd(data_np,
  117. fs=1/sampling_period)
  118. freqs_neo, psd_neo = elephant.spectral.welch_psd(data_neo)
  119. freqs_neo_1dim, psd_neo_1dim = elephant.spectral.welch_psd(data_neo_1dim)
  120. self.assertTrue(np.all(freqs_np==freqs_neo))
  121. self.assertTrue(np.all(psd_np==psd_neo))
  122. self.assertTrue(np.all(psd_neo_1dim==psd_neo[0]))
  123. class WelchCohereTestCase(unittest.TestCase):
  124. def test_welch_cohere_errors(self):
  125. # generate a dummy data
  126. x = n.AnalogSignal(np.zeros(5000), sampling_period=0.001*pq.s,
  127. units='mV')
  128. y = n.AnalogSignal(np.zeros(5000), sampling_period=0.001*pq.s,
  129. units='mV')
  130. # check for invalid parameter values
  131. # - length of segments
  132. self.assertRaises(ValueError, elephant.spectral.welch_cohere, x, y,
  133. len_seg=0)
  134. self.assertRaises(ValueError, elephant.spectral.welch_cohere, x, y,
  135. len_seg=x.shape[0] * 2)
  136. # - number of segments
  137. self.assertRaises(ValueError, elephant.spectral.welch_cohere, x, y,
  138. num_seg=0)
  139. self.assertRaises(ValueError, elephant.spectral.welch_cohere, x, y,
  140. num_seg=x.shape[0] * 2)
  141. # - frequency resolution
  142. self.assertRaises(ValueError, elephant.spectral.welch_cohere, x, y,
  143. freq_res=-1)
  144. self.assertRaises(ValueError, elephant.spectral.welch_cohere, x, y,
  145. freq_res=x.sampling_rate/(x.shape[0]+1))
  146. # - overlap
  147. self.assertRaises(ValueError, elephant.spectral.welch_cohere, x, y,
  148. overlap=-1.0)
  149. self.assertRaises(ValueError, elephant.spectral.welch_cohere, x, y,
  150. overlap=1.1)
  151. def test_welch_cohere_behavior(self):
  152. # generate data by adding white noise and a sinusoid
  153. data_length = 5000
  154. sampling_period = 0.001
  155. signal_freq = 100.0
  156. noise1 = np.random.normal(size=data_length) * 0.01
  157. noise2 = np.random.normal(size=data_length) * 0.01
  158. signal1 = [np.cos(2*np.pi*signal_freq*t)
  159. for t in np.arange(0, data_length*sampling_period,
  160. sampling_period)]
  161. signal2 = [np.sin(2*np.pi*signal_freq*t)
  162. for t in np.arange(0, data_length*sampling_period,
  163. sampling_period)]
  164. x = n.AnalogSignal(np.array(signal1+noise1), units='mV',
  165. sampling_period=sampling_period*pq.s)
  166. y = n.AnalogSignal(np.array(signal2+noise2), units='mV',
  167. sampling_period=sampling_period*pq.s)
  168. # consistency between different ways of specifying segment length
  169. freqs1, coherency1, phase_lag1 = elephant.spectral.welch_cohere(x, y,
  170. len_seg=data_length//5, overlap=0)
  171. freqs2, coherency2, phase_lag2 = elephant.spectral.welch_cohere(x, y,
  172. num_seg=5, overlap=0)
  173. self.assertTrue((coherency1==coherency2).all() and
  174. (phase_lag1==phase_lag2).all() and
  175. (freqs1==freqs2).all())
  176. # frequency resolution and consistency with data
  177. freq_res = 1.0 * pq.Hz
  178. freqs, coherency, phase_lag = elephant.spectral.welch_cohere(x, y,
  179. freq_res=freq_res)
  180. self.assertAlmostEqual(freq_res, freqs[1]-freqs[0])
  181. self.assertAlmostEqual(freqs[coherency.argmax()], signal_freq,
  182. places=2)
  183. self.assertAlmostEqual(phase_lag[coherency.argmax()], np.pi/2,
  184. places=2)
  185. freqs_np, coherency_np, phase_lag_np =\
  186. elephant.spectral.welch_cohere(x.magnitude.flatten(), y.magnitude.flatten(),
  187. fs=1/sampling_period, freq_res=freq_res)
  188. self.assertTrue((freqs == freqs_np).all() and
  189. (coherency[:, 0] == coherency_np).all() and
  190. (phase_lag[:, 0] == phase_lag_np).all())
  191. # - check the behavior of parameter `axis` using multidimensional data
  192. num_channel = 4
  193. data_length = 5000
  194. x_multidim = np.random.normal(size=(num_channel, data_length))
  195. y_multidim = np.random.normal(size=(num_channel, data_length))
  196. freqs, coherency, phase_lag =\
  197. elephant.spectral.welch_cohere(x_multidim, y_multidim)
  198. freqs_T, coherency_T, phase_lag_T =\
  199. elephant.spectral.welch_cohere(x_multidim.T, y_multidim.T, axis=0)
  200. self.assertTrue(np.all(freqs==freqs_T))
  201. self.assertTrue(np.all(coherency==coherency_T.T))
  202. self.assertTrue(np.all(phase_lag==phase_lag_T.T))
  203. def test_welch_cohere_input_types(self):
  204. # generate a test data
  205. sampling_period = 0.001
  206. x = n.AnalogSignal(np.array(np.random.normal(size=5000)),
  207. sampling_period=sampling_period*pq.s,
  208. units='mV')
  209. y = n.AnalogSignal(np.array(np.random.normal(size=5000)),
  210. sampling_period=sampling_period*pq.s,
  211. units='mV')
  212. # outputs from AnalogSignal input are of Quantity type
  213. # (standard usage)
  214. freqs_neo, coherency_neo, phase_lag_neo =\
  215. elephant.spectral.welch_cohere(x, y)
  216. self.assertTrue(isinstance(freqs_neo, pq.quantity.Quantity))
  217. self.assertTrue(isinstance(phase_lag_neo, pq.quantity.Quantity))
  218. # outputs from Quantity array input are of Quantity type
  219. freqs_pq, coherency_pq, phase_lag_pq =\
  220. elephant.spectral.welch_cohere(x.magnitude.flatten()*x.units,
  221. y.magnitude.flatten()*y.units, fs=1/sampling_period)
  222. self.assertTrue(isinstance(freqs_pq, pq.quantity.Quantity))
  223. self.assertTrue(isinstance(phase_lag_pq, pq.quantity.Quantity))
  224. # outputs from Numpy ndarray input are NOT of Quantity type
  225. freqs_np, coherency_np, phase_lag_np =\
  226. elephant.spectral.welch_cohere(x.magnitude.flatten(), y.magnitude.flatten(),
  227. fs=1/sampling_period)
  228. self.assertFalse(isinstance(freqs_np, pq.quantity.Quantity))
  229. self.assertFalse(isinstance(phase_lag_np, pq.quantity.Quantity))
  230. # check if the results from different input types are identical
  231. self.assertTrue((freqs_neo==freqs_pq).all() and
  232. (coherency_neo[:, 0]==coherency_pq).all() and
  233. (phase_lag_neo[:, 0]==phase_lag_pq).all())
  234. self.assertTrue((freqs_neo==freqs_np).all() and
  235. (coherency_neo[:, 0]==coherency_np).all() and
  236. (phase_lag_neo[:, 0]==phase_lag_np).all())
  237. def test_welch_cohere_multidim_input(self):
  238. # generate multidimensional data
  239. num_channel = 4
  240. data_length = 5000
  241. sampling_period = 0.001
  242. x_np = np.array(np.random.normal(size=(num_channel, data_length)))
  243. y_np = np.array(np.random.normal(size=(num_channel, data_length)))
  244. # Since row-column order in AnalogSignal is different from the
  245. # convention in NumPy/SciPy, `data_np` needs to be transposed when its
  246. # used to define an AnalogSignal
  247. x_neo = n.AnalogSignal(x_np.T, units='mV',
  248. sampling_period=sampling_period*pq.s)
  249. y_neo = n.AnalogSignal(y_np.T, units='mV',
  250. sampling_period=sampling_period*pq.s)
  251. x_neo_1dim = n.AnalogSignal(x_np[0], units='mV',
  252. sampling_period=sampling_period*pq.s)
  253. y_neo_1dim = n.AnalogSignal(y_np[0], units='mV',
  254. sampling_period=sampling_period*pq.s)
  255. # check if the results from different input types are identical
  256. freqs_np, coherency_np, phase_lag_np =\
  257. elephant.spectral.welch_cohere(x_np, y_np, fs=1/sampling_period)
  258. freqs_neo, coherency_neo, phase_lag_neo =\
  259. elephant.spectral.welch_cohere(x_neo, y_neo)
  260. freqs_neo_1dim, coherency_neo_1dim, phase_lag_neo_1dim =\
  261. elephant.spectral.welch_cohere(x_neo_1dim, y_neo_1dim)
  262. self.assertTrue(np.all(freqs_np==freqs_neo))
  263. self.assertTrue(np.all(coherency_np.T==coherency_neo))
  264. self.assertTrue(np.all(phase_lag_np.T==phase_lag_neo))
  265. self.assertTrue(np.all(coherency_neo_1dim[:, 0]==coherency_neo[:, 0]))
  266. self.assertTrue(np.all(phase_lag_neo_1dim[:, 0]==phase_lag_neo[:, 0]))
  267. def suite():
  268. suite = unittest.makeSuite(WelchPSDTestCase, 'test')
  269. return suite
  270. if __name__ == "__main__":
  271. runner = unittest.TextTestRunner(verbosity=2)
  272. runner.run(suite())