spike_train_surrogates.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. # -*- coding: utf-8 -*-
  2. """
  3. Module to generate surrogates of a spike train by randomising its spike times
  4. in different ways (see [1]). Different methods destroy different features of
  5. the original data:
  6. * randomise_spikes:
  7. randomly reposition all spikes inside the time interval (t_start, t_stop).
  8. Keeps spike count, generates Poisson spike trains with time-stationary
  9. firing rate
  10. * dither_spikes:
  11. dither each spike time around original position by a random amount;
  12. keeps spike count and firing rates computed on a slow temporal scale;
  13. destroys ISIs, making them more exponentially distributed
  14. * dither_spike_train:
  15. dither the whole input spike train (i.e. all spikes equally) by a random
  16. amount; keeps spike count, ISIs, and firing rates computed on a slow
  17. temporal scale
  18. * jitter_spikes:
  19. discretise the full time interval (t_start, t_stop) into time segments
  20. and locally randomise the spike times (see randomise_spikes) inside each
  21. segment. Keeps spike count inside each segment and creates locally Poisson
  22. spike trains with locally time-stationary rates
  23. * shuffle_isis:
  24. shuffle the inter-spike intervals (ISIs) of the spike train randomly,
  25. keeping the first spike time fixed and generating the others from the
  26. new sequence of ISIs. Keeps spike count and ISIs, flattens the firing rate
  27. profile
  28. [1] Louis et al (2010) Surrogate Spike Train Generation Through Dithering in
  29. Operational Time. Front Comput Neurosci. 2010; 4: 127.
  30. Original implementation by: Emiliano Torre [e.torre@fz-juelich.de]
  31. :copyright: Copyright 2015-2016 by the Elephant team, see AUTHORS.txt.
  32. :license: Modified BSD, see LICENSE.txt for details.
  33. """
  34. import numpy as np
  35. import quantities as pq
  36. import neo
  37. try:
  38. import elephant.statistics as es
  39. isi = es.isi
  40. except ImportError:
  41. from .statistics import isi # Convenience when in elephant working dir.
  42. def dither_spikes(spiketrain, dither, n=1, decimals=None, edges=True):
  43. """
  44. Generates surrogates of a spike train by spike dithering.
  45. The surrogates are obtained by uniformly dithering times around the
  46. original position. The dithering is performed independently for each
  47. surrogate.
  48. The surrogates retain the :attr:`t_start` and :attr:`t_stop` of the
  49. original `SpikeTrain` object. Spikes moved beyond this range are lost or
  50. moved to the range's ends, depending on the parameter edge.
  51. Parameters
  52. ----------
  53. spiketrain : neo.SpikeTrain
  54. The spike train from which to generate the surrogates
  55. dither : quantities.Quantity
  56. Amount of dithering. A spike at time t is placed randomly within
  57. ]t-dither, t+dither[.
  58. n : int (optional)
  59. Number of surrogates to be generated.
  60. Default: 1
  61. decimals : int or None (optional)
  62. Number of decimal points for every spike time in the surrogates
  63. If None, machine precision is used.
  64. Default: None
  65. edges : bool (optional)
  66. For surrogate spikes falling outside the range
  67. `[spiketrain.t_start, spiketrain.t_stop)`, whether to drop them out
  68. (for edges = True) or set that to the range's closest end
  69. (for edges = False).
  70. Default: True
  71. Returns
  72. -------
  73. list of neo.SpikeTrain
  74. A list of `neo.SpikeTrain`, each obtained from :attr:`spiketrain` by
  75. randomly dithering its spikes. The range of the surrogate spike trains
  76. is the same as :attr:`spiketrain`.
  77. Examples
  78. --------
  79. >>> import quantities as pq
  80. >>> import neo
  81. >>>
  82. >>> st = neo.SpikeTrain([100, 250, 600, 800]*pq.ms, t_stop=1*pq.s)
  83. >>> print dither_spikes(st, dither = 20*pq.ms) # doctest: +SKIP
  84. [<SpikeTrain(array([ 96.53801903, 248.57047376, 601.48865767,
  85. 815.67209811]) * ms, [0.0 ms, 1000.0 ms])>]
  86. >>> print dither_spikes(st, dither = 20*pq.ms, n=2) # doctest: +SKIP
  87. [<SpikeTrain(array([ 104.24942044, 246.0317873 , 584.55938657,
  88. 818.84446913]) * ms, [0.0 ms, 1000.0 ms])>,
  89. <SpikeTrain(array([ 111.36693058, 235.15750163, 618.87388515,
  90. 786.1807108 ]) * ms, [0.0 ms, 1000.0 ms])>]
  91. >>> print dither_spikes(st, dither = 20*pq.ms, decimals=0) # doctest: +SKIP
  92. [<SpikeTrain(array([ 81., 242., 595., 799.]) * ms,
  93. [0.0 ms, 1000.0 ms])>]
  94. """
  95. # Transform spiketrain into a Quantity object (needed for matrix algebra)
  96. data = spiketrain.view(pq.Quantity)
  97. # Main: generate the surrogates
  98. surr = data.reshape((1, len(data))) + 2 * dither * np.random.random_sample(
  99. (n, len(data))) - dither
  100. # Round the surrogate data to decimal position, if requested
  101. if decimals is not None:
  102. surr = surr.round(decimals)
  103. if edges is False:
  104. # Move all spikes outside [spiketrain.t_start, spiketrain.t_stop] to
  105. # the range's ends
  106. surr = np.minimum(np.maximum(surr.base,
  107. (spiketrain.t_start / spiketrain.units).base),
  108. (spiketrain.t_stop / spiketrain.units).base) * spiketrain.units
  109. else:
  110. # Leave out all spikes outside [spiketrain.t_start, spiketrain.t_stop]
  111. tstart, tstop = (spiketrain.t_start / spiketrain.units).base, \
  112. (spiketrain.t_stop / spiketrain.units).base
  113. surr = [np.sort(s[np.all([s >= tstart, s < tstop], axis=0)]) * spiketrain.units
  114. for s in surr.base]
  115. # Return the surrogates as SpikeTrains
  116. return [neo.SpikeTrain(s,
  117. t_start=spiketrain.t_start,
  118. t_stop=spiketrain.t_stop).rescale(spiketrain.units)
  119. for s in surr]
  120. def randomise_spikes(spiketrain, n=1, decimals=None):
  121. """
  122. Generates surrogates of a spike trains by spike time randomisation.
  123. The surrogates are obtained by keeping the spike count of the original
  124. `SpikeTrain` object, but placing them randomly into the interval
  125. `[spiketrain.t_start, spiketrain.t_stop]`.
  126. This generates independent Poisson neo.SpikeTrain objects (exponentially
  127. distributed inter-spike intervals) while keeping the spike count as in
  128. :attr:`spiketrain`.
  129. Parameters
  130. ----------
  131. spiketrain : neo.SpikeTrain
  132. The spike train from which to generate the surrogates
  133. n : int (optional)
  134. Number of surrogates to be generated.
  135. Default: 1
  136. decimals : int or None (optional)
  137. Number of decimal points for every spike time in the surrogates
  138. If None, machine precision is used.
  139. Default: None
  140. Returns
  141. -------
  142. list of neo.SpikeTrain object(s)
  143. A list of `neo.SpikeTrain` objects, each obtained from :attr:`spiketrain`
  144. by randomly dithering its spikes. The range of the surrogate spike trains
  145. is the same as :attr:`spiketrain`.
  146. Examples
  147. --------
  148. >>> import quantities as pq
  149. >>> import neo
  150. >>>
  151. >>> st = neo.SpikeTrain([100, 250, 600, 800]*pq.ms, t_stop=1*pq.s)
  152. >>> print randomise_spikes(st) # doctest: +SKIP
  153. [<SpikeTrain(array([ 131.23574603, 262.05062963, 549.84371387,
  154. 940.80503832]) * ms, [0.0 ms, 1000.0 ms])>]
  155. >>> print randomise_spikes(st, n=2) # doctest: +SKIP
  156. [<SpikeTrain(array([ 84.53274955, 431.54011743, 733.09605806,
  157. 852.32426583]) * ms, [0.0 ms, 1000.0 ms])>,
  158. <SpikeTrain(array([ 197.74596726, 528.93517359, 567.44599968,
  159. 775.97843799]) * ms, [0.0 ms, 1000.0 ms])>]
  160. >>> print randomise_spikes(st, decimals=0) # doctest: +SKIP
  161. [<SpikeTrain(array([ 29., 667., 720., 774.]) * ms,
  162. [0.0 ms, 1000.0 ms])>]
  163. """
  164. # Create surrogate spike trains as rows of a Quantity array
  165. sts = ((spiketrain.t_stop - spiketrain.t_start) *
  166. np.random.random(size=(n, len(spiketrain))) +
  167. spiketrain.t_start).rescale(spiketrain.units)
  168. # Round the surrogate data to decimal position, if requested
  169. if decimals is not None:
  170. sts = sts.round(decimals)
  171. # Convert the Quantity array to a list of SpikeTrains, and return them
  172. return [neo.SpikeTrain(np.sort(st), t_start=spiketrain.t_start, t_stop=spiketrain.t_stop)
  173. for st in sts]
  174. def shuffle_isis(spiketrain, n=1, decimals=None):
  175. """
  176. Generates surrogates of a neo.SpikeTrain object by inter-spike-interval
  177. (ISI) shuffling.
  178. The surrogates are obtained by randomly sorting the ISIs of the given input
  179. :attr:`spiketrain`. This generates independent `SpikeTrain` object(s) with
  180. same ISI distribution and spike count as in :attr:`spiketrain`, while
  181. destroying temporal dependencies and firing rate profile.
  182. Parameters
  183. ----------
  184. spiketrain : neo.SpikeTrain
  185. The spike train from which to generate the surrogates
  186. n : int (optional)
  187. Number of surrogates to be generated.
  188. Default: 1
  189. decimals : int or None (optional)
  190. Number of decimal points for every spike time in the surrogates
  191. If None, machine precision is used.
  192. Default: None
  193. Returns
  194. -------
  195. list of SpikeTrain
  196. A list of spike trains, each obtained from `spiketrain` by random ISI
  197. shuffling. The range of the surrogate `neo.SpikeTrain` objects is the
  198. same as :attr:`spiketrain`.
  199. Examples
  200. --------
  201. >>> import quantities as pq
  202. >>> import neo
  203. >>>
  204. >>> st = neo.SpikeTrain([100, 250, 600, 800]*pq.ms, t_stop=1*pq.s)
  205. >>> print shuffle_isis(st) # doctest: +SKIP
  206. [<SpikeTrain(array([ 200., 350., 700., 800.]) * ms,
  207. [0.0 ms, 1000.0 ms])>]
  208. >>> print shuffle_isis(st, n=2) # doctest: +SKIP
  209. [<SpikeTrain(array([ 100., 300., 450., 800.]) * ms,
  210. [0.0 ms, 1000.0 ms])>,
  211. <SpikeTrain(array([ 200., 350., 700., 800.]) * ms,
  212. [0.0 ms, 1000.0 ms])>]
  213. """
  214. if len(spiketrain) > 0:
  215. isi0 = spiketrain[0] - spiketrain.t_start
  216. ISIs = np.hstack([isi0, isi(spiketrain)])
  217. # Round the ISIs to decimal position, if requested
  218. if decimals is not None:
  219. ISIs = ISIs.round(decimals)
  220. # Create list of surrogate spike trains by random ISI permutation
  221. sts = []
  222. for i in range(n):
  223. surr_times = np.cumsum(np.random.permutation(ISIs)) *\
  224. spiketrain.units + spiketrain.t_start
  225. sts.append(neo.SpikeTrain(
  226. surr_times, t_start=spiketrain.t_start,
  227. t_stop=spiketrain.t_stop))
  228. else:
  229. sts = []
  230. empty_train = neo.SpikeTrain([] * spiketrain.units,
  231. t_start=spiketrain.t_start,
  232. t_stop=spiketrain.t_stop)
  233. for i in range(n):
  234. sts.append(empty_train)
  235. return sts
  236. def dither_spike_train(spiketrain, shift, n=1, decimals=None, edges=True):
  237. """
  238. Generates surrogates of a neo.SpikeTrain by spike train shifting.
  239. The surrogates are obtained by shifting the whole spike train by a
  240. random amount (independent for each surrogate). Thus, ISIs and temporal
  241. correlations within the spike train are kept. For small shifts, the
  242. firing rate profile is also kept with reasonable accuracy.
  243. The surrogates retain the :attr:`t_start` and :attr:`t_stop` of the
  244. :attr:`spiketrain`. Spikes moved beyond this range are lost or moved to
  245. the range's ends, depending on the parameter edge.
  246. Parameters
  247. ----------
  248. spiketrain : neo.SpikeTrain
  249. The spike train from which to generate the surrogates
  250. shift : quantities.Quantity
  251. Amount of shift. spiketrain is shifted by a random amount uniformly
  252. drawn from the range ]-shift, +shift[.
  253. n : int (optional)
  254. Number of surrogates to be generated.
  255. Default: 1
  256. decimals : int or None (optional)
  257. Number of decimal points for every spike time in the surrogates
  258. If None, machine precision is used.
  259. Default: None
  260. edges : bool
  261. For surrogate spikes falling outside the range `[spiketrain.t_start,
  262. spiketrain.t_stop)`, whether to drop them out (for edges = True) or set
  263. that to the range's closest end (for edges = False).
  264. Default: True
  265. Returns
  266. -------
  267. list of SpikeTrain
  268. A list of spike trains, each obtained from spiketrain by randomly
  269. dithering its spikes. The range of the surrogate spike trains is the
  270. same as :attr:`spiketrain`.
  271. Examples
  272. --------
  273. >>> import quantities as pq
  274. >>> import neo
  275. >>>
  276. >>> st = neo.SpikeTrain([100, 250, 600, 800]*pq.ms, t_stop=1*pq.s)
  277. >>>
  278. >>> print dither_spike_train(st, shift = 20*pq.ms) # doctest: +SKIP
  279. [<SpikeTrain(array([ 96.53801903, 248.57047376, 601.48865767,
  280. 815.67209811]) * ms, [0.0 ms, 1000.0 ms])>]
  281. >>> print dither_spike_train(st, shift = 20*pq.ms, n=2) # doctest: +SKIP
  282. [<SpikeTrain(array([ 92.89084054, 242.89084054, 592.89084054,
  283. 792.89084054]) * ms, [0.0 ms, 1000.0 ms])>,
  284. <SpikeTrain(array([ 84.61079043, 234.61079043, 584.61079043,
  285. 784.61079043]) * ms, [0.0 ms, 1000.0 ms])>]
  286. >>> print dither_spike_train(st, shift = 20*pq.ms, decimals=0) # doctest: +SKIP
  287. [<SpikeTrain(array([ 82., 232., 582., 782.]) * ms,
  288. [0.0 ms, 1000.0 ms])>]
  289. """
  290. # Transform spiketrain into a Quantity object (needed for matrix algebra)
  291. data = spiketrain.view(pq.Quantity)
  292. # Main: generate the surrogates by spike train shifting
  293. surr = data.reshape((1, len(data))) + 2 * shift * \
  294. np.random.random_sample((n, 1)) - shift
  295. # Round the surrogate data to decimal position, if requested
  296. if decimals is not None:
  297. surr = surr.round(decimals)
  298. if edges is False:
  299. # Move all spikes outside [spiketrain.t_start, spiketrain.t_stop] to
  300. # the range's ends
  301. surr = np.minimum(np.maximum(surr.base,
  302. (spiketrain.t_start / spiketrain.units).base),
  303. (spiketrain.t_stop / spiketrain.units).base) * spiketrain.units
  304. else:
  305. # Leave out all spikes outside [spiketrain.t_start, spiketrain.t_stop]
  306. tstart, tstop = (spiketrain.t_start / spiketrain.units).base,\
  307. (spiketrain.t_stop / spiketrain.units).base
  308. surr = [s[np.all([s >= tstart, s < tstop], axis=0)] * spiketrain.units
  309. for s in surr.base]
  310. # Return the surrogates as SpikeTrains
  311. return [neo.SpikeTrain(s, t_start=spiketrain.t_start,
  312. t_stop=spiketrain.t_stop).rescale(spiketrain.units)
  313. for s in surr]
  314. def jitter_spikes(spiketrain, binsize, n=1):
  315. """
  316. Generates surrogates of a :attr:`spiketrain` by spike jittering.
  317. The surrogates are obtained by defining adjacent time bins spanning the
  318. :attr:`spiketrain` range, and random re-positioning (independently for each
  319. surrogate) each spike in the time bin it falls into.
  320. The surrogates retain the :attr:`t_start and :attr:`t_stop` of the
  321. :attr:`spike train`. Note that within each time bin the surrogate
  322. `neo.SpikeTrain` objects are locally poissonian (the inter-spike-interval
  323. are exponentially distributed).
  324. Parameters
  325. ----------
  326. spiketrain : neo.SpikeTrain
  327. The spike train from which to generate the surrogates
  328. binsize : quantities.Quantity
  329. Size of the time bins within which to randomise the spike times.
  330. Note: the last bin arrives until `spiketrain.t_stop` and might have
  331. width different from `binsize`.
  332. n : int (optional)
  333. Number of surrogates to be generated.
  334. Default: 1
  335. Returns
  336. -------
  337. list of SpikeTrain
  338. A list of spike trains, each obtained from `spiketrain` by randomly
  339. replacing its spikes within bins of user-defined width. The range of the
  340. surrogate spike trains is the same as `spiketrain`.
  341. Examples
  342. --------
  343. >>> import quantities as pq
  344. >>> import neo
  345. >>>
  346. >>> st = neo.SpikeTrain([80, 150, 320, 480]*pq.ms, t_stop=1*pq.s)
  347. >>> print jitter_spikes(st, binsize=100*pq.ms) # doctest: +SKIP
  348. [<SpikeTrain(array([ 98.82898293, 178.45805954, 346.93993867,
  349. 461.34268507]) * ms, [0.0 ms, 1000.0 ms])>]
  350. >>> print jitter_spikes(st, binsize=100*pq.ms, n=2) # doctest: +SKIP
  351. [<SpikeTrain(array([ 97.15720041, 199.06945744, 397.51928207,
  352. 402.40065162]) * ms, [0.0 ms, 1000.0 ms])>,
  353. <SpikeTrain(array([ 80.74513157, 173.69371317, 338.05860962,
  354. 495.48869981]) * ms, [0.0 ms, 1000.0 ms])>]
  355. >>> print jitter_spikes(st, binsize=100*pq.ms) # doctest: +SKIP
  356. [<SpikeTrain(array([ 4.55064897e-01, 1.31927046e+02, 3.57846265e+02,
  357. 4.69370604e+02]) * ms, [0.0 ms, 1000.0 ms])>]
  358. """
  359. # Define standard time unit; all time Quantities are converted to
  360. # scalars after being rescaled to this unit, to use the power of numpy
  361. std_unit = binsize.units
  362. # Compute bin edges for the jittering procedure
  363. # !: the last bin arrives until spiketrain.t_stop and might have
  364. # size != binsize
  365. start_dl = spiketrain.t_start.rescale(std_unit).magnitude
  366. stop_dl = spiketrain.t_stop.rescale(std_unit).magnitude
  367. bin_edges = start_dl + np.arange(start_dl, stop_dl, binsize.magnitude)
  368. bin_edges = np.hstack([bin_edges, stop_dl])
  369. # Create n surrogates with spikes randomly placed in the interval (0,1)
  370. surr_poiss01 = np.random.random_sample((n, len(spiketrain)))
  371. # Compute the bin id of each spike
  372. bin_ids = np.array(
  373. (spiketrain.view(pq.Quantity) /
  374. binsize).rescale(pq.dimensionless).magnitude, dtype=int)
  375. # Compute the size of each time bin (as a numpy array)
  376. bin_sizes_dl = np.diff(bin_edges)
  377. # For each spike compute its offset (the left end of the bin it falls
  378. # into) and the size of the bin it falls into
  379. offsets = start_dl + np.array([bin_edges[bin_id] for bin_id in bin_ids])
  380. dilats = np.array([bin_sizes_dl[bin_id] for bin_id in bin_ids])
  381. # Compute each surrogate by dilatating and shifting each spike s in the
  382. # poisson 0-1 spike trains to dilat * s + offset. Attach time unit again
  383. surr = np.sort(surr_poiss01 * dilats + offsets, axis=1) * std_unit
  384. return [neo.SpikeTrain(s, t_start=spiketrain.t_start,
  385. t_stop=spiketrain.t_stop).rescale(spiketrain.units)
  386. for s in surr]
  387. def surrogates(
  388. spiketrain, n=1, surr_method='dither_spike_train', dt=None, decimals=None,
  389. edges=True):
  390. """
  391. Generates surrogates of a :attr:`spiketrain` by a desired generation
  392. method.
  393. This routine is a wrapper for the other surrogate generators in the
  394. module.
  395. The surrogates retain the :attr:`t_start` and :attr:`t_stop` of the
  396. original :attr:`spiketrain`.
  397. Parameters
  398. ----------
  399. spiketrain : neo.SpikeTrain
  400. The spike train from which to generate the surrogates
  401. n : int, optional
  402. Number of surrogates to be generated.
  403. Default: 1
  404. surr_method : str, optional
  405. The method to use to generate surrogate spike trains. Can be one of:
  406. * 'dither_spike_train': see surrogates.dither_spike_train() [dt needed]
  407. * 'dither_spikes': see surrogates.dither_spikes() [dt needed]
  408. * 'jitter_spikes': see surrogates.jitter_spikes() [dt needed]
  409. * 'randomise_spikes': see surrogates.randomise_spikes()
  410. * 'shuffle_isis': see surrogates.shuffle_isis()
  411. Default: 'dither_spike_train'
  412. dt : quantities.Quantity, optional
  413. For methods shifting spike times randomly around their original time
  414. (spike dithering, train shifting) or replacing them randomly within a
  415. certain window (spike jittering), dt represents the size of that
  416. shift / window. For other methods, dt is ignored.
  417. Default: None
  418. decimals : int or None, optional
  419. Number of decimal points for every spike time in the surrogates
  420. If None, machine precision is used.
  421. Default: None
  422. edges : bool
  423. For surrogate spikes falling outside the range `[spiketrain.t_start,
  424. spiketrain.t_stop)`, whether to drop them out (for edges = True) or set
  425. that to the range's closest end (for edges = False).
  426. Default: True
  427. Returns
  428. -------
  429. list of neo.SpikeTrain objects
  430. A list of spike trains, each obtained from `spiketrain` by randomly
  431. dithering its spikes. The range of the surrogate `neo.SpikeTrain`
  432. object(s) is the same as `spiketrain`.
  433. """
  434. # Define the surrogate function to use, depending on the specified method
  435. surrogate_types = {
  436. 'dither_spike_train': dither_spike_train,
  437. 'dither_spikes': dither_spikes,
  438. 'jitter_spikes': jitter_spikes,
  439. 'randomise_spikes': randomise_spikes,
  440. 'shuffle_isis': shuffle_isis}
  441. if surr_method not in surrogate_types.keys():
  442. raise ValueError('specified surr_method (=%s) not valid' % surr_method)
  443. if surr_method in ['dither_spike_train', 'dither_spikes', 'jitter_spikes']:
  444. return surrogate_types[surr_method](
  445. spiketrain, dt, n=n, decimals=decimals, edges=edges)
  446. elif surr_method in ['randomise_spikes', 'shuffle_isis']:
  447. return surrogate_types[surr_method](
  448. spiketrain, n=n, decimals=decimals)