DS.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import numpy as np
  2. from time import time
  3. class DS:
  4. def __init__(self, votes=None, workers=None, instances=None, test=None,
  5. transform=None, classes=None, num_epochs=1000):
  6. if votes is None or workers is None or instances is None:
  7. votes, workers, instances, test = self.test_data()
  8. # data info
  9. self.V = len(votes)
  10. self.U = len(np.unique(workers))
  11. self.I = len(np.unique(instances))
  12. if classes is not None:
  13. self.C = len(classes)
  14. else:
  15. self.C = len(np.unique(votes))
  16. self.transform = transform
  17. self.instance_prior = np.zeros(self.C)
  18. self.eps = np.finfo(np.float64).eps
  19. # EM parameters
  20. self.max_epochs = num_epochs
  21. # info to save
  22. self.LL = np.nan * np.ones(self.max_epochs)
  23. if test is not None:
  24. self.accuracy = np.nan * np.ones(self.max_epochs)
  25. self.labels = np.zeros((self.I, self.C))
  26. # self.labels_cov = np.zeros((self.I, self.C, self.C))
  27. self.worker_skill = np.zeros((self.U, self.C, self.C))
  28. # estimate label means and covariances using ds
  29. self.ds(votes, workers, instances, test)
  30. # apply transform
  31. if transform == 'clr':
  32. def clr(self):
  33. continuous = np.log(self.labels + self.eps)
  34. continuous -= continuous.mean(1, keepdims=True)
  35. return continuous
  36. self.labels = clr(self)
  37. elif transform == 'alr':
  38. def alr(self):
  39. continuous = np.log(self.labels[:, :-1] / (self.labels[:, -1] + self.eps))
  40. return continuous
  41. self.labels = alr(self)
  42. elif transform == 'ilr':
  43. # make projection matrix
  44. self.projectionMatrix = np.zeros((self.C, self.C - 1), dtype=np.float32)
  45. for it in range(self.C - 1):
  46. i = it + 1
  47. self.projectionMatrix[:i, it] = 1. / i
  48. self.projectionMatrix[i, it] = -1
  49. self.projectionMatrix[i + 1:, it] = 0
  50. self.projectionMatrix[:, it] *= np.sqrt(i / (i + 1.))
  51. def ilr(self):
  52. continuous = np.log(self.labels + self.eps)
  53. continuous -= continuous.mean(1, keepdims=True)
  54. continuous = np.dot(continuous, self.projectionMatrix)
  55. return continuous
  56. self.labels = ilr(self)
  57. # # adjust label covariances and inverses
  58. # self.__estimate_label_cov_with_inv_wishart()
  59. # self.labels_cov *= (self.num_samples + 1) / self.num_samples
  60. # EM functions
  61. def calculate_log_like(self, worker_ind, vote_ind):
  62. # calculate log-likelihood (2.7 is DS)
  63. LL = 0
  64. for i in range(self.I):
  65. LL += np.log((self.worker_skill[worker_ind[i], :, vote_ind[i]].prod(0) * self.instance_prior).sum())
  66. LL /= self.I
  67. return LL
  68. # estimate the instance classes given the current parameters (these labels are treated as missing data here for EM)
  69. def e_step(self, worker_ind, vote_ind):
  70. # estimate instance classes (2.5 is DS)
  71. for i in range(self.I):
  72. self.labels[i, :] = self.worker_skill[worker_ind[i], :, vote_ind[i]].prod(0)
  73. self.labels *= self.instance_prior[np.newaxis, :]
  74. self.labels /= self.labels.sum(1, keepdims=True) + self.eps
  75. # update parameters to maximize the data likelihood
  76. def m_step(self, instance_ind):
  77. # argmax LL over class probabilities (2.4 is DS)
  78. self.instance_prior = self.labels.mean(0) + self.eps
  79. # argmax LL over worker skill (2.3 is DS)
  80. for u in range(self.U):
  81. for c in range(self.C):
  82. self.worker_skill[u, :, c] = self.labels[instance_ind[u][c], :].sum(0)
  83. self.worker_skill /= self.worker_skill.sum(2, keepdims=True) + self.eps
  84. # DS optimization using EM
  85. def ds(self, votes, workers, instances, test):
  86. # precalculate indices
  87. print('Generating indices...')
  88. worker_ind, vote_ind, instance_ind = [], [], []
  89. for i in range(self.I):
  90. _instance_ind = instances == i
  91. worker_ind.append(workers[_instance_ind])
  92. vote_ind.append(votes[_instance_ind])
  93. for u in range(self.U):
  94. instance_ind.append([])
  95. _worker_ind = workers == u
  96. for c in range(self.C):
  97. _vote_ind = votes == c
  98. instance_ind[u].append(instances[np.bitwise_and(_worker_ind, _vote_ind)])
  99. # DS
  100. start = time()
  101. for ep in range(self.max_epochs):
  102. # begin epoch
  103. print('starting epoch ' + str(ep + 1))
  104. if ep:
  105. time_to_go = (time() - start) * (self.max_epochs - ep) / ep
  106. if time_to_go >= 3600:
  107. print('Estimated time to finish: %.2f hours' % (time_to_go / 3600,))
  108. elif time_to_go >= 60:
  109. print('Estimated time to finish: %.2f minutes' % (time_to_go / 60,))
  110. else:
  111. print('Estimated time to finish: %.1f seconds' % (time_to_go,))
  112. ep_start = time()
  113. # EM
  114. print('E step...')
  115. if not ep:
  116. # initial estimates
  117. for i in range(self.I):
  118. ind = instances == i
  119. for c in range(self.C):
  120. self.labels[i, c] = np.count_nonzero(votes[ind] == c)
  121. self.labels /= self.labels.sum(1, keepdims=True) + self.eps
  122. # self.labels[np.random.choice(self.I, self.I/2, replace=False), :] = 1. / self.C
  123. else:
  124. self.e_step(worker_ind, vote_ind)
  125. print('M step...')
  126. self.m_step(instance_ind)
  127. # save information
  128. print('Calculating log-likelihood...')
  129. self.LL[ep] = self.calculate_log_like(worker_ind, vote_ind)
  130. print('Log-likelihood = %f' % (self.LL[ep],))
  131. # evaulation if test available
  132. if test is not None:
  133. self.accuracy[ep] = (self.labels.argmax(1) == test).mean()
  134. print('Accuracy = %f' % (self.accuracy[ep],))
  135. # ce = -np.log(self.labels[range(self.I), test] + self.eps).sum()
  136. # print 'Cross Entropy = %f' % (ce,)
  137. # print epoch duration
  138. print('Epoch completed in %.1f seconds' % (time() - ep_start,))
  139. time_total = time() - start
  140. if time_total >= 3600:
  141. print('DS completed in %.2f hours' % (time_total / 3600,))
  142. elif time_total >= 60:
  143. print('DS completed in %.2f minutes' % (time_total / 60,))
  144. else:
  145. print('DS completed in %.1f seconds' % (time_total,))
  146. # # generate covariance estimates using inverse Wishart prior
  147. # def __estimate_label_cov_with_inv_wishart(self):
  148. # # prepare parameters
  149. # self.inv_wishart_prior_scatter = 0.1 * np.eye(self.C - 1) * self.num_samples
  150. # self.inv_wishart_degrees_of_freedom = self.C - 1
  151. # scatter_matrix = self.labels_cov * self.num_samples
  152. #
  153. # # calculate multivariate student-t covariance based on normal with known mean and inverse Wishart prior
  154. # self.labels_cov_iwp = (self.inv_wishart_prior_scatter + scatter_matrix) \
  155. # / (self.inv_wishart_degrees_of_freedom + self.num_samples - (self.C - 1) - 1)
  156. #
  157. # # calculate covariance inverses for later use
  158. # self.labels_icov_iwp = np.linalg.inv(self.labels_cov_iwp )
  159. @staticmethod
  160. def test_data():
  161. """
  162. Sample data from the Dawid & Skene (1979) paper
  163. :return: (votes, workers, instances, true_class)
  164. """
  165. # data from DS section 4
  166. data = [[[1, 1, 1], 1, 1, 1, 1],
  167. [[3, 3, 3], 4, 3, 3, 4],
  168. [[1, 1, 2], 2, 1, 2, 2],
  169. [[2, 2, 2], 3, 1, 2, 1],
  170. [[2, 2, 2], 3, 2, 2, 2],
  171. [[2, 2, 2], 3, 3, 2, 2],
  172. [[1, 2, 2], 2, 1, 1, 1],
  173. [[3, 3, 3], 3, 4, 3, 3],
  174. [[2, 2, 2], 2, 2, 2, 3],
  175. [[2, 3, 2], 2, 2, 2, 3],
  176. [[4, 4, 4], 4, 4, 4, 4],
  177. [[2, 2, 2], 3, 3, 4, 3],
  178. [[1, 1, 1], 1, 1, 1, 1],
  179. [[2, 2, 2], 3, 2, 1, 2],
  180. [[1, 2, 1], 1, 1, 1, 1],
  181. [[1, 1, 1], 2, 1, 1, 1],
  182. [[1, 1, 1], 1, 1, 1, 1],
  183. [[1, 1, 1], 1, 1, 1, 1],
  184. [[2, 2, 2], 2, 2, 2, 1],
  185. [[2, 2, 2], 1, 3, 2, 2],
  186. [[2, 2, 2], 2, 2, 2, 2],
  187. [[2, 2, 2], 2, 2, 2, 1],
  188. [[2, 2, 2], 3, 2, 2, 2],
  189. [[2, 2, 1], 2, 2, 2, 2],
  190. [[1, 1, 1], 1, 1, 1, 1],
  191. [[1, 1, 1], 1, 1, 1, 1],
  192. [[2, 3, 2], 2, 2, 2, 2],
  193. [[1, 1, 1], 1, 1, 1, 1],
  194. [[1, 1, 1], 1, 1, 1, 1],
  195. [[1, 1, 2], 1, 1, 2, 1],
  196. [[1, 1, 1], 1, 1, 1, 1],
  197. [[3, 3, 3], 3, 2, 3, 3],
  198. [[1, 1, 1], 1, 1, 1, 1],
  199. [[2, 2, 2], 2, 2, 2, 2],
  200. [[2, 2, 2], 3, 2, 3, 2],
  201. [[4, 3, 3], 4, 3, 4, 3],
  202. [[2, 2, 1], 2, 2, 3, 2],
  203. [[2, 3, 2], 3, 2, 3, 3],
  204. [[3, 3, 3], 3, 4, 3, 2],
  205. [[1, 1, 1], 1, 1, 1, 1],
  206. [[1, 1, 1], 1, 1, 1, 1],
  207. [[1, 2, 1], 2, 1, 1, 1],
  208. [[2, 3, 2], 2, 2, 2, 2],
  209. [[1, 2, 1], 1, 1, 1, 1],
  210. [[2, 2, 2], 2, 2, 2, 2]]
  211. # solutions from DS section 4
  212. test = [1,
  213. 4,
  214. 2,
  215. 2,
  216. 2,
  217. 2,
  218. 1,
  219. 3,
  220. 2,
  221. 2,
  222. 4,
  223. 3,
  224. 1,
  225. 2,
  226. 1,
  227. 1,
  228. 1,
  229. 1,
  230. 2,
  231. 2,
  232. 2,
  233. 2,
  234. 2,
  235. 2,
  236. 1,
  237. 1,
  238. 2,
  239. 1,
  240. 1,
  241. 1,
  242. 1,
  243. 3,
  244. 1,
  245. 2,
  246. 2,
  247. 4,
  248. 2,
  249. 3,
  250. 3,
  251. 1,
  252. 1,
  253. 1,
  254. 2,
  255. 1,
  256. 2]
  257. # cl_transform to list format
  258. votes, workers, instances = [], [], []
  259. for it_patient, patient in enumerate(data):
  260. for it_doctor, doctor in enumerate(patient):
  261. if isinstance(doctor, list):
  262. for diagnosis in doctor:
  263. votes.append(diagnosis-1)
  264. workers.append(it_doctor)
  265. instances.append(it_patient)
  266. else:
  267. votes.append(doctor-1)
  268. workers.append(it_doctor)
  269. instances.append(it_patient)
  270. return np.array(votes), np.array(workers), np.array(instances), np.array(test) - 1
  271. if __name__ == '__main__':
  272. ds = DS()