Browse Source

Minor fixes, python 3 compatibility and Neo 0.5dev compatibility

Michael Denker 6 years ago
parent
commit
dfaf917535
1 changed files with 32 additions and 16 deletions
  1. 32 16
      code/reachgraspio/reachgraspio.py

+ 32 - 16
code/reachgraspio/reachgraspio.py

@@ -51,7 +51,12 @@ import odml.tools
 import quantities as pq
 
 import neo
-from neo.io.blackrockio import BlackrockIO
+
+# Using old version of IO as long as available - should be deleted at some point
+try:
+    from neo.io import OldBlackrockIO as BlackrockIO
+except ImportError:
+    from neo.io.blackrockio import BlackrockIO
 
 
 class ReachGraspIO(BlackrockIO):
@@ -285,9 +290,9 @@ class ReachGraspIO(BlackrockIO):
         '65513': 'RW-ON (+CONF-LF)',
         '65514': 'RW-ON (+CONF-SG)'}
     event_labels_codes = dict(
-        [(k, []) for k in np.unique(event_labels_str.values())])
-    for k in event_labels_codes.keys():
-        for l, v in event_labels_str.iteritems():
+        [(k, []) for k in np.unique(list(event_labels_str.values()))]) #inefficient in Python2
+    for k in list(event_labels_codes):
+        for l, v in event_labels_str.items(): #inefficient in Python2
             if v == k:
                 event_labels_codes[k].append(l)
 
@@ -327,7 +332,7 @@ class ReachGraspIO(BlackrockIO):
         'RW-ON': 6,
         'STOP': 7}
     trial_const_sequence_str = dict(
-        (v, k) for k, v in trial_const_sequence_codes.iteritems())
+        (v, k) for k, v in trial_const_sequence_codes.items()) #inefficient in Python2
 
     # Create dictionaries for trial performances
     # (resulting decimal number from binary number created from trial_sequence)
@@ -340,7 +345,7 @@ class ReachGraspIO(BlackrockIO):
         'error<GO-ON': 175,
         'grip_error': 191,
         'correct_trial': 255}
-    performance_str = dict((v, k) for k, v in performance_codes.iteritems())
+    performance_str = dict((v, k) for k, v in performance_codes.items())  #inefficient in Python2
 
     def __init__(
             self, filename, odml_directory=None,
@@ -496,7 +501,7 @@ class ReachGraspIO(BlackrockIO):
         task_condition = 0
 
         if len(occurring_trtys) > 0:
-            for cnd, trtys in self.condition_str.iteritems():
+            for cnd, trtys in self.condition_str.items():  #inefficient in Python2
                 if set(trtys) == set(occurring_trtys):
                     # replace with detected task condition
                     task_condition = cnd
@@ -565,17 +570,24 @@ class ReachGraspIO(BlackrockIO):
 
         # Uncomment for event and trial sequence debugging
 #        for ev in events.labels:
-#            if ev in self.event_labels_str.keys():
+#            if ev in list(self.event_labels_str):
 #                print ev, self.event_labels_str[ev]
 #            else:
 #                print ev
 
         # Extract beginning of first complete trial
-        first_TSon_idx = list(
-            events.labels).index(self.event_labels_codes['TS-ON'][0])
+        tson_label = self.event_labels_codes['TS-ON'][0]
+        if tson_label in events.labels:
+            first_TSon_idx = list(events.labels).index(tson_label)
+        else:
+            first_TSon_idx = len(events.labels)
         # Extract end of last complete trial
-        last_WSoff_idx = len(events.labels) - list(events.labels[::-1]).index(
-            self.event_labels_codes['STOP'][0]) - 1
+        stop_label = self.event_labels_codes['STOP'][0]
+        if stop_label in events.labels:
+            last_WSoff_idx = len(events.labels) - \
+                             list(events.labels[::-1]).index(stop_label) - 1
+        else:
+            last_WSoff_idx = -1
 
         # Annotate events with modified labels, trial ids, and trial types
         trial_event_labels = []
@@ -655,7 +667,10 @@ class ReachGraspIO(BlackrockIO):
                             trialsequence[timestamp_id],
                             self.trial_const_sequence_codes['STOP'])
                     else:
-                        raise ValueError("Unknown trial event sequence.")
+                        trial_event_labels.append('STOP')
+                        trialsequence[timestamp_id] = self.__set_bit(
+                            trialsequence[timestamp_id],
+                            self.trial_const_sequence_codes['STOP'])
                 # interpretation of WS-ON/CUE-OFF
                 elif self.event_labels_str[l] == 'WS-ON/CUE-OFF':
                     trial_timestamp_ID.append(timestamp_id)
@@ -667,7 +682,8 @@ class ReachGraspIO(BlackrockIO):
                         trialsequence[timestamp_id] = self.__set_bit(
                             trialsequence[timestamp_id],
                             self.trial_const_sequence_codes['WS-ON'])
-                    elif prev_ev in self.event_labels_codes['CUE/GO']:
+                    elif (prev_ev in self.event_labels_codes['CUE/GO'] or
+                          prev_ev in self.event_labels_codes['GO/RW-OFF']):
                         trial_event_labels.append('CUE-OFF')
                         trialsequence[timestamp_id] = self.__set_bit(
                             trialsequence[timestamp_id],
@@ -747,7 +763,7 @@ class ReachGraspIO(BlackrockIO):
 
         # add modified belongs_to_trialtype to annotations
         for tid in trial_timestamp_ID:
-            if tid not in trialtypes.keys():
+            if tid not in list(trialtypes):
                 trialtypes[tid] = 'NONE'
         belongs_to_trialtype = [
             trialtypes[tid] for tid in trial_timestamp_ID]
@@ -1430,7 +1446,7 @@ class ReachGraspIO(BlackrockIO):
         bl.annotate(conditions=[])
         for seg in bl.segments:
             if load_events and not lazy:
-                if 'condition' in seg.annotations.keys():
+                if 'condition' in list(seg.annotations):
                     bl.annotations['conditions'].append(
                         seg.annotations['condition'])