neuronModels.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. from copy import copy
  2. from typing import Union, Iterable
  3. import numpy as np
  4. from brian2 import NeuronGroup, TimedArray, StateMonitor, SpikeMonitor, SpikeGeneratorGroup, array
  5. from brian2 import Synapses
  6. from brian2 import units
  7. from brian2.core.network import Network
  8. from brian2.equations.codestrings import CodeString
  9. from brian2.equations.equations import Equations
  10. from brian2.units.fundamentalunits import Quantity
  11. from brianUtils import getSimT
  12. def addSynNameVar(var: str, name: str) -> str:
  13. return "_".join([var, name])
  14. def addSynNameEqs(model: str, prePosts: Iterable[Union[str, None]], synName: str) -> tuple:
  15. mEq = Equations(model)
  16. newM = copy(model)
  17. for name in mEq.names:
  18. if (not name.endswith("_post")) and (not name.endswith("_pre")):
  19. newM = newM.replace(name, addSynNameVar(name, synName))
  20. newPrePosts = []
  21. prePostCSs = []
  22. for p in prePosts:
  23. newP = copy(p)
  24. if p:
  25. cs = CodeString(p)
  26. for name in cs.identifiers:
  27. if (not name.endswith("_post")) and (not name.endswith("_pre")):
  28. newP = newP.replace(name, addSynNameVar(name, synName))
  29. newPrePosts.append(newP)
  30. prePostCSs.append(cs)
  31. else:
  32. newPrePosts.append(None)
  33. prePostCSs.append(None)
  34. return newM, mEq, newPrePosts, prePostCSs
  35. class VSNeuron(object):
  36. def __init__(self, model: str, name: str,
  37. inits: dict,
  38. threshold: str,
  39. reset: str,
  40. method: str = "euler"):
  41. super().__init__()
  42. self.ngParams = {"model": model, "threshold": threshold, "reset": reset, "method": method,
  43. "name": name}
  44. self.inits = inits
  45. self.incomingSynapses = {}
  46. self.incomingSynapsePars = {}
  47. self.synCurrentNames = []
  48. self.recordMemVFlag = False
  49. self.recordSpikesFlag = False
  50. self.ng = None
  51. def updateInits(self, initUpdate: dict):
  52. self.inits.update(initUpdate)
  53. def setInputCurrent(self, I: Union[TimedArray, float]):
  54. self.inits["I"] = I
  55. def recordMembraneV(self):
  56. self.recordMemVFlag = True
  57. def recordSpikes(self):
  58. self.recordSpikesFlag = True
  59. def getMemVTrace(self):
  60. assert self.recordMemVFlag, 'Membrane Voltage was not recorded' \
  61. 'for this neuron'
  62. return self.memVRecord.t, self.memVRecord[0].V
  63. def getSpikes(self):
  64. assert self.recordSpikesFlag, "Spikes were not recorded for this neuron"
  65. return self.spikeRecord.t
  66. def addToNetwork(self, network: Network):
  67. self.ngParams["model"] = "\n".join((self.ngParams["model"], "Iext: amp"))
  68. self.inits["Iext"] = 0 * units.amp
  69. eq2Add = "I = Iext "
  70. for synCurrentName in self.synCurrentNames:
  71. self.ngParams["model"] = "\n".join((self.ngParams["model"], "{} : amp".format(synCurrentName)))
  72. self.inits[synCurrentName] = 0 * units.amp
  73. eq2Add += " + {} ".format(synCurrentName)
  74. eq2Add += ": amp"
  75. self.ngParams["model"] = "\n".join((self.ngParams["model"], eq2Add))
  76. self.ng = NeuronGroup(N=1, **self.ngParams)
  77. self.initSim()
  78. network.add(self.ng)
  79. if self.recordMemVFlag:
  80. self.memVRecord = StateMonitor(self.ng, "V", record=[0])
  81. network.add(self.memVRecord)
  82. if self.recordSpikesFlag:
  83. self.spikeRecord = SpikeMonitor(self.ng)
  84. network.add(self.spikeRecord)
  85. for synName, synPars in self.incomingSynapsePars.items():
  86. syn = Synapses(synPars["source"], self.ng,
  87. model=synPars["model"],
  88. on_pre=synPars["on_pre"],
  89. on_post=synPars["on_post"],
  90. method=synPars["method"])
  91. syn.connect(i=synPars["sourceInd"], j=synPars["destInd"])
  92. for k, v in synPars["initMap"].items():
  93. setattr(syn, k, v)
  94. self.incomingSynapses[synName] = syn
  95. network.add(syn)
  96. def initSim(self):
  97. for k, v in self.inits.items():
  98. setattr(self.ng, k, v)
  99. def addSynapse(self, synName: str, sourceNG: NeuronGroup,
  100. model: str, synParsInits: dict, synStateInits: dict,
  101. on_pre: Union[str, None] = None,
  102. on_post: Union[str, None] = None,
  103. sourceInd: int = 0, destInd: int = 0,
  104. method: str = "euler"):
  105. assert synName not in self.incomingSynapses, 'A Synapse with {} already exists'.format(synName)
  106. ISyn_PostInd = model.find("ISyn_post")
  107. assert ISyn_PostInd >= 0, "Synapse model should have an equation for" \
  108. "\'ISyn_post\'"
  109. nextEndLineInd = model.find("\n", ISyn_PostInd)
  110. assert model[nextEndLineInd - 8: nextEndLineInd] == "(summed)", \
  111. "Equation for \'ISyn_post\' must have (summed) flag"
  112. newModel, mEq, [newOn_pre, newOn_post], prePostCSs = \
  113. addSynNameEqs(model, [on_pre, on_post], synName)
  114. allSV = mEq.diff_eq_names
  115. allPars = list(mEq.parameter_names)
  116. for cs in prePostCSs:
  117. if cs:
  118. for i in cs.identifiers:
  119. if i not in allSV:
  120. allPars.append(i)
  121. for par in allPars:
  122. assert par in synParsInits, "Initialization not provided for {} in synParsInits".format(par)
  123. for sv in allSV:
  124. assert sv in synStateInits, "Initialization not provided for {} in synStateInits".format(sv)
  125. ISynName = "_".join(("ISyn", synName))
  126. self.synCurrentNames.append(ISynName)
  127. newModel = newModel.replace("ISyn", ISynName)
  128. initMap = {"delay": synParsInits["delay"]}
  129. for par in allPars:
  130. initMap[addSynNameVar(par, synName)] = synParsInits[par]
  131. for sv in allSV:
  132. initMap[addSynNameVar(sv, synName)] = synStateInits[sv]
  133. synPars = {"source": sourceNG, "model": newModel, "on_pre": newOn_pre,
  134. "on_post": newOn_post, "method": method,
  135. "sourceInd": sourceInd, "destInd": destInd, "initMap": initMap}
  136. self.incomingSynapsePars[synName] = synPars
  137. class JOSpikes265(object):
  138. def __init__(self, nOutputs: int =1, simSettleTime: Quantity = 0 * units.ms,
  139. sinPulseStarts: array = array(()) * units.ms,
  140. sinPulseDurs: array = array(()) * units.ms):
  141. self.nOutputs = nOutputs
  142. freq = 265 * units.Hz
  143. spikePhase = np.deg2rad(240)
  144. phaseDelay = (1 / freq) * (spikePhase / (2 * np.pi))
  145. self.spikeTimes = []
  146. self.spikeInds = []
  147. simSettleTimeF = float(simSettleTime)
  148. for start, dur in zip(sinPulseStarts, sinPulseDurs):
  149. startF = float(start)
  150. durF = float(dur)
  151. periodF = float(1/freq)
  152. phaseDelayF = float(phaseDelay)
  153. cycleStarts = np.arange(startF, startF + durF, periodF)
  154. for i in range(nOutputs):
  155. self.spikeTimes += (simSettleTimeF + cycleStarts + phaseDelayF).tolist()
  156. self.spikeInds += [i] * len(cycleStarts)
  157. self.spikeTimes = self.spikeTimes * units.second
  158. self.JOSGG = SpikeGeneratorGroup(nOutputs, array(self.spikeInds),
  159. self.spikeTimes)
  160. def getSineInput(simDur: Quantity, simStepSize: Quantity,
  161. sinPulseStarts: Quantity, sinPulseDurs: Quantity,
  162. freq: Quantity, simSettleTime: Quantity = 0 * units.ms,):
  163. simT = getSimT(simSettleTime + simDur, simStepSize)
  164. sineInput = np.zeros(simT.shape)
  165. for start, dur in zip(sinPulseStarts, sinPulseDurs):
  166. settleStart = start + simSettleTime
  167. settleEnd = start + dur + simSettleTime
  168. timeMask = (simT >= settleStart) & (simT <= settleEnd)
  169. sineInput[timeMask] = np.sin(2 * np.pi * freq * (simT[timeMask] - (0.5 / freq) - start))
  170. return sineInput