load_website_data.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import sqlite3 as sql
  2. import numpy as np
  3. import pdb
  4. def load_icl(db_path):
  5. # load sqlite data
  6. connection = sql.connect(db_path)
  7. cursor = connection.cursor()
  8. cursor.execute('SELECT * FROM users')
  9. db_combined = cursor.fetchall()
  10. cursor.execute('SELECT * FROM labels')
  11. db_labels = cursor.fetchall()
  12. db_labels_column_names = [x[0] for x in cursor.description]
  13. cursor.execute('SELECT * FROM images')
  14. db_images = cursor.fetchall()
  15. connection.close()
  16. del connection, cursor
  17. # remove users with not enough labels
  18. min_labels = 10
  19. user_labs = [x[1] for x in db_labels]
  20. user_labs_count = np.array([user_labs.count(x) for x in [x[0] for x in db_combined]])
  21. keep_users = np.where(user_labs_count >= min_labels)[0]
  22. db_combined = [db_combined[x] for x in keep_users]
  23. del user_labs_count
  24. # remove labels from users with not enough labels
  25. db_labels = [x for x in db_labels if x[1] in [y[0] for y in db_combined]]
  26. del keep_users
  27. # remove instances which only have "?" as an answer
  28. # find all images with a ?
  29. # for each of those images, find all labels
  30. # if the labels are only ?, remove
  31. remove = list()
  32. for it in np.unique([x[2] for x in db_labels if x[10]]):
  33. if not np.sum([x[3:10] for x in db_labels if x[2] == it]):
  34. remove.append(it)
  35. if remove:
  36. db_labels = [x for x in db_labels if x[2] not in remove]
  37. NotImplementedError('there are some dead answers that need input')
  38. # TODO: fix the above. doesn't catch everything
  39. # aggregate images
  40. db_images = [db_images[y-1] for y in np.unique([x[2] for x in db_labels])]
  41. # tabulate data
  42. I = len(set([x[2] for x in db_labels])) # number of images
  43. A = len(db_combined) # number of users and experts combined
  44. # dictionary for all
  45. combined_ind = [x[0] for x in db_combined]
  46. combined_dict = {x: y for x, y in zip(combined_ind, range(A))} # sqlite index to db_experts index
  47. # dictionary for images
  48. im_ind = list(set([x[2] for x in db_labels]))
  49. im_ind.sort()
  50. im_dict = {x: y for x, y in zip(im_ind, range(I))} # sqlite image_id to image index
  51. # separate votes_mat
  52. votes_mat = np.array([x[3:11] for x in db_labels])
  53. is_expert = np.array([x[4] for x in db_combined])
  54. # is_expert[0] = 0
  55. # index votes_mat
  56. iV = np.array([im_dict[x[2]] for x in db_labels])
  57. uV = np.array([combined_dict[x[1]] for x in db_labels])
  58. # tabulate more data
  59. V = len(votes_mat) # number of total votes_mat
  60. T = 7 # number of topics (estimated truth)
  61. C = T + 1 # number of categories (options for voting)
  62. # reshape votes_mat
  63. nz = np.nonzero(votes_mat)
  64. votes_vec = nz[1]
  65. votes_vec_workers = uV[nz[0]]
  66. votes_vec_instances = iV[nz[0]]
  67. VV = len(votes_vec)
  68. # dataset info
  69. instance_set_numbers = np.array([x[2] for x in db_images])
  70. instance_ic_numbers = np.array([x[3] for x in db_images])
  71. instance_ids = np.array([x[0] for x in db_images])
  72. return {'votes': votes_vec,
  73. 'workers': votes_vec_workers,
  74. 'instances': votes_vec_instances,
  75. 'is_expert': is_expert,
  76. 'instance_set_numbers': instance_set_numbers,
  77. 'instance_ic_numbers': instance_ic_numbers,
  78. 'instance_ids': instance_ids,
  79. 'worker_ids': np.array([x[1] for x in db_combined]),
  80. 'vote_ids': np.array(db_labels_column_names[3:11]),
  81. 'n_votes': V,
  82. 'n_classes': T,
  83. 'n_responses': C,
  84. 'n_instances': I,
  85. 'n_workers': A}