spike_train_correlation.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. # -*- coding: utf-8 -*-
  2. """
  3. This modules provides functions to calculate correlations between spike trains.
  4. :copyright: Copyright 2015-2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. from __future__ import division
  8. import numpy as np
  9. import neo
  10. import quantities as pq
  11. def covariance(binned_sts, binary=False):
  12. '''
  13. Calculate the NxN matrix of pairwise covariances between all combinations
  14. of N binned spike trains.
  15. For each pair of spike trains :math:`(i,j)`, the covariance :math:`C[i,j]`
  16. is obtained by binning :math:`i` and :math:`j` at the desired bin size. Let
  17. :math:`b_i` and :math:`b_j` denote the binary vectors and :math:`m_i` and
  18. :math:`m_j` their respective averages. Then
  19. .. math::
  20. C[i,j] = <b_i-m_i, b_j-m_j> / (l-1)
  21. where <..,.> is the scalar product of two vectors.
  22. For an input of n spike trains, a n x n matrix is returned containing the
  23. covariances for each combination of input spike trains.
  24. If binary is True, the binned spike trains are clipped to 0 or 1 before
  25. computing the covariance, so that the binned vectors :math:`b_i` and
  26. :math:`b_j` are binary.
  27. Parameters
  28. ----------
  29. binned_sts : elephant.conversion.BinnedSpikeTrain
  30. A binned spike train containing the spike trains to be evaluated.
  31. binary : bool, optional
  32. If True, two spikes of a particular spike train falling in the same bin
  33. are counted as 1, resulting in binary binned vectors :math:`b_i`. If
  34. False, the binned vectors :math:`b_i` contain the spike counts per bin.
  35. Default: False
  36. Returns
  37. -------
  38. C : ndarrray
  39. The square matrix of covariances. The element :math:`C[i,j]=C[j,i]` is
  40. the covariance between binned_sts[i] and binned_sts[j].
  41. Examples
  42. --------
  43. Generate two Poisson spike trains
  44. >>> from elephant.spike_train_generation import homogeneous_poisson_process
  45. >>> st1 = homogeneous_poisson_process(
  46. rate=10.0*Hz, t_start=0.0*s, t_stop=10.0*s)
  47. >>> st2 = homogeneous_poisson_process(
  48. rate=10.0*Hz, t_start=0.0*s, t_stop=10.0*s)
  49. Calculate the covariance matrix.
  50. >>> from elephant.conversion import BinnedSpikeTrain
  51. >>> cov_matrix = covariance(BinnedSpikeTrain([st1, st2], binsize=5*ms))
  52. The covariance between the spike trains is stored in cc_matrix[0,1] (or
  53. cov_matrix[1,0]).
  54. Notes
  55. -----
  56. * The spike trains in the binned structure are assumed to all cover the
  57. complete time span of binned_sts [t_start,t_stop).
  58. '''
  59. return __calculate_correlation_or_covariance(
  60. binned_sts, binary, corrcoef_norm=False)
  61. def corrcoef(binned_sts, binary=False):
  62. '''
  63. Calculate the NxN matrix of pairwise Pearson's correlation coefficients
  64. between all combinations of N binned spike trains.
  65. For each pair of spike trains :math:`(i,j)`, the correlation coefficient
  66. :math:`C[i,j]` is obtained by binning :math:`i` and :math:`j` at the
  67. desired bin size. Let :math:`b_i` and :math:`b_j` denote the binary vectors
  68. and :math:`m_i` and :math:`m_j` their respective averages. Then
  69. .. math::
  70. C[i,j] = <b_i-m_i, b_j-m_j> /
  71. \sqrt{<b_i-m_i, b_i-m_i>*<b_j-m_j,b_j-m_j>}
  72. where <..,.> is the scalar product of two vectors.
  73. For an input of n spike trains, a n x n matrix is returned.
  74. Each entry in the matrix is a real number ranging between -1 (perfectly
  75. anti-correlated spike trains) and +1 (perfectly correlated spike trains).
  76. If binary is True, the binned spike trains are clipped to 0 or 1 before
  77. computing the correlation coefficients, so that the binned vectors
  78. :math:`b_i` and :math:`b_j` are binary.
  79. Parameters
  80. ----------
  81. binned_sts : elephant.conversion.BinnedSpikeTrain
  82. A binned spike train containing the spike trains to be evaluated.
  83. binary : bool, optional
  84. If True, two spikes of a particular spike train falling in the same bin
  85. are counted as 1, resulting in binary binned vectors :math:`b_i`. If
  86. False, the binned vectors :math:`b_i` contain the spike counts per bin.
  87. Default: False
  88. Returns
  89. -------
  90. C : ndarrray
  91. The square matrix of correlation coefficients. The element
  92. :math:`C[i,j]=C[j,i]` is the Pearson's correlation coefficient between
  93. binned_sts[i] and binned_sts[j]. If binned_sts contains only one
  94. SpikeTrain, C=1.0.
  95. Examples
  96. --------
  97. Generate two Poisson spike trains
  98. >>> from elephant.spike_train_generation import homogeneous_poisson_process
  99. >>> st1 = homogeneous_poisson_process(
  100. rate=10.0*Hz, t_start=0.0*s, t_stop=10.0*s)
  101. >>> st2 = homogeneous_poisson_process(
  102. rate=10.0*Hz, t_start=0.0*s, t_stop=10.0*s)
  103. Calculate the correlation matrix.
  104. >>> from elephant.conversion import BinnedSpikeTrain
  105. >>> cc_matrix = corrcoef(BinnedSpikeTrain([st1, st2], binsize=5*ms))
  106. The correlation coefficient between the spike trains is stored in
  107. cc_matrix[0,1] (or cc_matrix[1,0]).
  108. Notes
  109. -----
  110. * The spike trains in the binned structure are assumed to all cover the
  111. complete time span of binned_sts [t_start,t_stop).
  112. '''
  113. return __calculate_correlation_or_covariance(
  114. binned_sts, binary, corrcoef_norm=True)
  115. def __calculate_correlation_or_covariance(binned_sts, binary, corrcoef_norm):
  116. '''
  117. Helper function for covariance() and corrcoef() that performs the complete
  118. calculation for either the covariance (corrcoef_norm=False) or correlation
  119. coefficient (corrcoef_norm=True). Both calculations differ only by the
  120. denominator.
  121. Parameters
  122. ----------
  123. binned_sts : elephant.conversion.BinnedSpikeTrain
  124. See covariance() or corrcoef(), respectively.
  125. binary : bool
  126. See covariance() or corrcoef(), respectively.
  127. corrcoef_norm : bool
  128. Use normalization factor for the correlation coefficient rather than
  129. for the covariance.
  130. '''
  131. num_neurons = binned_sts.matrix_rows
  132. # Pre-allocate correlation matrix
  133. C = np.zeros((num_neurons, num_neurons))
  134. # Retrieve unclipped matrix
  135. spmat = binned_sts.to_sparse_array()
  136. # For each row, extract the nonzero column indices and the corresponding
  137. # data in the matrix (for performance reasons)
  138. bin_idx_unique = []
  139. bin_counts_unique = []
  140. if binary:
  141. for s in spmat:
  142. bin_idx_unique.append(s.nonzero()[1])
  143. else:
  144. for s in spmat:
  145. bin_counts_unique.append(s.data)
  146. # All combinations of spike trains
  147. for i in range(num_neurons):
  148. for j in range(i, num_neurons):
  149. # Enumerator:
  150. # $$ <b_i-m_i, b_j-m_j>
  151. # = <b_i, b_j> + l*m_i*m_j - <b_i, M_j> - <b_j, M_i>
  152. # =: ij + l*m_i*m_j - n_i * m_j - n_j * m_i
  153. # = ij - n_i*n_j/l $$
  154. # where $n_i$ is the spike count of spike train $i$,
  155. # $l$ is the number of bins used (i.e., length of $b_i$ or $b_j$),
  156. # and $M_i$ is a vector [m_i, m_i,..., m_i].
  157. if binary:
  158. # Intersect indices to identify number of coincident spikes in
  159. # i and j (more efficient than directly using the dot product)
  160. ij = len(np.intersect1d(
  161. bin_idx_unique[i], bin_idx_unique[j], assume_unique=True))
  162. # Number of spikes in i and j
  163. n_i = len(bin_idx_unique[i])
  164. n_j = len(bin_idx_unique[j])
  165. else:
  166. # Calculate dot product b_i*b_j between unclipped matrices
  167. ij = spmat[i].dot(spmat[j].transpose()).toarray()[0][0]
  168. # Number of spikes in i and j
  169. n_i = np.sum(bin_counts_unique[i])
  170. n_j = np.sum(bin_counts_unique[j])
  171. enumerator = ij - n_i * n_j / binned_sts.num_bins
  172. # Denominator:
  173. if corrcoef_norm:
  174. # Correlation coefficient
  175. # Note:
  176. # $$ <b_i-m_i, b_i-m_i>
  177. # = <b_i, b_i> + m_i^2 - 2 <b_i, M_i>
  178. # =: ii + m_i^2 - 2 n_i * m_i
  179. # = ii - n_i^2 / $$
  180. if binary:
  181. # Here, b_i*b_i is just the number of filled bins (since
  182. # each filled bin of a clipped spike train has value equal
  183. # to 1)
  184. ii = len(bin_idx_unique[i])
  185. jj = len(bin_idx_unique[j])
  186. else:
  187. # directly calculate the dot product based on the counts of
  188. # all filled entries (more efficient than using the dot
  189. # product of the rows of the sparse matrix)
  190. ii = np.dot(bin_counts_unique[i], bin_counts_unique[i])
  191. jj = np.dot(bin_counts_unique[j], bin_counts_unique[j])
  192. denominator = np.sqrt(
  193. (ii - (n_i ** 2) / binned_sts.num_bins) *
  194. (jj - (n_j ** 2) / binned_sts.num_bins))
  195. else:
  196. # Covariance
  197. # $$ l-1 $$
  198. denominator = (binned_sts.num_bins - 1)
  199. # Fill entry of correlation matrix
  200. C[i, j] = C[j, i] = enumerator / denominator
  201. return np.squeeze(C)
  202. def cross_correlation_histogram(
  203. binned_st1, binned_st2, window='full', border_correction=False,
  204. binary=False, kernel=None, method='speed', cross_corr_coef=False):
  205. """
  206. Computes the cross-correlation histogram (CCH) between two binned spike
  207. trains binned_st1 and binned_st2.
  208. Parameters
  209. ----------
  210. binned_st1, binned_st2 : BinnedSpikeTrain
  211. Binned spike trains to cross-correlate. The two spike trains must have
  212. same t_start and t_stop
  213. window : string or list of integer (optional)
  214. ‘full’: This returns the crosscorrelation at each point of overlap,
  215. with an output shape of (N+M-1,). At the end-points of the
  216. cross-correlogram, the signals do not overlap completely, and
  217. boundary effects may be seen.
  218. ‘valid’: Mode valid returns output of length max(M, N) - min(M, N) + 1.
  219. The cross-correlation product is only given for points where the
  220. signals overlap completely.
  221. Values outside the signal boundary have no effect.
  222. list of integer (window[0]=minimum lag, window[1]=maximum lag): The
  223. entries of window are two integers representing the left and
  224. right extremes (expressed as number of bins) where the
  225. crosscorrelation is computed
  226. Default: 'full'
  227. border_correction : bool (optional)
  228. whether to correct for the border effect. If True, the value of the
  229. CCH at bin b (for b=-H,-H+1, ...,H, where H is the CCH half-length)
  230. is multiplied by the correction factor:
  231. (H+1)/(H+1-|b|),
  232. which linearly corrects for loss of bins at the edges.
  233. Default: False
  234. binary : bool (optional)
  235. whether to binary spikes from the same spike train falling in the
  236. same bin. If True, such spikes are considered as a single spike;
  237. otherwise they are considered as different spikes.
  238. Default: False.
  239. kernel : array or None (optional)
  240. A one dimensional array containing an optional smoothing kernel applied
  241. to the resulting CCH. The length N of the kernel indicates the
  242. smoothing window. The smoothing window cannot be larger than the
  243. maximum lag of the CCH. The kernel is normalized to unit area before
  244. being applied to the resulting CCH. Popular choices for the kernel are
  245. * normalized boxcar kernel: numpy.ones(N)
  246. * hamming: numpy.hamming(N)
  247. * hanning: numpy.hanning(N)
  248. * bartlett: numpy.bartlett(N)
  249. If None is specified, the CCH is not smoothed.
  250. Default: None
  251. method : string (optional)
  252. Defines the algorithm to use. "speed" uses numpy.correlate to calculate
  253. the correlation between two binned spike trains using a non-sparse data
  254. representation. Due to various optimizations, it is the fastest
  255. realization. In contrast, the option "memory" uses an own
  256. implementation to calculate the correlation based on sparse matrices,
  257. which is more memory efficient but slower than the "speed" option.
  258. Default: "speed"
  259. cross_corr_coef : bool (optional)
  260. Normalizes the CCH to obtain the cross-correlation coefficient
  261. function ranging from -1 to 1 according to Equation (5.10) in
  262. "Analysis of parallel spike trains", 2010, Gruen & Rotter, Vol 7
  263. Returns
  264. -------
  265. cch : AnalogSignal
  266. Containing the cross-correlation histogram between binned_st1 and
  267. binned_st2.
  268. The central bin of the histogram represents correlation at zero
  269. delay. Offset bins correspond to correlations at a delay equivalent
  270. to the difference between the spike times of binned_st1 and those of
  271. binned_st2: an entry at positive lags corresponds to a spike in
  272. binned_st2 following a spike in binned_st1 bins to the right, and an
  273. entry at negative lags corresponds to a spike in binned_st1 following
  274. a spike in binned_st2.
  275. To illustrate this definition, consider the two spike trains:
  276. binned_st1: 0 0 0 0 1 0 0 0 0 0 0
  277. binned_st2: 0 0 0 0 0 0 0 1 0 0 0
  278. Here, the CCH will have an entry of 1 at lag h=+3.
  279. Consistent with the definition of AnalogSignals, the time axis
  280. represents the left bin borders of each histogram bin. For example,
  281. the time axis might be:
  282. np.array([-2.5 -1.5 -0.5 0.5 1.5]) * ms
  283. bin_ids : ndarray of int
  284. Contains the IDs of the individual histogram bins, where the central
  285. bin has ID 0, bins the left have negative IDs and bins to the right
  286. have positive IDs, e.g.,:
  287. np.array([-3, -2, -1, 0, 1, 2, 3])
  288. Example
  289. -------
  290. Plot the cross-correlation histogram between two Poisson spike trains
  291. >>> import elephant
  292. >>> import matplotlib.pyplot as plt
  293. >>> import quantities as pq
  294. >>> binned_st1 = elephant.conversion.BinnedSpikeTrain(
  295. elephant.spike_train_generation.homogeneous_poisson_process(
  296. 10. * pq.Hz, t_start=0 * pq.ms, t_stop=5000 * pq.ms),
  297. binsize=5. * pq.ms)
  298. >>> binned_st2 = elephant.conversion.BinnedSpikeTrain(
  299. elephant.spike_train_generation.homogeneous_poisson_process(
  300. 10. * pq.Hz, t_start=0 * pq.ms, t_stop=5000 * pq.ms),
  301. binsize=5. * pq.ms)
  302. >>> cc_hist = \
  303. elephant.spike_train_correlation.cross_correlation_histogram(
  304. binned_st1, binned_st2, window=[-30,30],
  305. border_correction=False,
  306. binary=False, kernel=None, method='memory')
  307. >>> plt.bar(
  308. left=cc_hist[0].times.magnitude,
  309. height=cc_hist[0][:, 0].magnitude,
  310. width=cc_hist[0].sampling_period.magnitude)
  311. >>> plt.xlabel('time (' + str(cc_hist[0].times.units) + ')')
  312. >>> plt.ylabel('cross-correlation histogram')
  313. >>> plt.axis('tight')
  314. >>> plt.show()
  315. Alias
  316. -----
  317. cch
  318. """
  319. def _cross_corr_coef(cch_result, binned_st1, binned_st2):
  320. # Normalizes the CCH to obtain the cross-correlation
  321. # coefficient function ranging from -1 to 1
  322. N = max(binned_st1.num_bins, binned_st2.num_bins)
  323. Nx = len(binned_st1.spike_indices[0])
  324. Ny = len(binned_st2.spike_indices[0])
  325. spmat = [binned_st1.to_sparse_array(), binned_st2.to_sparse_array()]
  326. bin_counts_unique = []
  327. for s in spmat:
  328. bin_counts_unique.append(s.data)
  329. ii = np.dot(bin_counts_unique[0], bin_counts_unique[0])
  330. jj = np.dot(bin_counts_unique[1], bin_counts_unique[1])
  331. rho_xy = (cch_result - Nx * Ny / N) / \
  332. np.sqrt((ii - Nx**2. / N) * (jj - Ny**2. / N))
  333. return rho_xy
  334. def _border_correction(counts, max_num_bins, l, r):
  335. # Correct the values taking into account lacking contributes
  336. # at the edges
  337. correction = float(max_num_bins + 1) / np.array(
  338. max_num_bins + 1 - abs(
  339. np.arange(l, r + 1)), float)
  340. return counts * correction
  341. def _kernel_smoothing(counts, kern, l, r):
  342. # Define the kern for smoothing as an ndarray
  343. if hasattr(kern, '__iter__'):
  344. if len(kern) > np.abs(l) + np.abs(r) + 1:
  345. raise ValueError(
  346. 'The length of the kernel cannot be larger than the '
  347. 'length %d of the resulting CCH.' % (
  348. np.abs(l) + np.abs(r) + 1))
  349. kern = np.array(kern, dtype=float)
  350. kern = 1. * kern / sum(kern)
  351. # Check kern parameter
  352. else:
  353. raise ValueError('Invalid smoothing kernel.')
  354. # Smooth the cross-correlation histogram with the kern
  355. return np.convolve(counts, kern, mode='same')
  356. def _cch_memory(binned_st1, binned_st2, left_edge, right_edge,
  357. border_corr, binary, kern):
  358. # Retrieve unclipped matrix
  359. st1_spmat = binned_st1.to_sparse_array()
  360. st2_spmat = binned_st2.to_sparse_array()
  361. # For each row, extract the nonzero column indices
  362. # and the corresponding # data in the matrix (for performance reasons)
  363. st1_bin_idx_unique = st1_spmat.nonzero()[1]
  364. st2_bin_idx_unique = st2_spmat.nonzero()[1]
  365. # Case with binary entries
  366. if binary:
  367. st1_bin_counts_unique = np.array(st1_spmat.data > 0, dtype=int)
  368. st2_bin_counts_unique = np.array(st2_spmat.data > 0, dtype=int)
  369. # Case with all values
  370. else:
  371. st1_bin_counts_unique = st1_spmat.data
  372. st2_bin_counts_unique = st2_spmat.data
  373. # Initialize the counts to an array of zeroes,
  374. # and the bin IDs to integers
  375. # spanning the time axis
  376. counts = np.zeros(np.abs(left_edge) + np.abs(right_edge) + 1)
  377. bin_ids = np.arange(left_edge, right_edge + 1)
  378. # Compute the CCH at lags in left_edge,...,right_edge only
  379. for idx, i in enumerate(st1_bin_idx_unique):
  380. il = np.searchsorted(st2_bin_idx_unique, left_edge + i)
  381. ir = np.searchsorted(st2_bin_idx_unique,
  382. right_edge + i, side='right')
  383. timediff = st2_bin_idx_unique[il:ir] - i
  384. assert ((timediff >= left_edge) & (
  385. timediff <= right_edge)).all(), 'Not all the '
  386. 'entries of cch lie in the window'
  387. counts[timediff + np.abs(left_edge)] += (
  388. st1_bin_counts_unique[idx] * st2_bin_counts_unique[il:ir])
  389. st2_bin_idx_unique = st2_bin_idx_unique[il:]
  390. st2_bin_counts_unique = st2_bin_counts_unique[il:]
  391. # Border correction
  392. if border_corr is True:
  393. counts = _border_correction(
  394. counts, max_num_bins, left_edge, right_edge)
  395. if kern is not None:
  396. # Smoothing
  397. counts = _kernel_smoothing(counts, kern, left_edge, right_edge)
  398. # Transform the array count into an AnalogSignal
  399. cch_result = neo.AnalogSignal(
  400. signal=counts.reshape(counts.size, 1),
  401. units=pq.dimensionless,
  402. t_start=(bin_ids[0] - 0.5) * binned_st1.binsize,
  403. sampling_period=binned_st1.binsize)
  404. # Return only the hist_bins bins and counts before and after the
  405. # central one
  406. return cch_result, bin_ids
  407. def _cch_speed(binned_st1, binned_st2, left_edge, right_edge, cch_mode,
  408. border_corr, binary, kern):
  409. # Retrieve the array of the binne spike train
  410. st1_arr = binned_st1.to_array()[0, :]
  411. st2_arr = binned_st2.to_array()[0, :]
  412. # Convert the to binary version
  413. if binary:
  414. st1_arr = np.array(st1_arr > 0, dtype=int)
  415. st2_arr = np.array(st2_arr > 0, dtype=int)
  416. if cch_mode == 'pad':
  417. # Zero padding to stay between left_edge and right_edge
  418. st1_arr = np.pad(st1_arr,
  419. (int(np.abs(np.min([left_edge, 0]))), np.max(
  420. [right_edge, 0])),
  421. mode='constant')
  422. cch_mode = 'valid'
  423. # Cross correlate the spike trains
  424. counts = np.correlate(st2_arr, st1_arr, mode=cch_mode)
  425. bin_ids = np.r_[left_edge:right_edge + 1]
  426. # Border correction
  427. if border_corr is True:
  428. counts = _border_correction(
  429. counts, max_num_bins, left_edge, right_edge)
  430. if kern is not None:
  431. # Smoothing
  432. counts = _kernel_smoothing(counts, kern, left_edge, right_edge)
  433. # Transform the array count into an AnalogSignal
  434. cch_result = neo.AnalogSignal(
  435. signal=counts.reshape(counts.size, 1),
  436. units=pq.dimensionless,
  437. t_start=(bin_ids[0] - 0.5) * binned_st1.binsize,
  438. sampling_period=binned_st1.binsize)
  439. # Return only the hist_bins bins and counts before and after the
  440. # central one
  441. return cch_result, bin_ids
  442. # Check that the spike trains are binned with the same temporal
  443. # resolution
  444. if not binned_st1.matrix_rows == 1:
  445. raise AssertionError("Spike train must be one dimensional")
  446. if not binned_st2.matrix_rows == 1:
  447. raise AssertionError("Spike train must be one dimensional")
  448. if not binned_st1.binsize == binned_st2.binsize:
  449. raise AssertionError("Bin sizes must be equal")
  450. # Check t_start and t_stop identical (to drop once that the
  451. # pad functionality wil be available in the BinnedSpikeTrain classe)
  452. if not binned_st1.t_start == binned_st2.t_start:
  453. raise AssertionError("Spike train must have same t start")
  454. if not binned_st1.t_stop == binned_st2.t_stop:
  455. raise AssertionError("Spike train must have same t stop")
  456. # The maximum number of of bins
  457. max_num_bins = max(binned_st1.num_bins, binned_st2.num_bins)
  458. # Set the time window in which is computed the cch
  459. # Window parameter given in number of bins (integer)
  460. if isinstance(window[0], int) and isinstance(window[1], int):
  461. # Check the window parameter values
  462. if window[0] >= window[1] or window[0] <= -max_num_bins \
  463. or window[1] >= max_num_bins:
  464. raise ValueError(
  465. "The window exceeds the length of the spike trains")
  466. # Assign left and right edges of the cch
  467. left_edge, right_edge = window[0], window[1]
  468. # The mode in which to compute the cch for the speed implementation
  469. cch_mode = 'pad'
  470. # Case without explicit window parameter
  471. elif window == 'full':
  472. # cch computed for all the possible entries
  473. # Assign left and right edges of the cch
  474. right_edge = binned_st2.num_bins - 1
  475. left_edge = - binned_st1.num_bins + 1
  476. cch_mode = window
  477. # cch compute only for the entries that completely overlap
  478. elif window == 'valid':
  479. # cch computed only for valid entries
  480. # Assign left and right edges of the cch
  481. right_edge = max(binned_st2.num_bins - binned_st1.num_bins, 0)
  482. left_edge = min(binned_st2.num_bins - binned_st1.num_bins, 0)
  483. cch_mode = window
  484. # Check the mode parameter
  485. else:
  486. raise KeyError("Invalid window parameter")
  487. if method == "memory":
  488. cch_result, bin_ids = _cch_memory(
  489. binned_st1, binned_st2, left_edge, right_edge, border_correction,
  490. binary, kernel)
  491. elif method == "speed":
  492. cch_result, bin_ids = _cch_speed(
  493. binned_st1, binned_st2, left_edge, right_edge, cch_mode,
  494. border_correction, binary, kernel)
  495. if cross_corr_coef:
  496. cch_result = _cross_corr_coef(cch_result, binned_st1, binned_st2)
  497. return cch_result, bin_ids
  498. # Alias for common abbreviation
  499. cch = cross_correlation_histogram
  500. def spike_time_tiling_coefficient(spiketrain_1, spiketrain_2, dt=0.005 * pq.s):
  501. """
  502. Calculates the Spike Time Tiling Coefficient (STTC) as described in
  503. (Cutts & Eglen, 2014) following Cutts' implementation in C.
  504. The STTC is a pairwise measure of correlation between spike trains.
  505. It has been proposed as a replacement for the correlation index as it
  506. presents several advantages (e.g. it's not confounded by firing rate,
  507. appropriately distinguishes lack of correlation from anti-correlation,
  508. periods of silence don't add to the correlation and it's sensible to
  509. firing pattern).
  510. The STTC is calculated as follows:
  511. .. math::
  512. STTC = 1/2((PA - TB)/(1 - PA*TB) + (PB - TA)/(1 - PB*TA))
  513. Where `PA` is the proportion of spikes from train 1 that lie within
  514. `[-dt, +dt]` of any spike of train 2 divided by the total number of spikes
  515. in train 1, `PB` is the same proportion for the spikes in train 2;
  516. `TA` is the proportion of total recording time within `[-dt, +dt]` of any
  517. spike in train 1, TB is the same propotion for train 2.
  518. This is a Python implementation compatible with the elephant library of
  519. the original code by C. Cutts written in C and avaiable at:
  520. (https://github.com/CCutts/Detecting_pairwise_correlations_in_spike_trains/blob/master/spike_time_tiling_coefficient.c)
  521. Parameters
  522. ----------
  523. spiketrain_1, spiketrain_2: neo.Spiketrain objects to cross-correlate.
  524. Must have the same t_start and t_stop.
  525. dt: Python Quantity.
  526. The synchronicity window is used for both: the quantification of the
  527. propotion of total recording time that lies [-dt, +dt] of each spike
  528. in each train and the proportion of spikes in `spiketrain_1` that lies
  529. `[-dt, +dt]` of any spike in `spiketrain_2`.
  530. Default : 0.005 * pq.s
  531. Returns
  532. -------
  533. index: float
  534. The spike time tiling coefficient (STTC). Returns np.nan if any spike
  535. train is empty.
  536. References
  537. ----------
  538. Cutts, C. S., & Eglen, S. J. (2014). Detecting Pairwise Correlations in
  539. Spike Trains: An Objective Comparison of Methods and Application to the
  540. Study of Retinal Waves. Journal of Neuroscience, 34(43), 14288–14303.
  541. """
  542. def run_P(spiketrain_1, spiketrain_2, N1, N2, dt):
  543. """
  544. Check every spike in train 1 to see if there's a spike in train 2
  545. within dt
  546. """
  547. Nab = 0
  548. j = 0
  549. for i in range(N1):
  550. while j < N2: # don't need to search all j each iteration
  551. if np.abs(spiketrain_1[i] - spiketrain_2[j]) <= dt:
  552. Nab = Nab + 1
  553. break
  554. elif spiketrain_2[j] > spiketrain_1[i]:
  555. break
  556. else:
  557. j = j + 1
  558. return Nab
  559. def run_T(spiketrain, N, dt):
  560. """
  561. Calculate the proportion of the total recording time 'tiled' by spikes.
  562. """
  563. time_A = 2 * N * dt # maxium possible time
  564. if N == 1: # for just one spike in train
  565. if spiketrain[0] - spiketrain.t_start < dt:
  566. time_A = time_A - dt + spiketrain[0] - spiketrain.t_start
  567. elif spiketrain[0] + dt > spiketrain.t_stop:
  568. time_A = time_A - dt - spiketrain[0] + spiketrain.t_stop
  569. else: # if more than one spike in train
  570. i = 0
  571. while i < (N - 1):
  572. diff = spiketrain[i + 1] - spiketrain[i]
  573. if diff < (2 * dt): # subtract overlap
  574. time_A = time_A - 2 * dt + diff
  575. i += 1
  576. # check if spikes are within dt of the start and/or end
  577. # if so just need to subract overlap of first and/or last spike
  578. if (spiketrain[0] - spiketrain.t_start) < dt:
  579. time_A = time_A + spiketrain[0] - dt - spiketrain.t_start
  580. if (spiketrain.t_stop - spiketrain[N - 1]) < dt:
  581. time_A = time_A - spiketrain[-1] - dt + spiketrain.t_stop
  582. T = (time_A / (spiketrain.t_stop - spiketrain.t_start)).item()
  583. return T
  584. N1 = len(spiketrain_1)
  585. N2 = len(spiketrain_2)
  586. if N1 == 0 or N2 == 0:
  587. index = np.nan
  588. else:
  589. TA = run_T(spiketrain_1, N1, dt)
  590. TB = run_T(spiketrain_2, N2, dt)
  591. PA = run_P(spiketrain_1, spiketrain_2, N1, N2, dt)
  592. PA = PA / N1
  593. PB = run_P(spiketrain_2, spiketrain_1, N2, N1, dt)
  594. PB = PB / N2
  595. index = 0.5 * (PA - TB) / (1 - PA * TB) + 0.5 * (PB - TA) / (
  596. 1 - PB * TA)
  597. return index
  598. sttc = spike_time_tiling_coefficient