generate_datasets.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # -*- coding: utf-8 -*-
  2. '''
  3. Generate datasets for testing
  4. '''
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import
  7. from datetime import datetime
  8. import numpy as np
  9. from numpy.random import rand
  10. import quantities as pq
  11. from neo.core import (AnalogSignal,
  12. Block,
  13. Epoch, Event,
  14. IrregularlySampledSignal,
  15. ChannelIndex,
  16. Segment, SpikeTrain,
  17. Unit,
  18. class_by_name)
  19. from neo.core.baseneo import _container_name
  20. TEST_ANNOTATIONS = [1, 0, 1.5, "this is a test",
  21. datetime.fromtimestamp(424242424), None]
  22. def generate_one_simple_block(block_name='block_0', nb_segment=3,
  23. supported_objects=[], **kws):
  24. if supported_objects and Block not in supported_objects:
  25. raise ValueError('Block must be in supported_objects')
  26. bl = Block() # name = block_name)
  27. objects = supported_objects
  28. if Segment in objects:
  29. for s in range(nb_segment):
  30. seg = generate_one_simple_segment(seg_name="seg" + str(s),
  31. supported_objects=objects, **kws)
  32. bl.segments.append(seg)
  33. #if RecordingChannel in objects:
  34. # populate_RecordingChannel(bl)
  35. bl.create_many_to_one_relationship()
  36. return bl
  37. def generate_one_simple_segment(seg_name='segment 0',
  38. supported_objects=[],
  39. nb_analogsignal=4,
  40. t_start=0.*pq.s,
  41. sampling_rate=10*pq.kHz,
  42. duration=6.*pq.s,
  43. nb_spiketrain=6,
  44. spikerate_range=[.5*pq.Hz, 12*pq.Hz],
  45. event_types={'stim': ['a', 'b',
  46. 'c', 'd'],
  47. 'enter_zone': ['one',
  48. 'two'],
  49. 'color': ['black',
  50. 'yellow',
  51. 'green'],
  52. },
  53. event_size_range=[5, 20],
  54. epoch_types={'animal state': ['Sleep',
  55. 'Freeze',
  56. 'Escape'],
  57. 'light': ['dark',
  58. 'lighted']
  59. },
  60. epoch_duration_range=[.5, 3.],
  61. ):
  62. if supported_objects and Segment not in supported_objects:
  63. raise ValueError('Segment must be in supported_objects')
  64. seg = Segment(name=seg_name)
  65. if AnalogSignal in supported_objects:
  66. for a in range(nb_analogsignal):
  67. anasig = AnalogSignal(rand(int(sampling_rate * duration)),
  68. sampling_rate=sampling_rate, t_start=t_start,
  69. units=pq.mV, channel_index=a,
  70. name='sig %d for segment %s' % (a, seg.name))
  71. seg.analogsignals.append(anasig)
  72. if SpikeTrain in supported_objects:
  73. for s in range(nb_spiketrain):
  74. spikerate = rand()*np.diff(spikerate_range)
  75. spikerate += spikerate_range[0].magnitude
  76. #spikedata = rand(int((spikerate*duration).simplified))*duration
  77. #sptr = SpikeTrain(spikedata,
  78. # t_start=t_start, t_stop=t_start+duration)
  79. # #, name = 'spiketrain %d'%s)
  80. spikes = rand(int((spikerate*duration).simplified))
  81. spikes.sort() # spikes are supposed to be an ascending sequence
  82. sptr = SpikeTrain(spikes*duration,
  83. t_start=t_start, t_stop=t_start+duration)
  84. sptr.annotations['channel_index'] = s
  85. seg.spiketrains.append(sptr)
  86. if Event in supported_objects:
  87. for name, labels in event_types.items():
  88. evt_size = rand()*np.diff(event_size_range)
  89. evt_size += event_size_range[0]
  90. evt_size = int(evt_size)
  91. labels = np.array(labels, dtype='S')
  92. labels = labels[(rand(evt_size)*len(labels)).astype('i')]
  93. evt = Event(times=rand(evt_size)*duration, labels=labels)
  94. seg.events.append(evt)
  95. if Epoch in supported_objects:
  96. for name, labels in epoch_types.items():
  97. t = 0
  98. times = []
  99. durations = []
  100. while t < duration:
  101. times.append(t)
  102. dur = rand()*np.diff(epoch_duration_range)
  103. dur += epoch_duration_range[0]
  104. durations.append(dur)
  105. t = t+dur
  106. labels = np.array(labels, dtype='S')
  107. labels = labels[(rand(len(times))*len(labels)).astype('i')]
  108. epc = Epoch(times=pq.Quantity(times, units=pq.s),
  109. durations=pq.Quantity([x[0] for x in durations],
  110. units=pq.s),
  111. labels=labels,
  112. )
  113. seg.epochs.append(epc)
  114. # TODO : Spike, Event
  115. seg.create_many_to_one_relationship()
  116. return seg
  117. def generate_from_supported_objects(supported_objects):
  118. #~ create_many_to_one_relationship
  119. if not supported_objects:
  120. raise ValueError('No objects specified')
  121. objects = supported_objects
  122. if Block in supported_objects:
  123. higher = generate_one_simple_block(supported_objects=objects)
  124. # Chris we do not create RC and RCG if it is not in objects
  125. # there is a test in generate_one_simple_block so I removed
  126. #finalize_block(higher)
  127. elif Segment in objects:
  128. higher = generate_one_simple_segment(supported_objects=objects)
  129. else:
  130. #TODO
  131. return None
  132. higher.create_many_to_one_relationship()
  133. return higher
  134. def get_fake_value(name, datatype, dim=0, dtype='float', seed=None,
  135. units=None, obj=None, n=None, shape=None):
  136. """
  137. Returns default value for a given attribute based on neo.core
  138. If seed is not None, use the seed to set the random number generator.
  139. """
  140. if not obj:
  141. obj = 'TestObject'
  142. elif not hasattr(obj, 'lower'):
  143. obj = obj.__name__
  144. if (name in ['name', 'file_origin', 'description'] and
  145. (datatype != str or dim)):
  146. raise ValueError('%s must be str, not a %sD %s' % (name, dim,
  147. datatype))
  148. if name == 'file_origin':
  149. return 'test_file.txt'
  150. if name == 'name':
  151. return '%s%s' % (obj, get_fake_value('', datatype, seed=seed))
  152. if name == 'description':
  153. return 'test %s %s' % (obj, get_fake_value('', datatype, seed=seed))
  154. if seed is not None:
  155. np.random.seed(seed)
  156. if datatype == str:
  157. return str(np.random.randint(100000))
  158. if datatype == int:
  159. return np.random.randint(100)
  160. if datatype == float:
  161. return 1000. * np.random.random()
  162. if datatype == datetime:
  163. return datetime.fromtimestamp(1000000000*np.random.random())
  164. if (name in ['t_start', 't_stop', 'sampling_rate'] and
  165. (datatype != pq.Quantity or dim)):
  166. raise ValueError('%s must be a 0D Quantity, not a %sD %s' % (name, dim,
  167. datatype))
  168. # only put array types below here
  169. if units is not None:
  170. pass
  171. elif name in ['t_start', 't_stop',
  172. 'time', 'times',
  173. 'duration', 'durations']:
  174. units = pq.millisecond
  175. elif name == 'sampling_rate':
  176. units = pq.Hz
  177. elif datatype == pq.Quantity:
  178. units = np.random.choice(['nA', 'mA', 'A', 'mV', 'V'])
  179. units = getattr(pq, units)
  180. if name == 'sampling_rate':
  181. data = np.array(10000.0)
  182. elif name == 't_start':
  183. data = np.array(0.0)
  184. elif name == 't_stop':
  185. data = np.array(1.0)
  186. elif n and name == 'channel_indexes':
  187. data = np.arange(n)
  188. elif n and name == 'channel_names':
  189. data = np.array(["ch%d" % i for i in range(n)])
  190. elif n and obj == 'AnalogSignal':
  191. if name == 'signal':
  192. size = []
  193. for _ in range(int(dim)):
  194. size.append(np.random.randint(5) + 1)
  195. size[1] = n
  196. data = np.random.random(size)*1000.
  197. else:
  198. size = []
  199. for _ in range(int(dim)):
  200. if shape is None :
  201. if name == "times":
  202. size.append(5)
  203. else :
  204. size.append(np.random.randint(5) + 1)
  205. else:
  206. size.append(shape)
  207. data = np.random.random(size)
  208. if name not in ['time', 'times']:
  209. data *= 1000.
  210. if np.dtype(dtype) != np.float64:
  211. data = data.astype(dtype)
  212. if datatype == np.ndarray:
  213. return data
  214. if datatype == list:
  215. return data.tolist()
  216. if datatype == pq.Quantity:
  217. return data * units # set the units
  218. # we have gone through everything we know, so it must be something invalid
  219. raise ValueError('Unknown name/datatype combination %s %s' % (name,
  220. datatype))
  221. def get_fake_values(cls, annotate=True, seed=None, n=None):
  222. """
  223. Returns a dict containing the default values for all attribute for
  224. a class from neo.core.
  225. If seed is not None, use the seed to set the random number generator.
  226. The seed is incremented by 1 for each successive object.
  227. If annotate is True (default), also add annotations to the values.
  228. """
  229. if hasattr(cls, 'lower'): # is this a test that cls is a string? better to use isinstance(cls, basestring), no?
  230. cls = class_by_name[cls]
  231. kwargs = {} # assign attributes
  232. for i, attr in enumerate(cls._necessary_attrs + cls._recommended_attrs):
  233. if seed is not None:
  234. iseed = seed + i
  235. else:
  236. iseed = None
  237. kwargs[attr[0]] = get_fake_value(*attr, seed=iseed, obj=cls, n=n)
  238. if 'waveforms' in kwargs : #everything here is to force the kwargs to have len(time) == kwargs["waveforms"].shape[0]
  239. if len(kwargs["times"]) != kwargs["waveforms"].shape[0] :
  240. if len(kwargs["times"]) < kwargs["waveforms"].shape[0] :
  241. dif = kwargs["waveforms"].shape[0] - len(kwargs["times"])
  242. new_times =[]
  243. for i in kwargs["times"].magnitude :
  244. new_times.append(i)
  245. np.random.seed(0)
  246. new_times = np.concatenate([new_times, np.random.random(dif)])
  247. kwargs["times"] = pq.Quantity(new_times, units=pq.ms)
  248. else :
  249. kwargs['times'] = kwargs['times'][:kwargs["waveforms"].shape[0]]
  250. if 'times' in kwargs and 'signal' in kwargs:
  251. kwargs['times'] = kwargs['times'][:len(kwargs['signal'])]
  252. kwargs['signal'] = kwargs['signal'][:len(kwargs['times'])]
  253. if annotate:
  254. kwargs.update(get_annotations())
  255. kwargs['seed'] = seed
  256. return kwargs
  257. def get_annotations():
  258. '''
  259. Returns a dict containing the default values for annotations for
  260. a class from neo.core.
  261. '''
  262. return dict([(str(i), ann) for i, ann in enumerate(TEST_ANNOTATIONS)])
  263. def fake_neo(obj_type="Block", cascade=True, seed=None, n=1):
  264. '''
  265. Create a fake NEO object of a given type. Follows one-to-many
  266. and many-to-many relationships if cascade.
  267. n (default=1) is the number of child objects of each type will be created.
  268. In cases like segment.spiketrains, there will be more than this number
  269. because there will be n for each unit, of which there will be n for
  270. each channelindex, of which there will be n.
  271. '''
  272. if hasattr(obj_type, 'lower'):
  273. cls = class_by_name[obj_type]
  274. else:
  275. cls = obj_type
  276. obj_type = obj_type.__name__
  277. kwargs = get_fake_values(obj_type, annotate=True, seed=seed, n=n)
  278. obj = cls(**kwargs)
  279. # if not cascading, we don't need to do any of the stuff after this
  280. if not cascade:
  281. return obj
  282. # this is used to signal other containers that they shouldn't duplicate
  283. # data
  284. if obj_type == 'Block':
  285. cascade = 'block'
  286. for i, childname in enumerate(getattr(obj, '_child_objects', [])):
  287. # we create a few of each class
  288. for j in range(n):
  289. if seed is not None:
  290. iseed = 10*seed+100*i+1000*j
  291. else:
  292. iseed = None
  293. child = fake_neo(obj_type=childname, cascade=cascade,
  294. seed=iseed, n=n)
  295. child.annotate(i=i, j=j)
  296. # if we are creating a block and this is the object's primary
  297. # parent, don't create the object, we will import it from secondary
  298. # containers later
  299. if (cascade == 'block' and len(child._parent_objects) > 0 and
  300. obj_type != child._parent_objects[-1]):
  301. continue
  302. getattr(obj, _container_name(childname)).append(child)
  303. # need to manually create 'implicit' connections
  304. if obj_type == 'Block':
  305. # connect data objects to segment
  306. for i, chx in enumerate(obj.channel_indexes):
  307. for k, sigarr in enumerate(chx.analogsignals):
  308. obj.segments[k].analogsignals.append(sigarr)
  309. for k, sigarr in enumerate(chx.irregularlysampledsignals):
  310. obj.segments[k].irregularlysampledsignals.append(sigarr)
  311. for j, unit in enumerate(chx.units):
  312. for k, train in enumerate(unit.spiketrains):
  313. obj.segments[k].spiketrains.append(train)
  314. #elif obj_type == 'ChannelIndex':
  315. # inds = []
  316. # names = []
  317. # chinds = np.array([unit.channel_indexes[0] for unit in obj.units])
  318. # obj.indexes = np.array(inds, dtype='i')
  319. # obj.channel_names = np.array(names).astype('S')
  320. if hasattr(obj, 'create_many_to_one_relationship'):
  321. obj.create_many_to_one_relationship()
  322. return obj
  323. def clone_object(obj, n=None):
  324. '''
  325. Generate a new object and new objects with the same rules as the original.
  326. '''
  327. if hasattr(obj, '__iter__') and not hasattr(obj, 'ndim'):
  328. return [clone_object(iobj, n=n) for iobj in obj]
  329. cascade = hasattr(obj, 'children') and len(obj.children)
  330. if n is not None:
  331. pass
  332. elif cascade:
  333. n = min(len(getattr(obj, cont)) for cont in obj._child_containers)
  334. else:
  335. n = 0
  336. seed = obj.annotations.get('seed', None)
  337. newobj = fake_neo(obj.__class__, cascade=cascade, seed=seed, n=n)
  338. if 'i' in obj.annotations:
  339. newobj.annotate(i=obj.annotations['i'], j=obj.annotations['j'])
  340. return newobj