data_overview_1.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741
  1. # -*- coding: utf-8 -*-
  2. """
  3. Code for generating the first data figure in the manuscript.
  4. Authors: Julia Sprenger, Lyuba Zehl, Michael Denker
  5. Copyright (c) 2017, Institute of Neuroscience and Medicine (INM-6),
  6. Forschungszentrum Juelich, Germany
  7. All rights reserved.
  8. Redistribution and use in source and binary forms, with or without
  9. modification, are permitted provided that the following conditions are met:
  10. * Redistributions of source code must retain the above copyright notice, this
  11. list of conditions and the following disclaimer.
  12. * Redistributions in binary form must reproduce the above copyright notice,
  13. this list of conditions and the following disclaimer in the documentation
  14. and/or other materials provided with the distribution.
  15. * Neither the names of the copyright holders nor the names of the contributors
  16. may be used to endorse or promote products derived from this software without
  17. specific prior written permission.
  18. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  19. ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  20. WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  21. DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  22. FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  23. DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  24. SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  25. CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  26. OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  27. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  28. """
  29. import os
  30. import numpy as np
  31. from scipy import stats
  32. import quantities as pq
  33. import matplotlib.pyplot as plt
  34. from matplotlib import gridspec, ticker
  35. from reachgraspio import reachgraspio
  36. import odml.tools
  37. from neo import utils as neo_utils
  38. from neo_utils import load_segment
  39. import odml_utils
  40. # =============================================================================
  41. # Define data and metadata directories
  42. # =============================================================================
  43. def get_monkey_datafile(monkey):
  44. if monkey == "Lilou":
  45. return "l101210-001" # ns2 (behavior) and ns5 present
  46. elif monkey == "Nikos2":
  47. return "i140703-001" # ns2 and ns6 present
  48. else:
  49. return ""
  50. # Enter your dataset directory here
  51. datasetdir = "../datasets_blackrock/"
  52. trialtype_colors = {
  53. 'SGHF': 'MediumBlue', 'SGLF': 'Turquoise',
  54. 'PGHF': 'DarkGreen', 'PGLF': 'YellowGreen',
  55. 'LFSG': 'Orange', 'LFPG': 'Yellow',
  56. 'HFSG': 'DarkRed', 'HFPG': 'OrangeRed',
  57. 'SGSG': 'SteelBlue', 'PGPG': 'LimeGreen',
  58. 'NONE': 'k', 'PG': 'k', 'SG': 'k', 'LF': 'k', 'HF': 'k'}
  59. event_colors = {
  60. 'TS-ON': 'Gray', # 'TS-OFF': 'Gray',
  61. 'WS-ON': 'Gray', # 'WS-OFF': 'Gray',
  62. 'CUE-ON': 'Gray',
  63. 'CUE-OFF': 'Gray',
  64. 'GO-ON': 'Gray', # 'GO-OFF': 'Gray',
  65. # 'GO/RW-OFF': 'Gray',
  66. 'SR': 'Gray', # 'SR-REP': 'Gray',
  67. 'RW-ON': 'Gray', # 'RW-OFF': 'Gray',
  68. 'STOP': 'Gray'}
  69. # =============================================================================
  70. # Plot helper functions
  71. # =============================================================================
  72. def force_aspect(ax, aspect=1):
  73. ax.set_aspect(abs(
  74. (ax.get_xlim()[1] - ax.get_xlim()[0]) /
  75. (ax.get_ylim()[1] - ax.get_ylim()[0])) / aspect)
  76. def get_arraygrid(signals, chosen_el):
  77. array_grid = np.ones((10, 10)) * 0.7
  78. rejections = np.logical_or(signals.array_annotations['electrode_reject_HFC'],
  79. signals.array_annotations['electrode_reject_LFC'],
  80. signals.array_annotations['electrode_reject_IFC'])
  81. for sig_idx in range(signals.shape[-1]):
  82. connector_aligned_id = signals.array_annotations['connector_aligned_ids'][sig_idx]
  83. x, y = int((connector_aligned_id -1)// 10), int((connector_aligned_id - 1) % 10)
  84. if np.asarray(signals.array_annotations['channel_ids'][sig_idx], dtype=int) == chosen_el:
  85. array_grid[x, y] = -0.7
  86. elif rejections[sig_idx]:
  87. array_grid[x, y] = -0.35
  88. else:
  89. array_grid[x, y] = 0
  90. return np.ma.array(array_grid, mask=np.isnan(array_grid))
  91. # =============================================================================
  92. # Load data and metadata for a monkey
  93. # =============================================================================
  94. # CHANGE this parameter to load data of the different monkeys
  95. monkey = 'Nikos2'
  96. # monkey = 'Lilou'
  97. chosen_el = {'Lilou': 71, 'Nikos2': 63}
  98. chosen_units = {'Lilou': range(1, 5), 'Nikos2': range(1, 5)}
  99. datafile = get_monkey_datafile(monkey)
  100. session = reachgraspio.ReachGraspIO(
  101. filename=os.path.join(datasetdir, datafile),
  102. odml_directory=datasetdir,
  103. verbose=False)
  104. block = session.read_block(lazy=True)
  105. segment = block.segments[0]
  106. # Displaying loaded data structure as string output
  107. print("\nBlock")
  108. print('Attributes ', block.__dict__.keys())
  109. print('Annotations', block.annotations)
  110. print("\nSegment")
  111. print('Attributes ', segment.__dict__.keys())
  112. print('Annotations', segment.annotations)
  113. print("\nEvents")
  114. for x in segment.events:
  115. print('\tEvent with name', x.name)
  116. print('\t\tAttributes ', x.__dict__.keys())
  117. print('\t\tAnnotation keys', x.annotations.keys())
  118. print('\t\ttimes', x.times[:20])
  119. if x.name == 'TrialEvents':
  120. for anno_key in ['trial_id', 'trial_timestamp_id', 'trial_event_labels',
  121. 'trial_reject_IFC']:
  122. print('\t\t'+anno_key, x.array_annotations[anno_key][:20])
  123. print("\nGroups")
  124. for x in block.groups:
  125. print('\tGroup with name', x.name)
  126. print('\t\tAttributes ', x.__dict__.keys())
  127. print('\t\tAnnotations', x.annotations)
  128. print("\nSpikeTrains")
  129. for x in segment.spiketrains:
  130. print('\tSpiketrain with name', x.name)
  131. print('\t\tAttributes ', x.__dict__.keys())
  132. print('\t\tAnnotations', x.annotations)
  133. print('\t\tchannel_id', x.annotations['channel_id'])
  134. print('\t\tunit_id', x.annotations['unit_id'])
  135. print('\t\tis sua', x.annotations['sua'])
  136. print('\t\tis mua', x.annotations['mua'])
  137. print("\nAnalogSignals")
  138. for x in segment.analogsignals:
  139. print('\tAnalogSignal with name', x.name)
  140. print('\t\tAttributes ', x.__dict__.keys())
  141. print('\t\tAnnotations', x.annotations)
  142. print('\t\tchannel_ids', x.array_annotations['channel_ids'])
  143. # get start and stop events of trials
  144. start_events = neo_utils.get_events(
  145. segment,
  146. **{
  147. 'name': 'TrialEvents',
  148. 'trial_event_labels': 'TS-ON',
  149. 'performance_in_trial': 255})
  150. stop_events = neo_utils.get_events(
  151. segment,
  152. **{
  153. 'name': 'TrialEvents',
  154. 'trial_event_labels': 'STOP',
  155. 'performance_in_trial': 255})
  156. # there should only be one event object for these conditions
  157. assert len(start_events) == 1
  158. assert len(stop_events) == 1
  159. # insert epochs between 10ms before TS to 50ms after RW corresponding to trails
  160. ep = neo_utils.add_epoch(
  161. segment,
  162. start_events[0],
  163. stop_events[0],
  164. pre=-250 * pq.ms,
  165. post=500 * pq.ms,
  166. trial_status='complete_trials')
  167. ep.array_annotate(trial_type=start_events[0].array_annotations['belongs_to_trialtype'],
  168. trial_performance=start_events[0].array_annotations['performance_in_trial'])
  169. # access single epoch of this data_segment
  170. epochs = neo_utils.get_epochs(segment, **{'trial_status': 'complete_trials'})
  171. assert len(epochs) == 1
  172. # remove spiketrains not belonging to chosen_electrode
  173. segment.spiketrains = segment.filter(targdict={'channel_id': chosen_el[monkey]},
  174. recursive=True, objects='SpikeTrainProxy')
  175. segment.spiketrains = [st for st in segment.spiketrains if st.annotations['unit_id'] in range(1, 5)]
  176. # replacing the segment with a new segment containing all data
  177. # to speed up cutting of segments
  178. segment = load_segment(segment, load_wavefroms=True, channel_indexes=[chosen_el[monkey]])
  179. # use most raw neuronal data if multiple versions are present
  180. max_sampling_rate = max([a.sampling_rate for a in segment.analogsignals])
  181. idx = 0
  182. while idx < len(segment.analogsignals):
  183. signal = segment.analogsignals[idx]
  184. if signal.annotations['neural_signal'] and signal.sampling_rate < max_sampling_rate:
  185. segment.analogsignals.pop(idx)
  186. else:
  187. idx += 1
  188. # neural_signals = []
  189. # behav_signals = []
  190. # for sig in segment.analogsignals:
  191. # if sig.annotations['neural_signal']:
  192. # neural_signals.append(sig)
  193. # else:
  194. # behav_signals.append(sig)
  195. #
  196. # chosen_raw = neural_signals[0]
  197. # for sig in neural_signals:
  198. # if sig.sampling_rate > chosen_raw.sampling_rate:
  199. # chosen_raw = sig
  200. #
  201. # segment.analogsignals = behav_signals + [chosen_raw]
  202. # cut segments according to inserted 'complete_trials' epochs and reset trial times
  203. cut_segments = neo_utils.cut_segment_by_epoch(segment, epochs[0], reset_time=True)
  204. # =============================================================================
  205. # Define data for overview plots
  206. # =============================================================================
  207. trial_index = {'Lilou': 0, 'Nikos2': 6}
  208. trial_segment = cut_segments[trial_index[monkey]]
  209. blackrock_elid_list = block.annotations['avail_electrode_ids']
  210. # get 'TrialEvents'
  211. event = trial_segment.events[2]
  212. start = np.where(event.array_annotations['trial_event_labels'] == 'TS-ON')[0][0]
  213. trialx_trty = event.array_annotations['belongs_to_trialtype'][start]
  214. trialx_trtimeid = event.array_annotations['trial_timestamp_id'][start]
  215. trialx_color = trialtype_colors[trialx_trty]
  216. # find trial index for next trial with opposite force type (for ax5b plot)
  217. if 'LF' in trialx_trty:
  218. trialz_trty = trialx_trty.replace('LF', 'HF')
  219. else:
  220. trialz_trty = trialx_trty.replace('HF', 'LF')
  221. for i, tr in enumerate(cut_segments):
  222. eventz = tr.events[2]
  223. nextft = np.where(eventz.array_annotations['trial_event_labels'] == 'TS-ON')[0][0]
  224. if eventz.array_annotations['belongs_to_trialtype'][nextft] == trialz_trty:
  225. trialz_trtimeid = eventz.array_annotations['trial_timestamp_id'][nextft]
  226. trialz_color = trialtype_colors[trialz_trty]
  227. trialz_seg = tr
  228. break
  229. # =============================================================================
  230. # Define figure and subplot axis for first data overview
  231. # =============================================================================
  232. fig = plt.figure()
  233. fig.set_size_inches(6.5, 10.) # (w, h) in inches
  234. gs = gridspec.GridSpec(
  235. nrows=5,
  236. ncols=4,
  237. left=0.05,
  238. bottom=0.07,
  239. right=0.9,
  240. top=0.975,
  241. wspace=0.3,
  242. hspace=0.5,
  243. width_ratios=None,
  244. height_ratios=[1, 3, 3, 6, 3])
  245. ax1 = plt.subplot(gs[0, :]) # top row / odml data
  246. # second row
  247. ax2a = plt.subplot(gs[1, 0]) # electrode overview plot
  248. ax2b = plt.subplot(gs[1, 1]) # waveforms unit 1
  249. ax2c = plt.subplot(gs[1, 2]) # waveforms unit 2
  250. ax2d = plt.subplot(gs[1, 3]) # waveforms unit 3
  251. ax3 = plt.subplot(gs[2, :]) # third row / spiketrains
  252. ax4 = plt.subplot(gs[3, :], sharex=ax3) # fourth row / raw signal
  253. ax5a = plt.subplot(gs[4, 0:3]) # fifth row / behavioral signals
  254. ax5b = plt.subplot(gs[4, 3])
  255. fontdict_titles = {'fontsize': 'small', 'fontweight': 'bold'}
  256. fontdict_axis = {'fontsize': 'x-small'}
  257. wf_time_unit = pq.ms
  258. wf_signal_unit = pq.microvolt
  259. plotting_time_unit = pq.s
  260. raw_signal_unit = wf_signal_unit
  261. behav_signal_unit = pq.V
  262. # =============================================================================
  263. # PLOT TRIAL SEQUENCE OF SUBSESSION
  264. # =============================================================================
  265. # load complete metadata collection
  266. odmldoc = odml.load(datasetdir + datafile + '.odml')
  267. # get total trial number
  268. trno_tot = odml_utils.get_TrialCount(odmldoc)
  269. trno_ctr = odml_utils.get_TrialCount(odmldoc, performance_code=255)
  270. trno_ertr = trno_tot - trno_ctr
  271. # get trial id of chosen trial (and next trial with opposite force)
  272. trtimeids = odml_utils.get_TrialIDs(odmldoc, idtype='TrialTimestampID')
  273. trids = odml_utils.get_TrialIDs(odmldoc)
  274. trialx_trid = trids[trtimeids.index(trialx_trtimeid)]
  275. trialz_trid = trids[trtimeids.index(trialz_trtimeid)]
  276. # get all trial ids for grip error trials
  277. trids_pc191 = odml_utils.get_trialids_pc(odmldoc, 191)
  278. # get all trial ids for correct trials
  279. trids_pc255 = odml_utils.get_trialids_pc(odmldoc, 255)
  280. # get occurring trial types
  281. octrty = odml_utils.get_OccurringTrialTypes(odmldoc, code=False)
  282. # Subplot 1: Trial sequence
  283. boxes, labels = [], []
  284. for tt in octrty:
  285. # Plot trial ids of current trial type into trial sequence bar plot
  286. left = odml_utils.get_trialids_trty(odmldoc, tt)
  287. height = np.ones_like(left)
  288. width = 1.
  289. if tt in ['NONE', 'PG', 'SG', 'LF', 'HF']:
  290. color = 'w'
  291. else:
  292. color = trialtype_colors[tt]
  293. B = ax1.bar(
  294. x=left, height=height, width=width, color=color, linewidth=0.001, align='edge')
  295. # Mark trials of current trial type (left) if a grip error occurred
  296. x = [i for i in list(set(left) & set(trids_pc191))]
  297. y = np.ones_like(x) * 2.0
  298. ax1.scatter(x, y, s=5, color='k', marker='*')
  299. # Mark trials of current trial type (left) if any other error occurred
  300. x = [i for i in list(
  301. set(left) - set(trids_pc255) - set(trids_pc191))]
  302. y = np.ones_like(x) * 2.0
  303. ax1.scatter(x, y, s=5, color='gray', marker='*')
  304. # Collect information for trial type legend
  305. if tt not in ['PG', 'SG', 'LF', 'HF']:
  306. boxes.append(B[0])
  307. if tt == 'NONE':
  308. # use errors for providing total trial number
  309. labels.append('total: # %i' % trno_tot)
  310. # add another box and label for error numbers
  311. boxes.append(B[0])
  312. labels.append('* errors: # %i' % trno_ertr)
  313. else:
  314. # trial type trial numbers
  315. labels.append(tt + ': # %i' % len(left))
  316. # mark chosen trial
  317. x = [trialx_trid]
  318. y = np.ones_like(x) * 2.0
  319. ax1.scatter(x, y, s=5, marker='D', color='Red', edgecolors='Red')
  320. # mark next trial with opposite force
  321. x = [trialz_trid]
  322. y = np.ones_like(x) * 2.0
  323. ax1.scatter(x, y, s=5, marker='D', color='orange', edgecolors='orange')
  324. # Generate trial type legend; bbox: (left, bottom, width, height)
  325. leg = ax1.legend(
  326. boxes, labels, bbox_to_anchor=(0., 1., 0.5, 0.1), loc=3, handlelength=1.1,
  327. ncol=len(labels), borderaxespad=0., handletextpad=0.4,
  328. prop={'size': 'xx-small'})
  329. leg.draw_frame(False)
  330. # adjust x and y axis
  331. xticks = list(range(1, 101, 10)) + [100]
  332. ax1.set_xticks(xticks)
  333. ax1.set_xticklabels([str(int(t)) for t in xticks], size='xx-small')
  334. ax1.set_xlabel('trial ID', size='x-small')
  335. ax1.set_xlim(1.-width/2., 100.+width/2.)
  336. ax1.yaxis.set_visible(False)
  337. ax1.set_ylim(0, 3)
  338. ax1.spines['top'].set_visible(False)
  339. ax1.spines['left'].set_visible(False)
  340. ax1.spines['right'].set_visible(False)
  341. ax1.tick_params(direction='out', top=False, left=False, right=False)
  342. ax1.set_title('sequence of the first 100 trials', fontdict_titles, y=2)
  343. ax1.set_aspect('equal')
  344. # =============================================================================
  345. # PLOT ELECTRODE POSITION of chosen electrode
  346. # =============================================================================
  347. neural_signals = [sig for sig in trial_segment.analogsignals if sig.annotations['neural_signal']]
  348. assert len(neural_signals) == 1
  349. neural_signals = neural_signals[0]
  350. arraygrid = get_arraygrid(neural_signals, chosen_el[monkey])
  351. cmap = plt.cm.RdGy
  352. ax2a.pcolormesh(
  353. arraygrid, vmin=-1, vmax=1, lw=1, cmap=cmap, edgecolors='k',
  354. #shading='faceted'
  355. )
  356. force_aspect(ax2a, aspect=1)
  357. ax2a.tick_params(
  358. bottom=False, top=False, left=False, right=False,
  359. labelbottom=False, labeltop=False, labelleft=False, labelright=False)
  360. ax2a.set_title('electrode pos.', fontdict_titles)
  361. # =============================================================================
  362. # PLOT WAVEFORMS of units of the chosen electrode
  363. # =============================================================================
  364. unit_ax_translator = {1: ax2b, 2: ax2c, 3: ax2d}
  365. unit_type = {1: '', 2: '', 3: ''}
  366. wf_lim = []
  367. # plotting waveform for all spiketrains available
  368. for spiketrain in trial_segment.spiketrains:
  369. unit_id = spiketrain.annotations['unit_id']
  370. # get unit type
  371. if spiketrain.annotations['sua']:
  372. unit_type[unit_id] = 'SUA'
  373. elif spiketrain.annotations['mua']:
  374. unit_type[unit_id] = 'MUA'
  375. elif unit_id in [0, 255]:
  376. continue
  377. else:
  378. raise ValueError(f'Found unit with id {unit_id}, that is not SUA or MUA.')
  379. # get correct ax
  380. ax = unit_ax_translator[unit_id]
  381. # get wf sampling time before threshold crossing
  382. left_sweep = spiketrain.left_sweep
  383. # plot waveforms in subplots according to unit id
  384. for st_id, st in enumerate(spiketrain):
  385. wf = spiketrain.waveforms[st_id]
  386. wf_lim.append((np.min(wf), np.max(wf)))
  387. wf_color = str(
  388. (st / spiketrain.t_stop).rescale('dimensionless').magnitude)
  389. times = range(len(wf[0])) * spiketrain.units - left_sweep
  390. ax.plot(
  391. times.rescale(wf_time_unit), wf[0].rescale(wf_signal_unit),
  392. color=wf_color)
  393. ax.set_xlim(
  394. times.rescale(wf_time_unit)[0], times.rescale(wf_time_unit)[-1])
  395. # adding xlabels and titles
  396. for unit_id, ax in unit_ax_translator.items():
  397. ax.set_title('unit %i (%s)' % (unit_id, unit_type[unit_id]),
  398. fontdict_titles)
  399. ax.tick_params(direction='in', length=3, labelsize='xx-small',
  400. labelleft=False, labelright=False)
  401. ax.set_xlabel(wf_time_unit.dimensionality.latex, fontdict_axis)
  402. xticklocator = ticker.MaxNLocator(nbins=5)
  403. ax.xaxis.set_major_locator(xticklocator)
  404. ax.set_ylim(np.min(wf_lim), np.max(wf_lim))
  405. force_aspect(ax, aspect=1)
  406. # adding ylabel
  407. ax2d.tick_params(labelsize='xx-small', labelright=True)
  408. ax2d.set_ylabel(wf_signal_unit.dimensionality.latex, fontdict_axis)
  409. ax2d.yaxis.set_label_position("right")
  410. # =============================================================================
  411. # PLOT SPIKETRAINS of units of chosen electrode
  412. # =============================================================================
  413. plotted_unit_ids = []
  414. # plotting all available spiketrains
  415. for st in trial_segment.spiketrains:
  416. unit_id = st.annotations['unit_id']
  417. plotted_unit_ids.append(unit_id)
  418. ax3.plot(st.times.rescale(plotting_time_unit),
  419. np.zeros(len(st.times)) + unit_id,
  420. 'k|')
  421. # setting layout of spiketrain plot
  422. ax3.set_ylim(min(plotted_unit_ids) - 0.5, max(plotted_unit_ids) + 0.5)
  423. ax3.set_ylabel(r'unit ID', fontdict_axis)
  424. ax3.yaxis.set_major_locator(ticker.MultipleLocator(base=1))
  425. ax3.yaxis.set_label_position("right")
  426. ax3.tick_params(axis='y', direction='in', length=3, labelsize='xx-small',
  427. labelleft=False, labelright=True)
  428. ax3.invert_yaxis()
  429. ax3.set_title('spiketrains', fontdict_titles)
  430. # =============================================================================
  431. # PLOT "raw" SIGNAL of chosen trial of chosen electrode
  432. # =============================================================================
  433. # get "raw" data from chosen electrode
  434. el_raw_sig = [a for a in trial_segment.analogsignals if a.annotations['neural_signal']]
  435. assert len(el_raw_sig) == 1
  436. el_raw_sig = el_raw_sig[0]
  437. # plotting raw signal trace of chosen electrode
  438. chids = np.asarray(el_raw_sig.array_annotations['channel_ids'], dtype=int)
  439. chosen_el_idx = np.where(chids == chosen_el[monkey])[0][0]
  440. ax4.plot(el_raw_sig.times.rescale(plotting_time_unit),
  441. el_raw_sig[:, chosen_el_idx].squeeze().rescale(raw_signal_unit),
  442. color='k')
  443. # setting layout of raw signal plot
  444. ax4.set_ylabel(raw_signal_unit.units.dimensionality.latex, fontdict_axis)
  445. ax4.yaxis.set_label_position("right")
  446. ax4.tick_params(axis='y', direction='in', length=3, labelsize='xx-small',
  447. labelleft=False, labelright=True)
  448. ax4.set_title('"raw" signal', fontdict_titles)
  449. ax4.set_xlim(trial_segment.t_start.rescale(plotting_time_unit),
  450. trial_segment.t_stop.rescale(plotting_time_unit))
  451. ax4.xaxis.set_major_locator(ticker.MultipleLocator(base=1))
  452. # =============================================================================
  453. # PLOT EVENTS across ax3 and ax4 and add time bar
  454. # =============================================================================
  455. # find trial relevant events
  456. startidx = np.where(event.array_annotations['trial_event_labels'] == 'TS-ON')[0][0]
  457. stopidx = np.where(event.array_annotations['trial_event_labels'][startidx:] == 'STOP')[0][0] + startidx + 1
  458. for ax in [ax3, ax4]:
  459. xticks = []
  460. xticklabels = []
  461. for ev_id, ev in enumerate(event[startidx:stopidx]):
  462. ev_labels = event.array_annotations['trial_event_labels'][startidx:stopidx]
  463. if ev_labels[ev_id] in event_colors.keys():
  464. ev_color = event_colors[ev_labels[ev_id]]
  465. ax.axvline(
  466. ev.rescale(plotting_time_unit), color=ev_color, zorder=0.5)
  467. xticks.append(ev.rescale(plotting_time_unit))
  468. if ev_labels[ev_id] == 'CUE-OFF':
  469. xticklabels.append('-OFF')
  470. elif ev_labels[ev_id] == 'GO-ON':
  471. xticklabels.append('GO')
  472. else:
  473. xticklabels.append(ev_labels[ev_id])
  474. ax.set_xticks(xticks)
  475. ax.set_xticklabels(xticklabels)
  476. ax.tick_params(axis='x', direction='out', length=3, labelsize='xx-small',
  477. labeltop=False, top=False)
  478. timebar_ypos = ax4.get_ylim()[0] + np.diff(ax4.get_ylim())[0] / 10
  479. timebar_labeloffset = np.diff(ax4.get_ylim())[0] * 0.01
  480. timebar_xmin = xticks[-2] + ((xticks[-1] - xticks[-2]) / 2 - 0.25 * pq.s)
  481. timebar_xmax = timebar_xmin + 0.5 * pq.s
  482. ax4.plot([timebar_xmin, timebar_xmax], [timebar_ypos, timebar_ypos], '-',
  483. linewidth=3, color='k')
  484. ax4.text(timebar_xmin + 0.25 * pq.s, timebar_ypos + timebar_labeloffset,
  485. '500 ms', ha='center', va='bottom', size='xx-small', color='k')
  486. # =============================================================================
  487. # PLOT BEHAVIORAL SIGNALS of chosen trial
  488. # =============================================================================
  489. # get behavioral signals
  490. ainp_signals = [nsig for nsig in trial_segment.analogsignals if not nsig.annotations['neural_signal']][0]
  491. chids = np.asarray(ainp_signals.array_annotations['channel_ids'], dtype=int)
  492. force_channel_idx = np.where(chids == 141)[0][0]
  493. ainp_trialz_signals = [a for a in trialz_seg.analogsignals if not a.annotations['neural_signal']]
  494. assert len(ainp_trialz_signals)
  495. ainp_trialz = ainp_trialz_signals[0][:, force_channel_idx]
  496. # find out what signal to use
  497. trialx_sec = odmldoc['Recording']['TaskSettings']['Trial_%03i' % trialx_trid]
  498. # get correct channel id
  499. trialx_chids = [143]
  500. FSRi = trialx_sec['AnalogEvents'].properties['UsedForceSensor'].values[0]
  501. FSRinfosec = odmldoc['Setup']['Apparatus']['TargetObject']['FSRSensor']
  502. if 'SG' in trialx_trty:
  503. sgchids = FSRinfosec.properties['SGChannelIDs'].values
  504. trialx_chids.append(min(sgchids) if FSRi == 1 else max(sgchids))
  505. else:
  506. pgchids = FSRinfosec.properties['PGChannelIDs'].values
  507. trialx_chids.append(min(pgchids) if FSRi == 1 else max(pgchids))
  508. # define time epoch
  509. startidx = np.where(event.array_annotations['trial_event_labels'] == 'SR')[0][0]
  510. stopidx = np.where(event.array_annotations['trial_event_labels'] == 'OBB')[0][0]
  511. sr = event[startidx].rescale(plotting_time_unit)
  512. stop = event[stopidx].rescale(plotting_time_unit) + 0.050 * pq.s
  513. startidx = np.where(event.array_annotations['trial_event_labels'] == 'FSRplat-ON')[0][0]
  514. stopidx = np.where(event.array_annotations['trial_event_labels'] == 'FSRplat-OFF')[0][0]
  515. fplon = event[startidx].rescale(plotting_time_unit)
  516. fploff = event[stopidx].rescale(plotting_time_unit)
  517. # define time epoch trialz
  518. startidx = np.where(eventz.array_annotations['trial_event_labels'] == 'FSRplat-ON')[0][0]
  519. stopidx = np.where(eventz.array_annotations['trial_event_labels'] == 'FSRplat-OFF')[0][0]
  520. fplon_trz = eventz[startidx].rescale(plotting_time_unit)
  521. fploff_trz = eventz[stopidx].rescale(plotting_time_unit)
  522. # plotting grip force and object displacement
  523. ai_legend = []
  524. ai_legend_txt = []
  525. for chidx, chid in enumerate(np.asarray(ainp_signals.array_annotations['channel_ids'], dtype=int)):
  526. ainp = ainp_signals[:, chidx]
  527. if int(ainp.array_annotations['channel_ids'][0]) in trialx_chids:
  528. ainp_times = ainp.times.rescale(plotting_time_unit)
  529. mask = (ainp_times > sr) & (ainp_times < stop)
  530. ainp_ampli = stats.zscore(ainp.magnitude[mask])
  531. if int(ainp.array_annotations['channel_ids'][0]) != 143:
  532. color = 'gray'
  533. ai_legend_txt.append('grip force')
  534. else:
  535. color = 'k'
  536. ai_legend_txt.append('object disp.')
  537. ai_legend.append(
  538. ax5a.plot(ainp_times[mask], ainp_ampli, color=color)[0])
  539. # get force load of this trial for next plot
  540. elif int(ainp.array_annotations['channel_ids'][0]) == 141:
  541. ainp_times = ainp.times.rescale(plotting_time_unit)
  542. mask = (ainp_times > fplon) & (ainp_times < fploff)
  543. force_av_01 = np.mean(ainp.rescale(behav_signal_unit).magnitude[mask])
  544. # setting layout of grip force and object displacement plot
  545. ax5a.set_title('grip force and object displacement', fontdict_titles)
  546. ax5a.yaxis.set_label_position("left")
  547. ax5a.tick_params(direction='in', length=3, labelsize='xx-small',
  548. labelleft=False, labelright=True)
  549. ax5a.set_ylabel('zscore', fontdict_axis)
  550. ax5a.legend(
  551. ai_legend, ai_legend_txt,
  552. bbox_to_anchor=(0.65, .85, 0.25, 0.1), loc=2, handlelength=1.1,
  553. ncol=len(labels), borderaxespad=0., handletextpad=0.4,
  554. prop={'size': 'xx-small'})
  555. # plotting load/pull force of LF and HF trial
  556. force_times = ainp_trialz.times.rescale(plotting_time_unit)
  557. mask = (force_times > fplon_trz) & (force_times < fploff_trz)
  558. force_av_02 = np.mean(ainp_trialz.rescale(behav_signal_unit).magnitude[mask])
  559. bar_width = [0.4, 0.4]
  560. color = [trialx_color, trialz_color]
  561. ax5b.bar([0, 0.6], [force_av_01, force_av_02], bar_width, color=color)
  562. ax5b.set_title('load/pull force', fontdict_titles)
  563. ax5b.set_ylabel(behav_signal_unit.units.dimensionality.latex, fontdict_axis)
  564. ax5b.set_xticks([0, 0.6])
  565. ax5b.set_xticklabels([trialx_trty, trialz_trty], fontdict=fontdict_axis)
  566. ax5b.yaxis.set_label_position("right")
  567. ax5b.tick_params(direction='in', length=3, labelsize='xx-small',
  568. labelleft=False, labelright=True)
  569. # =============================================================================
  570. # PLOT EVENTS across ax5a and add time bar
  571. # =============================================================================
  572. # find trial relevant events
  573. startidx = np.where(event.array_annotations['trial_event_labels'] == 'SR')[0][0]
  574. stopidx = np.where(event.array_annotations['trial_event_labels'] == 'OBB')[0][0]
  575. xticks = []
  576. xticklabels = []
  577. for ev_id, ev in enumerate(event[startidx:stopidx]):
  578. ev_labels = event.array_annotations['trial_event_labels'][startidx:stopidx + 1]
  579. if ev_labels[ev_id] in ['RW-ON']:
  580. ax5a.axvline(ev.rescale(plotting_time_unit), color='k', zorder=0.5)
  581. xticks.append(ev.rescale(plotting_time_unit))
  582. xticklabels.append(ev_labels[ev_id])
  583. elif ev_labels[ev_id] in ['OT', 'OR', 'DO', 'OBB', 'FSRplat-ON',
  584. 'FSRplat-OFF', 'HEplat-ON']:
  585. ev_color = 'k'
  586. xticks.append(ev.rescale(plotting_time_unit))
  587. xticklabels.append(ev_labels[ev_id])
  588. ax5a.axvline(
  589. ev.rescale(plotting_time_unit), color='k', ls='-.', zorder=0.5)
  590. elif ev_labels[ev_id] == 'HEplat-OFF':
  591. ev_color = 'k'
  592. ax5a.axvline(
  593. ev.rescale(plotting_time_unit), color='k', ls='-.', zorder=0.5)
  594. ax5a.set_xticks(xticks)
  595. ax5a.set_xticklabels(xticklabels, fontdict=fontdict_axis, rotation=90)
  596. ax5a.tick_params(axis='x', direction='out', length=3, labelsize='xx-small',
  597. labeltop=False, top=False)
  598. ax5a.set_ylim([-2.0, 2.0])
  599. timebar_ypos = ax5a.get_ylim()[0] + np.diff(ax5a.get_ylim())[0] / 10
  600. timebar_labeloffset = np.diff(ax5a.get_ylim())[0] * 0.02
  601. timebar_xmax = xticks[xticklabels.index('RW-ON')] - 0.1 * pq.s
  602. timebar_xmin = timebar_xmax - 0.25 * pq.s
  603. ax5a.plot([timebar_xmin, timebar_xmax], [timebar_ypos, timebar_ypos], '-',
  604. linewidth=3, color='k')
  605. ax5a.text(timebar_xmin + 0.125 * pq.s, timebar_ypos + timebar_labeloffset,
  606. '250 ms', ha='center', va='bottom', size='xx-small', color='k')
  607. # add time window of ax5a to ax4
  608. ax4.axvspan(ax5a.get_xlim()[0], ax5a.get_xlim()[1], facecolor=[0.9, 0.9, 0.9],
  609. zorder=-0.1, ec=None)
  610. # =============================================================================
  611. # SAVE FIGURE
  612. # =============================================================================
  613. fname = 'data_overview_1_%s' % monkey
  614. for file_format in ['eps', 'png', 'pdf']:
  615. fig.savefig(fname + '.%s' % file_format, dpi=400, format=file_format)