CLLDA_for_ICLabel.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from load_website_data import load_icl
  2. import numpy as np
  3. try:
  4. from scipy.io import savemat
  5. except ImportError:
  6. pass
  7. from crowd_labeling import CLLDA, concurrent_cllda, combine_cllda
  8. from crowd_labeling.MV import MV
  9. import cPickle as pkl
  10. import json
  11. from os.path import join, isfile, isdir
  12. import argparse
  13. from copy import deepcopy
  14. import sys
  15. # parse input arguments
  16. print('parsing arguments')
  17. parser = argparse.ArgumentParser(description='Run or update CL estimates.')
  18. parser.add_argument('database', type=str, help='Absolute reference to the sqlite database file.')
  19. parser.add_argument('save', type=str, help='Directory in which to save results.')
  20. parser.add_argument('-classifications', type=str, help='Directory in which to save results for website viewing.',
  21. default=None)
  22. args = parser.parse_args()
  23. database = args.database
  24. path = args.save
  25. classifications_path = args.classifications
  26. assert isfile(database), 'database path does not exist'
  27. assert isdir(path), 'save path does not exist'
  28. assert isdir(classifications_path), 'classifications path does not exist'
  29. # load sqlite data
  30. print('loading database')
  31. icl_votes = load_icl(database)
  32. votes = icl_votes['votes']
  33. workers = icl_votes['workers']
  34. instances = icl_votes['instances']
  35. instance_set_numbers = icl_votes['instance_set_numbers']
  36. instance_ic_numbers = icl_votes['instance_ic_numbers']
  37. vote_ids = np.array(['Brain', 'Muscle', 'Eye', 'Heart', 'Line Noise', 'Chan Noise', 'Other', '?'])
  38. instance_ids = icl_votes['instance_ids']
  39. worker_ids = icl_votes['worker_ids']
  40. T = icl_votes['n_classes']
  41. C = icl_votes['n_responses']
  42. A = icl_votes['n_workers']
  43. is_expert = (icl_votes['is_expert']).astype(bool) # type = np.ndarray
  44. # append identifier to string
  45. def add_identifier(string, identifier):
  46. return '_'.join((x for x in (string, identifier) if x is not None))
  47. # run cllda
  48. def run_cllda(save_path, vts, wks, its, vt_ids=None, it_ids=None, wk_ids=None, it_priors=None, wk_priors=None,
  49. it_set_numbers=None, it_ic_numbers=None, identifier=None):
  50. if isfile(join(save_path, add_identifier('icl_cllda_models', identifier) + '.pkl')):
  51. # load for python
  52. with open(join(save_path, add_identifier('icl_cllda_models', identifier) + '.pkl'), 'rb') as f:
  53. cls = pkl.load(f)
  54. # update CLLDA with all transforms
  55. cls = concurrent_cllda(cls, vts, wks, its, nprocs=4, vote_ids=vt_ids, instance_ids=it_ids,
  56. worker_ids=wk_ids, worker_prior=wk_priors, num_epochs=800, burn_in=0)
  57. else:
  58. # CLLDA with all transforms
  59. cls = concurrent_cllda(4, vts, wks, its, nprocs=4, vote_ids=vt_ids, instance_ids=it_ids,
  60. worker_ids=wk_ids, worker_prior=wk_priors, instance_prior=it_priors,
  61. transform=('none', 'ilr', 'clr', 'alr'), num_epochs=1000, burn_in=200)
  62. # save individual models for python
  63. with open(join(save_path, add_identifier('icl_cllda_models', identifier) + '.pkl'), 'wb') as f:
  64. pkl.dump(cls, f)
  65. # combine models
  66. cl = combine_cllda(cls)
  67. # aggregate data
  68. return {
  69. 'instance_ids': cl.instance_ids,
  70. 'worker_ids': cl.worker_ids,
  71. 'vote_ids': cl.vote_ids,
  72. 'instance_set_numbers': it_set_numbers.astype(int),
  73. 'instance_ic_numbers': it_ic_numbers.astype(int),
  74. 'transform': cl.transform,
  75. 'labels': cl.labels,
  76. 'labels_cov': cl.labels_cov,
  77. 'worker_mats': cl.worker_mats,
  78. }
  79. # save results in 3 different formats
  80. def save_results(save_path, data, identifier=None):
  81. # save combined model for php
  82. print('saving for php')
  83. json_data = deepcopy(data)
  84. for key, val in json_data.iteritems():
  85. if isinstance(val, np.ndarray):
  86. json_data[key] = val.tolist()
  87. elif isinstance(val, list):
  88. for it, item in enumerate(val):
  89. if isinstance(item, np.ndarray):
  90. val[it] = item.tolist()
  91. json_data[key] = val
  92. with open(join(save_path, add_identifier('ICLabels', identifier) + '.json'), 'wb') as f:
  93. json.dump(json_data, f)
  94. # save combined model for python
  95. print('saving for python')
  96. with open(join(save_path, add_identifier('ICLabels', identifier) + '.pkl'), 'wb') as f:
  97. pkl.dump(data, f)
  98. # save combined model for matlab
  99. if 'savemat' in sys.modules:
  100. print('saving for matlab')
  101. for key, val in data.iteritems():
  102. if not isinstance(val, np.ndarray):
  103. try:
  104. val = np.array(val)
  105. except ValueError:
  106. data[key] = np.empty(len(val), dtype=np.object)
  107. for it, item in enumerate(val):
  108. data[key][it] = item
  109. continue
  110. if not np.issubdtype(val.dtype, np.number):
  111. data[key] = val.astype(np.object)
  112. savemat(join(save_path, add_identifier('ICLabels', identifier) + '.mat'), data)
  113. # optionally save classifications for website viewing
  114. if isdir(classifications_path) and all((x in data.keys() for x in ('labels', 'vote_ids',
  115. 'instance_set_numbers', 'instance_ic_numbers'))):
  116. path_str = join(classifications_path, add_identifier('website', identifier) + '_icl_')
  117. with open(path_str + 'index.json', 'w') as f:
  118. json.dump(zip(json_data['instance_set_numbers'], json_data['instance_ic_numbers']), f)
  119. with open(path_str + 'classifications.json', 'w') as f:
  120. try:
  121. json.dump(json_data['labels'][np.where(np.array(data['transform']) == 'none')[0][0]], f)
  122. except KeyError:
  123. json.dump(json_data['labels'], f)
  124. with open(path_str + 'classes.json', 'w') as f:
  125. json.dump(json_data['vote_ids'][:-1], f)
  126. # CLLDA settings
  127. n_pseudovotes_e = 100
  128. n_pseudovotes_u = 1
  129. expert_prior = n_pseudovotes_e * (np.hstack((np.eye(T), np.zeros((T, 1))))) + 0.01
  130. user_prior = n_pseudovotes_u * (np.hstack((np.eye(T), np.zeros((T, 1))))) + 0.01
  131. all_priors = np.zeros((A, T, C))
  132. all_priors[is_expert.astype(np.bool), :, :] = np.tile(expert_prior[np.newaxis], [is_expert.sum(), 1, 1])
  133. all_priors[np.logical_not(is_expert), :, :] = np.tile(user_prior[np.newaxis], [np.logical_not(is_expert).sum(), 1, 1])
  134. instance_prior = np.histogram(votes, range(C))[0] / 100. / np.histogram(votes, range(C))[0].sum()
  135. # run and save CLLDA with experts
  136. tag = 'expert'
  137. print('Running CLLDA_' + tag + '...')
  138. out = run_cllda(path, votes, workers, instances, vote_ids, instance_ids, worker_ids, instance_prior,
  139. all_priors, instance_set_numbers, instance_ic_numbers, tag)
  140. print('Saved individual CLLDA_' + tag + ' models')
  141. print('Saving combined results...')
  142. save_results(path, out, tag)
  143. print('Saved combined results')
  144. # run CLLDA without experts
  145. tag = 'noexpert'
  146. print('Running CLLDA_' + tag + '...')
  147. out = run_cllda(path, votes, workers, instances, vote_ids, instance_ids, worker_ids, instance_prior,
  148. user_prior, instance_set_numbers, instance_ic_numbers, tag)
  149. print('Saved individual CLLDA_' + tag + ' models')
  150. print('Saving combined results...')
  151. save_results(path, out, tag)
  152. print('Saved combined results')
  153. # run and save with only luca
  154. # remove non-luca votes
  155. worker_ids_lu = worker_ids[0]
  156. luca_ind = np.in1d(workers, (0,))
  157. votes_lu = votes[luca_ind]
  158. workers_lu = workers[luca_ind]
  159. instances_lu = instances[luca_ind]
  160. # remove instances with votes that are unsure
  161. keep_index = np.logical_not(np.in1d(instances_lu, np.unique(instances_lu[votes_lu == 7])))
  162. votes_lu = votes_lu[keep_index]
  163. workers_lu = workers_lu[keep_index]
  164. instances_lu = instances_lu[keep_index]
  165. instance_ids_lu = instance_ids[np.unique(instances_lu)]
  166. # reset instance numbering
  167. instance_set_numbers_lu = np.array(instance_set_numbers)[np.unique(instances_lu)]
  168. instance_ic_numbers_lu = np.array(instance_ic_numbers)[np.unique(instances_lu)]
  169. instances_lu = np.array([{x: y for x, y in zip(np.unique(instances_lu),
  170. np.arange(np.unique(instances_lu).size))}[z]
  171. for z in instances_lu])
  172. # run MV
  173. cl = MV(votes_lu, workers_lu, instances_lu)
  174. # save results
  175. save_results(path, {
  176. 'instance_ids': instance_ids_lu,
  177. 'worker_ids': worker_ids_lu,
  178. 'vote_ids': vote_ids,
  179. 'instance_set_numbers': instance_set_numbers_lu,
  180. 'instance_ic_numbers': instance_ic_numbers_lu,
  181. 'labels': cl.labels,
  182. }, 'onlyluca')