spike_train_generation.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063
  1. # -*- coding: utf-8 -*-
  2. """
  3. Functions to generate spike trains from analog signals,
  4. or to generate random spike trains.
  5. Some functions are based on the NeuroTools stgen module, which was mostly
  6. written by Eilif Muller, or from the NeuroTools signals.analogs module.
  7. :copyright: Copyright 2015 by the Elephant team, see AUTHORS.txt.
  8. :license: Modified BSD, see LICENSE.txt for details.
  9. """
  10. from __future__ import division
  11. import numpy as np
  12. from quantities import ms, mV, Hz, Quantity, dimensionless
  13. from neo import SpikeTrain
  14. import random
  15. from elephant.spike_train_surrogates import dither_spike_train
  16. import warnings
  17. def spike_extraction(signal, threshold=0.0 * mV, sign='above',
  18. time_stamps=None, extr_interval=(-2 * ms, 4 * ms)):
  19. """
  20. Return the peak times for all events that cross threshold and the
  21. waveforms. Usually used for extracting spikes from a membrane
  22. potential to calculate waveform properties.
  23. Similar to spike_train_generation.peak_detection.
  24. Parameters
  25. ----------
  26. signal : neo AnalogSignal object
  27. 'signal' is an analog signal.
  28. threshold : A quantity, e.g. in mV
  29. 'threshold' contains a value that must be reached for an event
  30. to be detected. Default: 0.0 * mV.
  31. sign : 'above' or 'below'
  32. 'sign' determines whether to count thresholding crossings
  33. that cross above or below the threshold. Default: 'above'.
  34. time_stamps: None, quantity array or Object with .times interface
  35. if 'spike_train' is a quantity array or exposes a quantity array
  36. exposes the .times interface, it provides the time_stamps
  37. around which the waveform is extracted. If it is None, the
  38. function peak_detection is used to calculate the time_stamps
  39. from signal. Default: None.
  40. extr_interval: unpackable time quantities, len == 2
  41. 'extr_interval' specifies the time interval around the
  42. time_stamps where the waveform is extracted. The default is an
  43. interval of '6 ms'. Default: (-2 * ms, 4 * ms).
  44. Returns
  45. -------
  46. result_st : neo SpikeTrain object
  47. 'result_st' contains the time_stamps of each of the spikes and
  48. the waveforms in result_st.waveforms.
  49. """
  50. # Get spike time_stamps
  51. if time_stamps is None:
  52. time_stamps = peak_detection(signal, threshold, sign=sign)
  53. elif hasattr(time_stamps, 'times'):
  54. time_stamps = time_stamps.times
  55. elif type(time_stamps) is Quantity:
  56. raise TypeError("time_stamps must be None, a quantity array or" +
  57. " expose the.times interface")
  58. if len(time_stamps) == 0:
  59. return SpikeTrain(time_stamps, units=signal.times.units,
  60. t_start=signal.t_start, t_stop=signal.t_stop,
  61. waveforms=np.array([]),
  62. sampling_rate=signal.sampling_rate)
  63. # Unpack the extraction interval from tuple or array
  64. extr_left, extr_right = extr_interval
  65. if extr_left > extr_right:
  66. raise ValueError("extr_interval[0] must be < extr_interval[1]")
  67. if any(np.diff(time_stamps) < extr_interval[1]):
  68. warnings.warn("Waveforms overlap.", UserWarning)
  69. data_left = ((extr_left * signal.sampling_rate).simplified).magnitude
  70. data_right = ((extr_right * signal.sampling_rate).simplified).magnitude
  71. data_stamps = (((time_stamps - signal.t_start) *
  72. signal.sampling_rate).simplified).magnitude
  73. data_stamps = data_stamps.astype(int)
  74. borders_left = data_stamps + data_left
  75. borders_right = data_stamps + data_right
  76. borders = np.dstack((borders_left, borders_right)).flatten()
  77. waveforms = np.array(
  78. np.split(np.array(signal), borders.astype(int))[1::2]) * signal.units
  79. # len(np.shape(waveforms)) == 1 if waveforms do not have the same width.
  80. # this can occur when extr_interval indexes beyond the signal.
  81. # Workaround: delete spikes shorter than the maximum length with
  82. if len(np.shape(waveforms)) == 1:
  83. max_len = (np.array([len(x) for x in waveforms])).max()
  84. to_delete = np.array([idx for idx, x in enumerate(waveforms)
  85. if len(x) < max_len])
  86. waveforms = np.delete(waveforms, to_delete, axis=0)
  87. waveforms = np.array([x for x in waveforms])
  88. warnings.warn("Waveforms " +
  89. ("{:d}, " * len(to_delete)).format(*to_delete) +
  90. "exceeded signal and had to be deleted. " +
  91. "Change extr_interval to keep.")
  92. waveforms = waveforms[:, np.newaxis, :]
  93. return SpikeTrain(time_stamps, units=signal.times.units,
  94. t_start=signal.t_start, t_stop=signal.t_stop,
  95. sampling_rate=signal.sampling_rate, waveforms=waveforms,
  96. left_sweep=extr_left)
  97. def threshold_detection(signal, threshold=0.0 * mV, sign='above'):
  98. """
  99. Returns the times when the analog signal crosses a threshold.
  100. Usually used for extracting spike times from a membrane potential.
  101. Adapted from version in NeuroTools.
  102. Parameters
  103. ----------
  104. signal : neo AnalogSignal object
  105. 'signal' is an analog signal.
  106. threshold : A quantity, e.g. in mV
  107. 'threshold' contains a value that must be reached
  108. for an event to be detected. Default: 0.0 * mV.
  109. sign : 'above' or 'below'
  110. 'sign' determines whether to count thresholding crossings
  111. that cross above or below the threshold.
  112. format : None or 'raw'
  113. Whether to return as SpikeTrain (None)
  114. or as a plain array of times ('raw').
  115. Returns
  116. -------
  117. result_st : neo SpikeTrain object
  118. 'result_st' contains the spike times of each of the events (spikes)
  119. extracted from the signal.
  120. """
  121. assert threshold is not None, "A threshold must be provided"
  122. if sign is 'above':
  123. cutout = np.where(signal > threshold)[0]
  124. elif sign in 'below':
  125. cutout = np.where(signal < threshold)[0]
  126. if len(cutout) <= 0:
  127. events = np.zeros(0)
  128. else:
  129. take = np.where(np.diff(cutout) > 1)[0] + 1
  130. take = np.append(0, take)
  131. time = signal.times
  132. events = time[cutout][take]
  133. events_base = events.base
  134. if events_base is None:
  135. # This occurs in some Python 3 builds due to some
  136. # bug in quantities.
  137. events_base = np.array([event.base for event in events]) # Workaround
  138. result_st = SpikeTrain(events_base, units=signal.times.units,
  139. t_start=signal.t_start, t_stop=signal.t_stop)
  140. return result_st
  141. def peak_detection(signal, threshold=0.0 * mV, sign='above', format=None):
  142. """
  143. Return the peak times for all events that cross threshold.
  144. Usually used for extracting spike times from a membrane potential.
  145. Similar to spike_train_generation.threshold_detection.
  146. Parameters
  147. ----------
  148. signal : neo AnalogSignal object
  149. 'signal' is an analog signal.
  150. threshold : A quantity, e.g. in mV
  151. 'threshold' contains a value that must be reached
  152. for an event to be detected.
  153. sign : 'above' or 'below'
  154. 'sign' determines whether to count thresholding crossings that
  155. cross above or below the threshold. Default: 'above'.
  156. format : None or 'raw'
  157. Whether to return as SpikeTrain (None) or as a plain array
  158. of times ('raw'). Default: None.
  159. Returns
  160. -------
  161. result_st : neo SpikeTrain object
  162. 'result_st' contains the spike times of each of the events
  163. (spikes) extracted from the signal.
  164. """
  165. assert threshold is not None, "A threshold must be provided"
  166. if sign is 'above':
  167. cutout = np.where(signal > threshold)[0]
  168. peak_func = np.argmax
  169. elif sign in 'below':
  170. cutout = np.where(signal < threshold)[0]
  171. peak_func = np.argmin
  172. else:
  173. raise ValueError("sign must be 'above' or 'below'")
  174. if len(cutout) <= 0:
  175. events_base = np.zeros(0)
  176. else:
  177. # Select thr crossings lasting at least 2 dtps, np.diff(cutout) > 2
  178. # This avoids empty slices
  179. border_start = np.where(np.diff(cutout) > 1)[0]
  180. border_end = border_start + 1
  181. borders = np.concatenate((border_start, border_end))
  182. borders = np.append(0, borders)
  183. borders = np.append(borders, len(cutout)-1)
  184. borders = np.sort(borders)
  185. true_borders = cutout[borders]
  186. right_borders = true_borders[1::2] + 1
  187. true_borders = np.sort(np.append(true_borders[0::2], right_borders))
  188. # Workaround for bug that occurs when signal goes below thr for 1 dtp,
  189. # Workaround eliminates empy slices from np. split
  190. backward_mask = np.absolute(np.ediff1d(true_borders, to_begin=1)) > 0
  191. forward_mask = np.absolute(np.ediff1d(true_borders[::-1],
  192. to_begin=1)[::-1]) > 0
  193. true_borders = true_borders[backward_mask * forward_mask]
  194. split_signal = np.split(np.array(signal), true_borders)[1::2]
  195. maxima_idc_split = np.array([peak_func(x) for x in split_signal])
  196. max_idc = maxima_idc_split + true_borders[0::2]
  197. events = signal.times[max_idc]
  198. events_base = events.base
  199. if events_base is None:
  200. # This occurs in some Python 3 builds due to some
  201. # bug in quantities.
  202. events_base = np.array([event.base for event in events]) # Workaround
  203. if format is None:
  204. result_st = SpikeTrain(events_base, units=signal.times.units,
  205. t_start=signal.t_start, t_stop=signal.t_stop)
  206. elif 'raw':
  207. result_st = events_base
  208. else:
  209. raise ValueError("Format argument must be None or 'raw'")
  210. return result_st
  211. def _homogeneous_process(interval_generator, args, mean_rate, t_start, t_stop,
  212. as_array):
  213. """
  214. Returns a spike train whose spikes are a realization of a random process
  215. generated by the function `interval_generator` with the given rate,
  216. starting at time `t_start` and stopping `time t_stop`.
  217. """
  218. def rescale(x):
  219. return (x / mean_rate.units).rescale(t_stop.units)
  220. n = int(((t_stop - t_start) * mean_rate).simplified)
  221. number = np.ceil(n + 3 * np.sqrt(n))
  222. if number < 100:
  223. number = min(5 + np.ceil(2 * n), 100)
  224. assert number > 4 # if positive, number cannot be less than 5
  225. isi = rescale(interval_generator(*args, size=int(number)))
  226. spikes = np.cumsum(isi)
  227. spikes += t_start
  228. i = spikes.searchsorted(t_stop)
  229. if i == len(spikes):
  230. # ISI buffer overrun
  231. extra_spikes = []
  232. t_last = spikes[-1] + rescale(interval_generator(*args, size=1))[0]
  233. while t_last < t_stop:
  234. extra_spikes.append(t_last)
  235. t_last = t_last + rescale(interval_generator(*args, size=1))[0]
  236. # np.concatenate does not conserve units
  237. spikes = Quantity(
  238. np.concatenate(
  239. (spikes, extra_spikes)).magnitude, units=spikes.units)
  240. else:
  241. spikes = spikes[:i]
  242. if as_array:
  243. spikes = spikes.magnitude
  244. else:
  245. spikes = SpikeTrain(
  246. spikes, t_start=t_start, t_stop=t_stop, units=spikes.units)
  247. return spikes
  248. def homogeneous_poisson_process(rate, t_start=0.0 * ms, t_stop=1000.0 * ms,
  249. as_array=False):
  250. """
  251. Returns a spike train whose spikes are a realization of a Poisson process
  252. with the given rate, starting at time `t_start` and stopping time `t_stop`.
  253. All numerical values should be given as Quantities, e.g. 100*Hz.
  254. Parameters
  255. ----------
  256. rate : Quantity scalar with dimension 1/time
  257. The rate of the discharge.
  258. t_start : Quantity scalar with dimension time
  259. The beginning of the spike train.
  260. t_stop : Quantity scalar with dimension time
  261. The end of the spike train.
  262. as_array : bool
  263. If True, a NumPy array of sorted spikes is returned,
  264. rather than a SpikeTrain object.
  265. Raises
  266. ------
  267. ValueError : If `t_start` and `t_stop` are not of type `pq.Quantity`.
  268. Examples
  269. --------
  270. >>> from quantities import Hz, ms
  271. >>> spikes = homogeneous_poisson_process(50*Hz, 0*ms, 1000*ms)
  272. >>> spikes = homogeneous_poisson_process(
  273. 20*Hz, 5000*ms, 10000*ms, as_array=True)
  274. """
  275. if not isinstance(t_start, Quantity) or not isinstance(t_stop, Quantity):
  276. raise ValueError("t_start and t_stop must be of type pq.Quantity")
  277. rate = rate.rescale((1 / t_start).units)
  278. mean_interval = 1 / rate.magnitude
  279. return _homogeneous_process(
  280. np.random.exponential, (mean_interval,), rate, t_start, t_stop,
  281. as_array)
  282. def inhomogeneous_poisson_process(rate, as_array=False):
  283. """
  284. Returns a spike train whose spikes are a realization of an inhomogeneous
  285. Poisson process with the given rate profile.
  286. Parameters
  287. ----------
  288. rate : neo.AnalogSignal
  289. A `neo.AnalogSignal` representing the rate profile evolving over time.
  290. Its values have all to be `>=0`. The output spiketrain will have
  291. `t_start = rate.t_start` and `t_stop = rate.t_stop`
  292. as_array : bool
  293. If True, a NumPy array of sorted spikes is returned,
  294. rather than a SpikeTrain object.
  295. Raises
  296. ------
  297. ValueError : If `rate` contains any negative value.
  298. """
  299. # Check rate contains only positive values
  300. if any(rate < 0) or not rate.size:
  301. raise ValueError(
  302. 'rate must be a positive non empty signal, representing the'
  303. 'rate at time t')
  304. else:
  305. #Generate n hidden Poisson SpikeTrains with rate equal to the peak rate
  306. max_rate = np.max(rate)
  307. homogeneous_poiss = homogeneous_poisson_process(
  308. rate=max_rate, t_stop=rate.t_stop, t_start=rate.t_start)
  309. # Compute the rate profile at each spike time by interpolation
  310. rate_interpolated = _analog_signal_linear_interp(
  311. signal=rate, times=homogeneous_poiss.magnitude *
  312. homogeneous_poiss.units)
  313. # Accept each spike at time t with probability rate(t)/max_rate
  314. u = np.random.uniform(size=len(homogeneous_poiss)) * max_rate
  315. spikes = homogeneous_poiss[u < rate_interpolated.flatten()]
  316. if as_array:
  317. spikes = spikes.magnitude
  318. return spikes
  319. def _analog_signal_linear_interp(signal, times):
  320. '''
  321. Compute the linear interpolation of a signal at desired times.
  322. Given the `signal` (neo.AnalogSignal) taking value `s0` and `s1` at two
  323. consecutive time points `t0` and `t1` `(t0 < t1)`, for every time `t` in
  324. `times`, such that `t0<t<=t1` is returned the value of the linear
  325. interpolation, given by:
  326. `s = ((s1 - s0) / (t1 - t0)) * t + s0`.
  327. Parameters
  328. ----------
  329. times : Quantity vector(time)
  330. The time points for which the interpolation is computed
  331. signal : neo.core.AnalogSignal
  332. The analog signal containing the discretization of the function to
  333. interpolate
  334. Returns
  335. ------
  336. out: Quantity array representing the values of the interpolated signal at the
  337. times given by times
  338. Notes
  339. -----
  340. If `signal` has sampling period `dt=signal.sampling_period`, its values
  341. are defined at `t=signal.times`, such that `t[i] = signal.t_start + i * dt`
  342. The last of such times is lower than
  343. signal.t_stop`:t[-1] = signal.t_stop - dt`.
  344. For the interpolation at times t such that `t[-1] <= t <= signal.t_stop`,
  345. the value of `signal` at `signal.t_stop` is taken to be that
  346. at time `t[-1]`.
  347. '''
  348. dt = signal.sampling_period
  349. t_start = signal.t_start.rescale(signal.times.units)
  350. t_stop = signal.t_stop.rescale(signal.times.units)
  351. # Extend the signal (as a dimensionless array) copying the last value
  352. # one time, and extend its times to t_stop
  353. signal_extended = np.vstack(
  354. [signal.magnitude, signal[-1].magnitude]).flatten()
  355. times_extended = np.hstack([signal.times, t_stop]) * signal.times.units
  356. time_ids = np.floor(((times - t_start) / dt).rescale(
  357. dimensionless).magnitude).astype('i')
  358. # Compute the slope m of the signal at each time in times
  359. y1 = signal_extended[time_ids]
  360. y2 = signal_extended[time_ids + 1]
  361. m = (y2 - y1) / dt
  362. # Interpolate the signal at each time in times by linear interpolation
  363. out = (y1 + m * (times - times_extended[time_ids])) * signal.units
  364. return out.rescale(signal.units)
  365. def homogeneous_gamma_process(a, b, t_start=0.0 * ms, t_stop=1000.0 * ms,
  366. as_array=False):
  367. """
  368. Returns a spike train whose spikes are a realization of a gamma process
  369. with the given parameters, starting at time `t_start` and stopping time
  370. `t_stop` (average rate will be b/a).
  371. All numerical values should be given as Quantities, e.g. 100*Hz.
  372. Parameters
  373. ----------
  374. a : int or float
  375. The shape parameter of the gamma distribution.
  376. b : Quantity scalar with dimension 1/time
  377. The rate parameter of the gamma distribution.
  378. t_start : Quantity scalar with dimension time
  379. The beginning of the spike train.
  380. t_stop : Quantity scalar with dimension time
  381. The end of the spike train.
  382. as_array : bool
  383. If True, a NumPy array of sorted spikes is returned,
  384. rather than a SpikeTrain object.
  385. Raises
  386. ------
  387. ValueError : If `t_start` and `t_stop` are not of type `pq.Quantity`.
  388. Examples
  389. --------
  390. >>> from quantities import Hz, ms
  391. >>> spikes = homogeneous_gamma_process(2.0, 50*Hz, 0*ms, 1000*ms)
  392. >>> spikes = homogeneous_gamma_process(
  393. 5.0, 20*Hz, 5000*ms, 10000*ms, as_array=True)
  394. """
  395. if not isinstance(t_start, Quantity) or not isinstance(t_stop, Quantity):
  396. raise ValueError("t_start and t_stop must be of type pq.Quantity")
  397. b = b.rescale((1 / t_start).units).simplified
  398. rate = b / a
  399. k, theta = a, (1 / b.magnitude)
  400. return _homogeneous_process(np.random.gamma, (k, theta), rate, t_start, t_stop, as_array)
  401. def _n_poisson(rate, t_stop, t_start=0.0 * ms, n=1):
  402. """
  403. Generates one or more independent Poisson spike trains.
  404. Parameters
  405. ----------
  406. rate : Quantity or Quantity array
  407. Expected firing rate (frequency) of each output SpikeTrain.
  408. Can be one of:
  409. * a single Quantity value: expected firing rate of each output
  410. SpikeTrain
  411. * a Quantity array: rate[i] is the expected firing rate of the i-th
  412. output SpikeTrain
  413. t_stop : Quantity
  414. Single common stop time of each output SpikeTrain. Must be > t_start.
  415. t_start : Quantity (optional)
  416. Single common start time of each output SpikeTrain. Must be < t_stop.
  417. Default: 0 s.
  418. n: int (optional)
  419. If rate is a single Quantity value, n specifies the number of
  420. SpikeTrains to be generated. If rate is an array, n is ignored and the
  421. number of SpikeTrains is equal to len(rate).
  422. Default: 1
  423. Returns
  424. -------
  425. list of neo.SpikeTrain
  426. Each SpikeTrain contains one of the independent Poisson spike trains,
  427. either n SpikeTrains of the same rate, or len(rate) SpikeTrains with
  428. varying rates according to the rate parameter. The time unit of the
  429. SpikeTrains is given by t_stop.
  430. """
  431. # Check that the provided input is Hertz of return error
  432. try:
  433. for r in rate.reshape(-1, 1):
  434. r.rescale('Hz')
  435. except AttributeError:
  436. raise ValueError('rate argument must have rate unit (1/time)')
  437. # Check t_start < t_stop and create their strip dimensions
  438. if not t_start < t_stop:
  439. raise ValueError(
  440. 't_start (=%s) must be < t_stop (=%s)' % (t_start, t_stop))
  441. # Set number n of output spike trains (specified or set to len(rate))
  442. if not (type(n) == int and n > 0):
  443. raise ValueError('n (=%s) must be a positive integer' % str(n))
  444. rate_dl = rate.simplified.magnitude.flatten()
  445. # Check rate input parameter
  446. if len(rate_dl) == 1:
  447. if rate_dl < 0:
  448. raise ValueError('rate (=%s) must be non-negative.' % rate)
  449. rates = np.array([rate_dl] * n)
  450. else:
  451. rates = rate_dl.flatten()
  452. if any(rates < 0):
  453. raise ValueError('rate must have non-negative elements.')
  454. sts = []
  455. for r in rates:
  456. sts.append(homogeneous_poisson_process(r * Hz, t_start, t_stop))
  457. return sts
  458. def single_interaction_process(
  459. rate, rate_c, t_stop, n=2, jitter=0 * ms, coincidences='deterministic',
  460. t_start=0 * ms, min_delay=0 * ms, return_coinc=False):
  461. """
  462. Generates a multidimensional Poisson SIP (single interaction process)
  463. plus independent Poisson processes
  464. A Poisson SIP consists of Poisson time series which are independent
  465. except for simultaneous events in all of them. This routine generates
  466. a SIP plus additional parallel independent Poisson processes.
  467. See [1].
  468. Parameters
  469. -----------
  470. t_stop: quantities.Quantity
  471. Total time of the simulated processes. The events are drawn between
  472. 0 and `t_stop`.
  473. rate: quantities.Quantity
  474. Overall mean rate of the time series to be generated (coincidence
  475. rate `rate_c` is subtracted to determine the background rate). Can be:
  476. * a float, representing the overall mean rate of each process. If
  477. so, it must be higher than `rate_c`.
  478. * an iterable of floats (one float per process), each float
  479. representing the overall mean rate of a process. If so, all the
  480. entries must be larger than `rate_c`.
  481. rate_c: quantities.Quantity
  482. Coincidence rate (rate of coincidences for the n-dimensional SIP).
  483. The SIP spike trains will have coincident events with rate `rate_c`
  484. plus independent 'background' events with rate `rate-rate_c`.
  485. n: int, optional
  486. If `rate` is a single Quantity value, `n` specifies the number of
  487. SpikeTrains to be generated. If rate is an array, `n` is ignored and
  488. the number of SpikeTrains is equal to `len(rate)`.
  489. Default: 1
  490. jitter: quantities.Quantity, optional
  491. Jitter for the coincident events. If `jitter == 0`, the events of all
  492. n correlated processes are exactly coincident. Otherwise, they are
  493. jittered around a common time randomly, up to +/- `jitter`.
  494. coincidences: string, optional
  495. Whether the total number of injected coincidences must be determin-
  496. istic (i.e. rate_c is the actual rate with which coincidences are
  497. generated) or stochastic (i.e. rate_c is the mean rate of coincid-
  498. ences):
  499. * 'deterministic': deterministic rate
  500. * 'stochastic': stochastic rate
  501. Default: 'deterministic'
  502. t_start: quantities.Quantity, optional
  503. Starting time of the series. If specified, it must be lower than
  504. t_stop
  505. Default: 0 * ms
  506. min_delay: quantities.Quantity, optional
  507. Minimum delay between consecutive coincidence times.
  508. Default: 0 * ms
  509. return_coinc: bool, optional
  510. Whether to return the coincidence times for the SIP process
  511. Default: False
  512. Returns
  513. --------
  514. output: list
  515. Realization of a SIP consisting of n Poisson processes characterized
  516. by synchronous events (with the given jitter)
  517. If `return_coinc` is `True`, the coincidence times are returned as a
  518. second output argument. They also have an associated time unit (same
  519. as `t_stop`).
  520. References
  521. ----------
  522. [1] Kuhn, Aertsen, Rotter (2003) Neural Comput 15(1):67-101
  523. EXAMPLE:
  524. >>> import quantities as qt
  525. >>> import jelephant.core.stocmod as sm
  526. >>> sip, coinc = sm.sip_poisson(n=10, n=0, t_stop=1*qt.sec, \
  527. rate=20*qt.Hz, rate_c=4, return_coinc = True)
  528. *************************************************************************
  529. """
  530. # Check if n is a positive integer
  531. if not (isinstance(n, int) and n > 0):
  532. raise ValueError('n (=%s) must be a positive integer' % str(n))
  533. # Assign time unit to jitter, or check that its existing unit is a time
  534. # unit
  535. jitter = abs(jitter)
  536. # Define the array of rates from input argument rate. Check that its length
  537. # matches with n
  538. if rate.ndim == 0:
  539. if rate < 0 * Hz:
  540. raise ValueError(
  541. 'rate (=%s) must be non-negative.' % str(rate))
  542. rates_b = np.array(
  543. [rate.magnitude for _ in range(n)]) * rate.units
  544. else:
  545. rates_b = np.array(rate).flatten() * rate.units
  546. if not all(rates_b >= 0. * Hz):
  547. raise ValueError('*rate* must have non-negative elements')
  548. # Check: rate>=rate_c
  549. if np.any(rates_b < rate_c):
  550. raise ValueError('all elements of *rate* must be >= *rate_c*')
  551. # Check min_delay < 1./rate_c
  552. if not (rate_c == 0 * Hz or min_delay < 1. / rate_c):
  553. raise ValueError(
  554. "'*min_delay* (%s) must be lower than 1/*rate_c* (%s)." %
  555. (str(min_delay), str((1. / rate_c).rescale(min_delay.units))))
  556. # Generate the n Poisson processes there are the basis for the SIP
  557. # (coincidences still lacking)
  558. embedded_poisson_trains = _n_poisson(
  559. rate=rates_b - rate_c, t_stop=t_stop, t_start=t_start)
  560. # Convert the trains from neo SpikeTrain objects to simpler Quantity
  561. # objects
  562. embedded_poisson_trains = [
  563. emb.view(Quantity) for emb in embedded_poisson_trains]
  564. # Generate the array of times for coincident events in SIP, not closer than
  565. # min_delay. The array is generated as a quantity from the Quantity class
  566. # in the quantities module
  567. if coincidences == 'deterministic':
  568. Nr_coinc = int(((t_stop - t_start) * rate_c).rescale(dimensionless))
  569. while True:
  570. coinc_times = t_start + \
  571. np.sort(np.random.random(Nr_coinc)) * (t_stop - t_start)
  572. if len(coinc_times) < 2 or min(np.diff(coinc_times)) >= min_delay:
  573. break
  574. elif coincidences == 'stochastic':
  575. while True:
  576. coinc_times = homogeneous_poisson_process(
  577. rate=rate_c, t_stop=t_stop, t_start=t_start)
  578. if len(coinc_times) < 2 or min(np.diff(coinc_times)) >= min_delay:
  579. break
  580. # Convert coinc_times from a neo SpikeTrain object to a Quantity object
  581. # pq.Quantity(coinc_times.base)*coinc_times.units
  582. coinc_times = coinc_times.view(Quantity)
  583. # Set the coincidence times to T-jitter if larger. This ensures that
  584. # the last jittered spike time is <T
  585. for i in range(len(coinc_times)):
  586. if coinc_times[i] > t_stop - jitter:
  587. coinc_times[i] = t_stop - jitter
  588. # Replicate coinc_times n times, and jitter each event in each array by
  589. # +/- jitter (within (t_start, t_stop))
  590. embedded_coinc = coinc_times + \
  591. np.random.random(
  592. (len(rates_b), len(coinc_times))) * 2 * jitter - jitter
  593. embedded_coinc = embedded_coinc + \
  594. (t_start - embedded_coinc) * (embedded_coinc < t_start) - \
  595. (t_stop - embedded_coinc) * (embedded_coinc > t_stop)
  596. # Inject coincident events into the n SIP processes generated above, and
  597. # merge with the n independent processes
  598. sip_process = [
  599. np.sort(np.concatenate((
  600. embedded_poisson_trains[m].rescale(t_stop.units),
  601. embedded_coinc[m].rescale(t_stop.units))) * t_stop.units)
  602. for m in range(len(rates_b))]
  603. # Convert back sip_process and coinc_times from Quantity objects to
  604. # neo.SpikeTrain objects
  605. sip_process = [
  606. SpikeTrain(t, t_start=t_start, t_stop=t_stop).rescale(t_stop.units)
  607. for t in sip_process]
  608. coinc_times = [
  609. SpikeTrain(t, t_start=t_start, t_stop=t_stop).rescale(t_stop.units)
  610. for t in embedded_coinc]
  611. # Return the processes in the specified output_format
  612. if not return_coinc:
  613. output = sip_process
  614. else:
  615. output = sip_process, coinc_times
  616. return output
  617. def _pool_two_spiketrains(a, b, extremes='inner'):
  618. """
  619. Pool the spikes of two spike trains a and b into a unique spike train.
  620. Parameters
  621. ----------
  622. a, b : neo.SpikeTrains
  623. Spike trains to be pooled
  624. extremes: str, optional
  625. Only spikes of a and b in the specified extremes are considered.
  626. * 'inner': pool all spikes from max(a.tstart_ b.t_start) to
  627. min(a.t_stop, b.t_stop)
  628. * 'outer': pool all spikes from min(a.tstart_ b.t_start) to
  629. max(a.t_stop, b.t_stop)
  630. Default: 'inner'
  631. Output
  632. ------
  633. neo.SpikeTrain containing all spikes in a and b falling in the
  634. specified extremes
  635. """
  636. unit = a.units
  637. times_a_dimless = list(a.view(Quantity).magnitude)
  638. times_b_dimless = list(b.rescale(unit).view(Quantity).magnitude)
  639. times = (times_a_dimless + times_b_dimless) * unit
  640. if extremes == 'outer':
  641. t_start = min(a.t_start, b.t_start)
  642. t_stop = max(a.t_stop, b.t_stop)
  643. elif extremes == 'inner':
  644. t_start = max(a.t_start, b.t_start)
  645. t_stop = min(a.t_stop, b.t_stop)
  646. times = times[times > t_start]
  647. times = times[times < t_stop]
  648. else:
  649. raise ValueError(
  650. 'extremes (%s) can only be "inner" or "outer"' % extremes)
  651. pooled_train = SpikeTrain(
  652. times=sorted(times.magnitude), units=unit, t_start=t_start,
  653. t_stop=t_stop)
  654. return pooled_train
  655. def _pool_spiketrains(trains, extremes='inner'):
  656. """
  657. Pool spikes from any number of spike trains into a unique spike train.
  658. Parameters
  659. ----------
  660. trains: list
  661. list of spike trains to merge
  662. extremes: str, optional
  663. Only spikes of a and b in the specified extremes are considered.
  664. * 'inner': pool all spikes from min(a.t_start b.t_start) to
  665. max(a.t_stop, b.t_stop)
  666. * 'outer': pool all spikes from max(a.tstart_ b.t_start) to
  667. min(a.t_stop, b.t_stop)
  668. Default: 'inner'
  669. Output
  670. ------
  671. neo.SpikeTrain containing all spikes in trains falling in the
  672. specified extremes
  673. """
  674. merge_trains = trains[0]
  675. for t in trains[1:]:
  676. merge_trains = _pool_two_spiketrains(
  677. merge_trains, t, extremes=extremes)
  678. t_start, t_stop = merge_trains.t_start, merge_trains.t_stop
  679. merge_trains = sorted(merge_trains)
  680. merge_trains = np.squeeze(merge_trains)
  681. merge_trains = SpikeTrain(
  682. merge_trains, t_stop=t_stop, t_start=t_start, units=trains[0].units)
  683. return merge_trains
  684. def _sample_int_from_pdf(a, n):
  685. """
  686. Draw n independent samples from the set {0,1,...,L}, where L=len(a)-1,
  687. according to the probability distribution a.
  688. a[j] is the probability to sample j, for each j from 0 to L.
  689. Parameters
  690. -----
  691. a: numpy.array
  692. Probability vector (i..e array of sum 1) that at each entry j carries
  693. the probability to sample j (j=0,1,...,len(a)-1).
  694. n: int
  695. Number of samples generated with the function
  696. Output
  697. -------
  698. array of n samples taking values between 0 and n=len(a)-1.
  699. """
  700. A = np.cumsum(a) # cumulative distribution of a
  701. u = np.random.uniform(0, 1, size=n)
  702. U = np.array([u for i in a]).T # copy u (as column vector) len(a) times
  703. return (A < U).sum(axis=1)
  704. def _mother_proc_cpp_stat(A, t_stop, rate, t_start=0 * ms):
  705. """
  706. Generate the hidden ("mother") Poisson process for a Compound Poisson
  707. Process (CPP).
  708. Parameters
  709. ----------
  710. A : numpy.array
  711. Amplitude distribution. A[j] represents the probability of a
  712. synchronous event of size j.
  713. The sum over all entries of a must be equal to one.
  714. t_stop : quantities.Quantity
  715. The stopping time of the mother process
  716. rate : quantities.Quantity
  717. Homogeneous rate of the n spike trains that will be genereted by the
  718. CPP function
  719. t_start : quantities.Quantity, optional
  720. The starting time of the mother process
  721. Default: 0 ms
  722. Output
  723. ------
  724. Poisson spike train representing the mother process generating the CPP
  725. """
  726. N = len(A) - 1
  727. exp_A = np.dot(A, range(N + 1)) # expected value of a
  728. exp_mother = (N * rate) / float(exp_A) # rate of the mother process
  729. return homogeneous_poisson_process(
  730. rate=exp_mother, t_stop=t_stop, t_start=t_start)
  731. def _cpp_hom_stat(A, t_stop, rate, t_start=0 * ms):
  732. """
  733. Generate a Compound Poisson Process (CPP) with amplitude distribution
  734. A and heterogeneous firing rates r=r[0], r[1], ..., r[-1].
  735. Parameters
  736. ----------
  737. A : numpy.ndarray
  738. Amplitude distribution. A[j] represents the probability of a
  739. synchronous event of size j.
  740. The sum over all entries of A must be equal to one.
  741. t_stop : quantities.Quantity
  742. The end time of the output spike trains
  743. rate : quantities.Quantity
  744. Average rate of each spike train generated
  745. t_start : quantities.Quantity, optional
  746. The start time of the output spike trains
  747. Default: 0 ms
  748. Output
  749. ------
  750. List of n neo.SpikeTrains, having average firing rate r and correlated
  751. such to form a CPP with amplitude distribution a
  752. """
  753. # Generate mother process and associated spike labels
  754. mother = _mother_proc_cpp_stat(
  755. A=A, t_stop=t_stop, rate=rate, t_start=t_start)
  756. labels = _sample_int_from_pdf(A, len(mother))
  757. N = len(A) - 1 # Number of trains in output
  758. try: # Faster but more memory-consuming approach
  759. M = len(mother) # number of spikes in the mother process
  760. spike_matrix = np.zeros((N, M), dtype=bool)
  761. # for each spike, take its label l
  762. for spike_id, l in enumerate(labels):
  763. # choose l random trains
  764. train_ids = random.sample(range(N), l)
  765. # and set the spike matrix for that train
  766. for train_id in train_ids:
  767. spike_matrix[train_id, spike_id] = True # and spike to True
  768. times = [[] for i in range(N)]
  769. for train_id, row in enumerate(spike_matrix):
  770. times[train_id] = mother[row].view(Quantity)
  771. except MemoryError: # Slower (~2x) but less memory-consuming approach
  772. print('memory case')
  773. times = [[] for i in range(N)]
  774. for t, l in zip(mother, labels):
  775. train_ids = random.sample(range(N), l)
  776. for train_id in train_ids:
  777. times[train_id].append(t)
  778. trains = [SpikeTrain(
  779. times=t, t_start=t_start, t_stop=t_stop) for t in times]
  780. return trains
  781. def _cpp_het_stat(A, t_stop, rate, t_start=0. * ms):
  782. """
  783. Generate a Compound Poisson Process (CPP) with amplitude distribution
  784. A and heterogeneous firing rates r=r[0], r[1], ..., r[-1].
  785. Parameters
  786. ----------
  787. A : array
  788. CPP's amplitude distribution. A[j] represents the probability of
  789. a synchronous event of size j among the generated spike trains.
  790. The sum over all entries of A must be equal to one.
  791. t_stop : Quantity (time)
  792. The end time of the output spike trains
  793. rate : Quantity (1/time)
  794. Average rate of each spike train generated
  795. t_start : quantities.Quantity, optional
  796. The start time of the output spike trains
  797. Default: 0 ms
  798. Output
  799. ------
  800. List of neo.SpikeTrains with different firing rates, forming
  801. a CPP with amplitude distribution A
  802. """
  803. # Computation of Parameters of the two CPPs that will be merged
  804. # (uncorrelated with heterog. rates + correlated with homog. rates)
  805. N = len(rate) # number of output spike trains
  806. A_exp = np.dot(A, range(N + 1)) # expectation of A
  807. r_sum = np.sum(rate) # sum of all output firing rates
  808. r_min = np.min(rate) # minimum of the firing rates
  809. r1 = r_sum - N * r_min # rate of the uncorrelated CPP
  810. r2 = r_sum / float(A_exp) - r1 # rate of the correlated CPP
  811. r_mother = r1 + r2 # rate of the hidden mother process
  812. # Check the analytical constraint for the amplitude distribution
  813. if A[1] < (r1 / r_mother).rescale(dimensionless).magnitude:
  814. raise ValueError('A[1] too small / A[i], i>1 too high')
  815. # Compute the amplitude distrib of the correlated CPP, and generate it
  816. a = [(r_mother * i) / float(r2) for i in A]
  817. a[1] = a[1] - r1 / float(r2)
  818. CPP = _cpp_hom_stat(a, t_stop, r_min, t_start)
  819. # Generate the independent heterogeneous Poisson processes
  820. POISS = [
  821. homogeneous_poisson_process(i - r_min, t_start, t_stop) for i in rate]
  822. # Pool the correlated CPP and the corresponding Poisson processes
  823. out = [_pool_two_spiketrains(CPP[i], POISS[i]) for i in range(N)]
  824. return out
  825. def compound_poisson_process(rate, A, t_stop, shift=None, t_start=0 * ms):
  826. """
  827. Generate a Compound Poisson Process (CPP; see [1]) with a given amplitude
  828. distribution A and stationary marginal rates r.
  829. The CPP process is a model for parallel, correlated processes with Poisson
  830. spiking statistics at pre-defined firing rates. It is composed of len(A)-1
  831. spike trains with a correlation structure determined by the amplitude
  832. distribution A: A[j] is the probability that a spike occurs synchronously
  833. in any j spike trains.
  834. The CPP is generated by creating a hidden mother Poisson process, and then
  835. copying spikes of the mother process to j of the output spike trains with
  836. probability A[j].
  837. Note that this function decorrelates the firing rate of each SpikeTrain
  838. from the probability for that SpikeTrain to participate in a synchronous
  839. event (which is uniform across SpikeTrains).
  840. Parameters
  841. ----------
  842. rate : quantities.Quantity
  843. Average rate of each spike train generated. Can be:
  844. - a single value, all spike trains will have same rate rate
  845. - an array of values (of length len(A)-1), each indicating the
  846. firing rate of one process in output
  847. A : array
  848. CPP's amplitude distribution. A[j] represents the probability of
  849. a synchronous event of size j among the generated spike trains.
  850. The sum over all entries of A must be equal to one.
  851. t_stop : quantities.Quantity
  852. The end time of the output spike trains.
  853. shift : None or quantities.Quantity, optional
  854. If None, the injected synchrony is exact. If shift is a Quantity, all
  855. the spike trains are shifted independently by a random amount in
  856. the interval [-shift, +shift].
  857. Default: None
  858. t_start : quantities.Quantity, optional
  859. The t_start time of the output spike trains.
  860. Default: 0 s
  861. Returns
  862. -------
  863. List of neo.SpikeTrains
  864. SpikeTrains with specified firing rates forming the CPP with amplitude
  865. distribution A.
  866. References
  867. ----------
  868. [1] Staude, Rotter, Gruen (2010) J Comput Neurosci 29:327-350.
  869. """
  870. # Check A is a probability distribution (it sums to 1 and is positive)
  871. if abs(sum(A) - 1) > np.finfo('float').eps:
  872. raise ValueError(
  873. 'A must be a probability vector, sum(A)= %f !=1' % (sum(A)))
  874. if any([a < 0 for a in A]):
  875. raise ValueError(
  876. 'A must be a probability vector, all the elements of must be >0')
  877. # Check that the rate is not an empty Quantity
  878. if rate.ndim == 1 and len(rate.magnitude) == 0:
  879. raise ValueError('Rate is an empty Quantity array')
  880. # Return empty spike trains for specific parameters
  881. elif A[0] == 1 or np.sum(np.abs(rate.magnitude)) == 0:
  882. return [
  883. SpikeTrain([] * t_stop.units, t_stop=t_stop,
  884. t_start=t_start) for i in range(len(A) - 1)]
  885. else:
  886. # Homogeneous rates
  887. if rate.ndim == 0:
  888. cpp = _cpp_hom_stat(A=A, t_stop=t_stop, rate=rate, t_start=t_start)
  889. # Heterogeneous rates
  890. else:
  891. cpp = _cpp_het_stat(A=A, t_stop=t_stop, rate=rate, t_start=t_start)
  892. if shift is None:
  893. return cpp
  894. # Dither the output spiketrains
  895. else:
  896. cpp = [
  897. dither_spike_train(cp, shift=shift, edges=True)[0]
  898. for cp in cpp]
  899. return cpp
  900. # Alias for the compound poisson process
  901. cpp = compound_poisson_process