DLInt2try.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import os
  2. import shutil
  3. import seaborn as sns
  4. from brian2 import defaultclock, units
  5. from brian2.core.network import Network
  6. from matplotlib import pyplot as plt
  7. from dirDefs import homeFolder
  8. from models.neuronModels import VSNeuron, JOSpikes265, getSineInput
  9. from mplPars import mplPars
  10. from paramLists import synapsePropsList, inputParsList
  11. sns.set(style="whitegrid", rc=mplPars)
  12. simSettleTime = 500 * units.ms
  13. simStepSize = 0.5 * units.ms
  14. simDuration = 100 * units.ms
  15. # inputParsName = 'onePulse'
  16. # inputParsName = 'twoPulse'
  17. inputParsName = 'threePulse'
  18. # simStepSize = 0.5 * units.ms
  19. # simDuration = 1100 * units.ms
  20. # # inputParsName = 'oneSecondPulse'
  21. # # inputParsName = 'pulseTrainInt20Dur10'
  22. # inputParsName = 'pulseTrainInt20Dur16'
  23. # # inputParsName = 'pulseTrainInt33Dur10'
  24. # # inputParsName = 'pulseTrainInt33Dur16'
  25. NeuronProps = "DLInt2Try2"
  26. NeuronSynapseProps = 'DLInt2_syn_try2'
  27. dlint2 = VSNeuron(NeuronProps)
  28. opDir = os.path.join(homeFolder, NeuronProps, NeuronSynapseProps, inputParsName)
  29. if os.path.isdir(opDir):
  30. ch = input('Results already exist at {}. Delete?(y/n):'.format(opDir))
  31. if ch == 'y':
  32. shutil.rmtree(opDir)
  33. os.makedirs(opDir)
  34. period265 = (1 / 265)
  35. inputPars = getattr(inputParsList, inputParsName)
  36. JO = JOSpikes265(nOutputs=1, simSettleTime=simSettleTime, **inputPars)
  37. dlint2.addExp2Synapses(name='JO', nSyn=1, sourceNG=JO.JOSGG,
  38. sourceInd=0,
  39. **getattr(synapsePropsList, NeuronSynapseProps))
  40. net = Network()
  41. net.add(JO.JOSGG)
  42. dlint2.addToNetwork(net)
  43. defaultclock.dt = simStepSize
  44. totalSimDur = simDuration + simSettleTime
  45. net.run(totalSimDur, report='text')
  46. simT, memV = dlint2.getMemVTrace()
  47. spikeTimes = dlint2.getSpikes()
  48. fig, axs = plt.subplots(nrows=2, figsize=(10, 6.25), sharex='col')
  49. axs[0].plot(simT / units.ms, memV / units.mV)
  50. spikesY = memV.min() + 1.05 * (memV.max() - memV.min())
  51. axs[0].plot(spikeTimes / units.ms, [spikesY / units.mV] * spikeTimes.shape[0], 'k^')
  52. axs[0].set_ylabel('DLInt1 \nmemV (mV)')
  53. axs[0].set_xlim([simSettleTime / units.ms - 50, totalSimDur / units.ms + 50])
  54. sineInput = getSineInput(simSettleTime=simSettleTime, simDur=simDuration,
  55. simStepSize=simStepSize,
  56. sinPulseDurs=inputPars['sinPulseDurs'],
  57. sinPulseStarts=inputPars['sinPulseStarts'],
  58. freq=265 * units.Hz)
  59. axs[1].plot(simT / units.ms, sineInput, 'r-', label='Vibration Input')
  60. axs[1].plot(JO.spikeTimes / units.ms, [sineInput.max() * 1.05] * len(JO.spikeTimes), 'k^',
  61. label='JO Spikes')
  62. axs[1].legend(loc='upper right')
  63. axs[1].set_xlabel('time (ms)')
  64. axs[1].set_ylabel('Vibration \nInput/JO\n Spikes')
  65. fig.tight_layout()
  66. fig.canvas.draw()
  67. # plt.show()
  68. fig.savefig(os.path.join(opDir, 'Traces.png'), dpi=150)