generate_nix_testfiles.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from packaging.version import Version
  2. import numpy as np
  3. import neo
  4. import quantities as pq
  5. n_segments = 2
  6. n_spiketrains = 8
  7. n_waveform_samples = 10
  8. n_analogsignals = 4
  9. n_irregularlysampledsignals = 1
  10. n_events = 2
  11. n_epochs = 3
  12. random_generator = np.random.default_rng(seed=42)
  13. def get_rand(shape=None, min=0, max=1, sorted=False):
  14. data = random_generator.random(shape)
  15. #rescaling random numbers to min-max range
  16. data = data*(max-min) + min
  17. if sorted:
  18. data = np.sort(data)
  19. return data
  20. def generate_basic_block():
  21. block = neo.Block(name=r'my_block')
  22. for seg_idx in range(n_segments):
  23. seg = neo.Segment(name=f'my_segment_{seg_idx}')
  24. block.segments.append(seg)
  25. for spiketrain_idx in range(n_spiketrains):
  26. waveforms = get_rand((10, 14)) * pq.V
  27. st = neo.SpikeTrain(times=get_rand((10), max=10, sorted=True)*pq.s, t_stop=10*pq.s,
  28. name=f'my_spiketrain_{spiketrain_idx}', waveforms=waveforms,
  29. left_sweep=4)
  30. st.segment = seg
  31. seg.spiketrains.append(st)
  32. for anasig_idx in range(n_analogsignals):
  33. anasig = neo.AnalogSignal(signal=get_rand((100, anasig_idx+2),max=1000)*pq.V,
  34. t_start=0*pq.s,
  35. sampling_rate=(anasig_idx+1)*pq.Hz,
  36. name=f'my_analogsignal_{anasig_idx}')
  37. anasig.segment = seg
  38. seg.analogsignals.append(anasig)
  39. for irr_idx in range(n_irregularlysampledsignals):
  40. irrsig = neo.IrregularlySampledSignal(times=get_rand((100), sorted=True, max=100)*pq.s,
  41. signal=get_rand((100,anasig_idx+2),max=1000)*pq.V,
  42. t_start=0*pq.s, t_stop=100*pq.s,
  43. name=f'my_irregularlysampledsignal_{irr_idx}')
  44. irrsig.segment = seg
  45. seg.irregularlysampledsignals.append(irrsig)
  46. for ev_idx in range(n_events):
  47. event = neo.Event(times=get_rand((10),max=10,sorted=True)*pq.s,
  48. labels=np.array([f'my_event_timestamp_{i}' for i in range(10)]),
  49. name=f'my_event_{ev_idx}')
  50. event.segment = seg
  51. seg.events.append(event)
  52. for ep_idx in range(n_epochs):
  53. epoch = neo.Epoch(times=get_rand((3),max=6,sorted=True)*pq.s,
  54. durations=[1.1, 2.2, 3.3]*pq.s,
  55. labels=np.array([f'my_epoch_timestamp_{i}' for i in range(3)]),
  56. name=f'my_epoch_{ep_idx}')
  57. epoch.segment = seg
  58. seg.epochs.append(epoch)
  59. return block
  60. def add_spiketrain_groups(block):
  61. spiketrain_groups = []
  62. for group_idx in range(n_spiketrains):
  63. group = neo.Group(name=f'my_spiketrain_group_{group_idx}', allowed_types=[neo.SpikeTrain])
  64. spiketrain_groups.append(group)
  65. # assign spiketrains to groups
  66. for seg_idx, seg in enumerate(block.segments):
  67. for idx, st in enumerate(seg.spiketrains):
  68. st.group = spiketrain_groups[idx]
  69. spiketrain_groups[idx].spiketrains.append(st)
  70. # attach groups to block
  71. block.groups.extend(spiketrain_groups)
  72. def add_spiketrain_units_channel_indexes(block):
  73. # Linking pattern:
  74. # 1 Channel_unit - last channel of analog signal
  75. # \- 2 spiketrains
  76. spiketrain_units = []
  77. for unit_idx in range(n_spiketrains):
  78. unit = neo.Unit(f'my_unit_{unit_idx}')
  79. spiketrain_units.append(unit)
  80. channel_indexes = []
  81. for asig_idx in range(n_analogsignals):
  82. signals = [seg.analogsignals[asig_idx] for seg in block.segments]
  83. channel_index = neo.ChannelIndex([asig_idx], name=f'my_channel_index_{asig_idx}')
  84. channel_index.analogsignals.extend(signals)
  85. channel_indexes.append(channel_index)
  86. channel_index.block = block
  87. # assign spiketrain to units
  88. for seg_idx, seg in enumerate(block.segments):
  89. for idx, st in enumerate(seg.spiketrains):
  90. st.unit = spiketrain_units[idx]
  91. spiketrain_units[idx].spiketrains.append(st)
  92. # link channel_index to block
  93. block.channel_indexes.extend(channel_indexes)
  94. # attach units to channel_indexes
  95. for idx in range(n_spiketrains):
  96. # link two spiketrains to a single channel_index
  97. channel_indexes[idx//2].units.append(spiketrain_units[idx])
  98. spiketrain_units[idx].channel_index = channel_indexes[idx//2]
  99. if __name__ == '__main__':
  100. neo_version = Version(neo.__version__)
  101. print(f'Generating test file for neo version {neo_version} ...', end='')
  102. block = generate_basic_block()
  103. # adding spiketrain units to neo block depending on container objects available
  104. if hasattr(neo, 'Group'):
  105. add_spiketrain_groups(block)
  106. elif hasattr(neo, 'ChannelIndex'):
  107. add_spiketrain_units_channel_indexes(block)
  108. else:
  109. raise ValueError(f'No mechanism to represent spike units found for neo version: {neo_version}')
  110. io = neo.NixIO(f'generated_file_neo{neo_version}.nix')
  111. io.write_block(block)
  112. io.close()
  113. print('done')