CLLDA.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. import numpy as np
  2. try:
  3. from scipy.stats import mode
  4. except ImportError:
  5. def mode(a, axis=0):
  6. scores = np.unique(np.ravel(a)) # get ALL unique values
  7. testshape = list(a.shape)
  8. testshape[axis] = 1
  9. oldmostfreq = np.zeros(testshape)
  10. oldcounts = np.zeros(testshape)
  11. for score in scores:
  12. template = (a == score)
  13. counts = np.expand_dims(np.sum(template, axis), axis)
  14. mostfrequent = np.where(counts > oldcounts, score, oldmostfreq)
  15. oldcounts = np.maximum(counts, oldcounts)
  16. oldmostfreq = mostfrequent
  17. return mostfrequent, oldcounts
  18. import multiprocessing as mp
  19. from time import time
  20. from copy import deepcopy
  21. from crowd_labeling.logratio_transformations import \
  22. centered_log_ratio_transform as clr, \
  23. isometric_log_ratio_transform as ilr, \
  24. additive_log_ratio_transform as alr, \
  25. make_projection_matrix as mpm
  26. class CLLDA:
  27. """
  28. The :class: CLLDA is a python implementation of Crowd Labeling Latent Dirichlet Allocation.
  29. This algorithm processes crowd labeling (aka crowd consensus) data where workers label instances
  30. as pertaining to one or more classes. Allows for calculating resulting label estimates and
  31. covariances in multiple log-ratio transformed spaces.
  32. """
  33. def __init__(self, votes, workers, instances, vote_ids=None, worker_ids=None, instance_ids=None,
  34. worker_prior=None, instance_prior=None, transform=None,
  35. num_epochs=1000, burn_in=200, updateable=True, save_samples=False, seed=None):
  36. """
  37. Initializes settings for the model and automatically calls the inference function.
  38. :param votes: List of vote values.
  39. :param workers: List of uses who submitted :param votes.
  40. :param instances: List of instances to which the :param votes pertain.
  41. :param vote_ids: (optional) List of vote ids. If provided, :param votes should be a list of integers.
  42. :param instance_ids: (optional) List of instance ids. If provided, :param instances should be a list of integers.
  43. :param worker_ids: (optional) List of worker ids. If provided, :param workers should be a list of integers.
  44. :param worker_prior: (optional) Matrix prior for worker skill (pseudovotes).
  45. :param instance_prior: (optional) List of class priors (pseudovotes)
  46. :param transform: log-ratio transform to use.
  47. :param num_epochs: number of epochs to run for.
  48. :param burn_in: number of epochs to ignore for convergence of the Gibbs chain.
  49. :param updateable: If True, will save vote-classes between runs at the expense of memory.
  50. :param save_samples: option to save vote-classes after each epoch (very memory intensive).
  51. :param seed: seed for the random number generator if reproducibility is desired.
  52. """
  53. # set random seed
  54. if seed is None:
  55. seed = np.random.randint(int(1e8))
  56. self.rng = np.random.RandomState(seed=seed)
  57. # data info and priors
  58. self.V = len(votes)
  59. self.U = len(np.unique(workers))
  60. self.I = len(np.unique(instances))
  61. if worker_prior is not None:
  62. self.worker_prior = np.array(worker_prior)
  63. if self.worker_prior.ndim == 2:
  64. self.worker_prior = np.tile(self.worker_prior[np.newaxis, :, :], [self.U, 1, 1])
  65. self.C = self.worker_prior.shape[1]
  66. self.R = self.worker_prior.shape[2]
  67. else:
  68. if vote_ids is not None:
  69. self.C = len(vote_ids)
  70. self.R = self.C
  71. else:
  72. self.R = len(np.unique(votes))
  73. self.C = self.R
  74. self.worker_prior = (np.eye(self.R) + np.ones((self.R, self.R)) / self.R) * 3
  75. self.worker_prior = np.tile(self.worker_prior[np.newaxis, :, :], [self.U, 1, 1])
  76. if instance_prior is None:
  77. self.instance_prior = np.ones(self.C) / self.C / 4
  78. else:
  79. self.instance_prior = instance_prior
  80. # determine vote IDs
  81. if vote_ids is None:
  82. self.vote_ids = np.unique(votes)
  83. vote_dict = {y: x for x, y in enumerate(self.vote_ids)}
  84. votes = np.array([vote_dict[x] for x in votes])
  85. else:
  86. self.vote_ids = vote_ids
  87. # determine instance IDs
  88. if instance_ids is None:
  89. self.instance_ids = np.unique(instances)
  90. instance_dict = {y: x for x, y in enumerate(self.instance_ids)}
  91. instances = np.array([instance_dict[x] for x in instances])
  92. else:
  93. self.instance_ids = instance_ids
  94. # determine worker IDs
  95. if worker_ids is None:
  96. self.worker_ids = np.unique(workers)
  97. worker_dict = {y: x for x, y in enumerate(self.worker_ids)}
  98. workers = np.array([worker_dict[x] for x in workers])
  99. else:
  100. self.worker_ids = worker_ids
  101. # cl_transform info
  102. if not isinstance(transform, str) and hasattr(transform, '__iter__'):
  103. self.transform = tuple(transform)
  104. else:
  105. self.transform = (transform,)
  106. # Gibbs sampling parameters
  107. self.num_epochs = num_epochs
  108. self.burn_in = burn_in
  109. self.num_samples = num_epochs - burn_in
  110. # info to save
  111. self.LL = np.nan * np.ones(self.num_epochs)
  112. self.worker_mats = np.zeros((self.U, self.C, self.R))
  113. self.labels, self.labels_cov = list(), list()
  114. for transform in self.transform:
  115. if transform in (None, 'none', 'clr'):
  116. self.labels.append(np.zeros((self.I, self.C)))
  117. self.labels_cov.append(np.zeros((self.I, self.C, self.C)))
  118. elif transform in ('alr', 'ilr'):
  119. self.labels.append(np.zeros((self.I, self.C - 1)))
  120. self.labels_cov.append(np.zeros((self.I, self.C - 1, self.C - 1)))
  121. else:
  122. raise Exception('Unknown transform!')
  123. if save_samples:
  124. self.samples = np.zeros((self.num_epochs - self.burn_in, self.I, self.C - 1))
  125. self.updateable = updateable
  126. self.vote_classes = None
  127. # estimate label means and covariances using cllda
  128. self.cllda(votes, workers, instances)
  129. # clean up
  130. if not self.updateable:
  131. self.vote_classes = None
  132. else:
  133. self.votes = votes
  134. self.instances = instances
  135. self.workers = workers
  136. # CLLDA optimization using Gibbs sampling
  137. def cllda(self, votes, workers, instances, starting_epoch=0):
  138. """
  139. Performs inference on the :class: CLLDA model.
  140. :param votes: List of vote values.
  141. :param workers: List of workers who submitted :param votes.
  142. :param instances: List of instances to which the :param votes pertain.
  143. :param starting_epoch: How many epochs have already been incorporated in the averages.
  144. """
  145. # precalculate
  146. worker_prior_sum = self.worker_prior.sum(axis=2)
  147. instance_prior_sum = self.instance_prior.sum()
  148. # initial estimates
  149. if self.vote_classes is None:
  150. if self.C == self.R:
  151. self.vote_classes = votes.copy()
  152. else:
  153. self.vote_classes = self.rng.randint(0, self.C, self.V)
  154. # calculate vote weights
  155. temp = np.vstack((workers, instances)).T
  156. temp = np.ascontiguousarray(temp).view(np.dtype((np.void, temp.dtype.itemsize * temp.shape[1])))
  157. _, unique_counts = np.unique(temp, return_counts=True)
  158. weights = 1. / unique_counts[instances] # type: np.ndarray
  159. # initial counts
  160. counts_across_images = np.zeros(shape=(self.U, self.C, self.R))
  161. counts_across_workers_and_votes = np.zeros(shape=(self.I, self.C))
  162. for it_v in range(self.V):
  163. counts_across_images[workers[it_v], self.vote_classes[it_v], votes[it_v]] += weights[it_v]
  164. counts_across_workers_and_votes[instances[it_v], self.vote_classes[it_v]] += weights[it_v]
  165. counts_across_images_and_votes = counts_across_images.sum(axis=2)
  166. # set cl_transform
  167. transform = list()
  168. for tfm in self.transform:
  169. if tfm in (None, 'none'):
  170. transform.append(self.identity)
  171. elif tfm == 'clr':
  172. transform.append(clr)
  173. elif tfm == 'alr':
  174. transform.append(alr)
  175. elif tfm == 'ilr':
  176. transform.append(lambda comp: ilr(comp, mpm(self.C)))
  177. # LDA functions
  178. def get_data_like():
  179. like = np.zeros(self.V)
  180. for it_v in range(self.V):
  181. i = instances[it_v]
  182. k = self.vote_classes[it_v]
  183. u = workers[it_v]
  184. v = votes[it_v]
  185. w = weights[it_v] # type: np.ndarray
  186. like[it_v] = (counts_across_workers_and_votes[i, k] - w + self.instance_prior[k]) \
  187. * (counts_across_images[u, k, v] - w + self.worker_prior[u, k, v]) \
  188. / (counts_across_images_and_votes[u, k] - w + worker_prior_sum[u, k])
  189. return np.log(like).sum()
  190. def get_label_prob():
  191. like = (counts_across_workers_and_votes[i, :] + self.instance_prior[:]) \
  192. * (counts_across_images[u, :, v] + self.worker_prior[u, :, v]) \
  193. / (counts_across_images_and_votes[u, :] + worker_prior_sum[u, :])
  194. return like / like.sum()
  195. def update_labels():
  196. # create update
  197. numerator = counts_across_workers_and_votes + self.instance_prior
  198. denominator = counts_across_workers_and_votes.sum(axis=1) + instance_prior_sum
  199. update = numerator / denominator[:, np.newaxis]
  200. for it, tfm in enumerate(transform):
  201. tfmupdate = tfm(update)
  202. if hasattr(self, 'samples'):
  203. self.samples[ep - self.burn_in, :, :] = tfmupdate
  204. # update labels
  205. delta = (tfmupdate - self.labels[it]) / (ep - self.burn_in + 1)
  206. self.labels[it] += delta
  207. # update labels_M2
  208. delta_cov = delta[:, :, np.newaxis] * delta[:, :, np.newaxis].transpose(0, 2, 1)
  209. self.labels_cov[it] += (ep - self.burn_in) * delta_cov - self.labels_cov[it] / (ep - self.burn_in + 1)
  210. def update_worker_mats():
  211. # create update
  212. numerator = counts_across_images + self.worker_prior
  213. denominator = counts_across_images.sum(axis=2) + worker_prior_sum
  214. update = numerator / denominator[:, :, np.newaxis]
  215. # update labels
  216. delta = (update - self.worker_mats) / (ep - self.burn_in + 1)
  217. self.worker_mats += delta
  218. # CLLDA
  219. start = time()
  220. for ep in range(starting_epoch, starting_epoch + self.num_epochs):
  221. # begin epoch
  222. print('starting epoch ' + str(ep + 1))
  223. if ep > starting_epoch:
  224. time_to_go = (time() - start) * (self.num_epochs - ep) / ep
  225. if time_to_go >= 3600:
  226. print('Estimated time to finish: %.2f hours' % (time_to_go / 3600,))
  227. elif time_to_go >= 60:
  228. print('Estimated time to finish: %.1f minutes' % (time_to_go / 60,))
  229. else:
  230. print('Estimated time to finish: %.1f seconds' % (time_to_go,))
  231. ep_start = time()
  232. # gibbs sampling
  233. for it_v in self.rng.permutation(self.V).astype(np.int64):
  234. # get correct indices
  235. i = instances[it_v]
  236. k = self.vote_classes[it_v]
  237. u = workers[it_v]
  238. v = votes[it_v]
  239. w = weights[it_v]
  240. # decrement counts
  241. counts_across_images[u, k, v] -= w
  242. counts_across_workers_and_votes[i, k] -= w
  243. counts_across_images_and_votes[u, k] -= w
  244. # calculate probabilities of labels for this vote
  245. probs = get_label_prob()
  246. # sample new label
  247. k = self.rng.multinomial(1, probs).argmax()
  248. self.vote_classes[it_v] = k
  249. # increment counts
  250. counts_across_images[u, k, v] += w
  251. counts_across_workers_and_votes[i, k] += w
  252. counts_across_images_and_votes[u, k] += w
  253. # save information
  254. self.LL[ep] = get_data_like()
  255. if ep >= self.burn_in + starting_epoch:
  256. update_labels()
  257. update_worker_mats()
  258. # print epoch LL and duration
  259. print('Epoch completed in %.1f seconds' % (time() - ep_start,))
  260. print('LL: %.6f' % (self.LL[ep]))
  261. # adjust label covariances
  262. self.labels_cov = [x * self.num_samples / (self.num_samples - 1.) for x in self.labels_cov]
  263. time_total = time() - start
  264. if time_total >= 3600:
  265. print('CLLDA completed in %.2f hours' % (time_total / 3600,))
  266. elif time_total >= 60:
  267. print('CLLDA completed in %.1f minutes' % (time_total / 60,))
  268. else:
  269. print('CLLDA completed in %.1f seconds' % (time_total,))
  270. #
  271. def update(self, votes, workers, instances, vote_ids=None, instance_ids=None, worker_ids=None, worker_prior=None,
  272. num_epochs=1000, burn_in=200):
  273. # check that this is updateble
  274. assert self.updateable, 'This model is not updateable, presumable to conserve memory.'
  275. # determine IDs
  276. # for votes
  277. old_vote_ids = self.vote_ids.copy() # type: np.ndarray
  278. if vote_ids is None:
  279. self.vote_ids = np.unique(votes)
  280. vote_dict = {y: x for x, y in enumerate(self.vote_ids)}
  281. votes = np.array([vote_dict[x] for x in votes])
  282. else:
  283. self.vote_ids = vote_ids
  284. # for instances
  285. old_instance_ids = self.instance_ids.copy() # type: np.ndarray
  286. if instance_ids is None:
  287. self.instance_ids = np.unique(instances)
  288. instance_dict = {y: x for x, y in enumerate(self.instance_ids)}
  289. instances = np.array([instance_dict[x] for x in instances])
  290. else:
  291. self.instance_ids = instance_ids
  292. # for workers
  293. old_worker_ids = self.worker_ids.copy() # type: np.ndarray
  294. if worker_ids is None:
  295. self.worker_ids = np.unique(workers)
  296. worker_dict = {y: x for x, y in enumerate(self.worker_ids)}
  297. workers = np.array([worker_dict[x] for x in workers])
  298. else:
  299. self.worker_ids = worker_ids
  300. # update parameters
  301. self.V = len(votes)
  302. self.U = len(np.unique(workers))
  303. self.I = len(np.unique(instances))
  304. self.num_epochs = num_epochs
  305. self.burn_in = burn_in
  306. # add more samples to previous solution
  307. if np.array_equal(votes, self.votes) and np.array_equal(workers, self.workers) \
  308. and np.array_equal(instances, self.instances) and np.array_equal(self.vote_ids, old_vote_ids) \
  309. and np.array_equal(self.instance_ids, old_instance_ids) \
  310. and np.array_equal(self.worker_ids, old_worker_ids):
  311. # adjust label covariances
  312. self.labels_cov = [x * (self.num_samples - 1.) / self.num_samples for x in self.labels_cov]
  313. # update parameters
  314. self.LL = np.concatenate((self.LL, np.zeros(num_epochs)))
  315. old_num_samples = self.num_samples
  316. self.num_samples += num_epochs - burn_in
  317. self.votes = votes
  318. self.workers = workers
  319. self.instances = instances
  320. # update cllda
  321. self.cllda(votes, workers, instances, old_num_samples - 1)
  322. # keep only vote-classes and build off of them
  323. else:
  324. # insert old vote-classes and initialize new vote-classes
  325. old_vote_classes = self.vote_classes.copy()
  326. self.vote_classes = np.zeros_like(votes)
  327. old_dict = {y: x for x, y in enumerate(zip(self.votes, self.workers, self.instances))}
  328. for it, index in enumerate(zip(votes, workers, instances)):
  329. try:
  330. self.vote_classes[it] = old_vote_classes[old_dict[index]]
  331. except KeyError:
  332. if self.C == self.R:
  333. self.vote_classes[it] = votes[it]
  334. else:
  335. self.vote_classes[it] = self.rng.randint(self.C)
  336. # adjust worker_prior if necessary
  337. if not np.array_equal(self.worker_ids, old_worker_ids):
  338. assert worker_prior is not None, "Worker priors must be provided if worker_ids change."
  339. self.worker_prior = np.array(worker_prior)
  340. if self.worker_prior.ndim == 2:
  341. self.worker_prior = np.tile(self.worker_prior[np.newaxis, :, :], [self.U, 1, 1])
  342. # adjust info to save
  343. self.worker_mats = np.zeros((self.U, self.C, self.R))
  344. self.labels, self.labels_cov = list(), list()
  345. for transform in self.transform:
  346. if transform in (None, 'none', 'clr'):
  347. self.labels.append(np.zeros((self.I, self.C)))
  348. self.labels_cov.append(np.zeros((self.I, self.C, self.C)))
  349. elif transform in ('alr', 'ilr'):
  350. self.labels.append(np.zeros((self.I, self.C - 1)))
  351. self.labels_cov.append(np.zeros((self.I, self.C - 1, self.C - 1)))
  352. else:
  353. raise Exception('Unknown transform!')
  354. # update parameters
  355. self.LL = np.zeros(num_epochs)
  356. self.num_samples = num_epochs
  357. # update cllda
  358. self.cllda(votes, workers, instances)
  359. self.votes = votes
  360. self.instances = instances
  361. self.workers = workers
  362. # no cl_transform
  363. @staticmethod
  364. def identity(compositional):
  365. return compositional
  366. def concurrent_cllda(models, votes, workers, instances, nprocs=4, **kwargs):
  367. """
  368. Effortless parallelization of multiple CLLDA models.
  369. :param models: If creating new models, an integer denoting how many models to create.
  370. Otherwise, a list of existing models to update.
  371. :param votes: List of vote values.
  372. :param workers: List of uses who submitted :param votes.
  373. :param instances: List of instances to which the :param votes pertain.
  374. :param nprocs: Number of processors to use in the parallel pool.
  375. :param kwargs: Other possible inputs to either CLLDA.__init__ or CLLDA.update
  376. :return: List of new or updated CLLDA models.
  377. """
  378. # open parallel pool
  379. print('Starting multiprocessing pool...')
  380. pool = mp.Pool(processes=nprocs)
  381. # run CL-LDA
  382. if isinstance(models, int):
  383. print('Starting new CL-LDA models in parallel...')
  384. if 'seed' in kwargs.keys():
  385. np.random.seed(kwargs['seed'])
  386. kwargs = [deepcopy(kwargs) for x in range(models)]
  387. for it in range(models):
  388. kwargs[it]['seed'] = np.random.randint(int(1e8))
  389. out = pool.map(_new_cllda, [(votes, workers, instances, kwa) for kwa in kwargs])
  390. elif hasattr(models, '__iter__'):
  391. print('Updating CL-LDA models in parallel...')
  392. out = pool.map(_update_cllda, [(model, votes, workers, instances, kwargs) for model in models])
  393. else:
  394. pool.close()
  395. TypeError('Unknown type for input: models.')
  396. # close parallel pool
  397. pool.close()
  398. print('Multiprocessing pool closed.')
  399. return out
  400. def combine_cllda(models):
  401. """
  402. Combine multiple CLLDA instances.
  403. :param models: List of CLLDA models trained with the same settings.
  404. :return: CLLDA model which combines the input models.
  405. """
  406. # check models are equivalent
  407. assert np.equal(models[0].V, [model.V for model in models[1:]]).any(), 'Different number of votes!'
  408. assert np.equal(models[0].U, [model.U for model in models[1:]]).any(), 'Different number of workers!'
  409. assert np.equal(models[0].I, [model.I for model in models[1:]]).any(), 'Different number of instances!'
  410. assert np.equal(models[0].C, [model.C for model in models[1:]]).any(), 'Different number of classes!'
  411. assert np.equal(models[0].R, [model.R for model in models[1:]]).any(), 'Different number of responses!'
  412. assert np.equal(models[0].worker_prior,
  413. [model.worker_prior for model in models[1:]]).any(), 'Different worker priors!'
  414. assert np.equal(models[0].instance_prior,
  415. [model.instance_prior for model in models[1:]]).any(), 'Different instance priors!'
  416. assert np.all([models[0].transform == model.transform for model in models[1:]]), 'Different transforms!'
  417. # data info
  418. out = deepcopy((models[0]))
  419. # combine label estimates
  420. out.num_samples = np.sum([model.num_samples for model in models])
  421. # combine worker estimates
  422. out.worker_mats = np.sum([model.worker_mats * model.num_samples for model in models],
  423. axis=0) / out.num_samples
  424. if all([x.updateable for x in models]):
  425. out.vote_classes = mode(np.stack([x.vote_classes for x in models]))[0].flatten()
  426. # combine labels and label covariances
  427. for it in range(len(models[0].transform)):
  428. out.labels[it] = np.sum([model.labels[it] * model.num_samples for model in models], 0) / out.num_samples
  429. labels_corrmat = [(model.num_samples - 1.) / model.num_samples * model.labels_cov[it]
  430. + model.labels[it][..., np.newaxis] * model.labels[it][..., np.newaxis].transpose(0, 2, 1)
  431. for model in models]
  432. out.labels_cov[it] = np.sum([corrmat * model.num_samples for model, corrmat in zip(models, labels_corrmat)],
  433. 0) \
  434. / out.num_samples - out.labels[it][..., np.newaxis] * out.labels[it][
  435. ..., np.newaxis].transpose(0, 2, 1)
  436. # adjust label covariances
  437. out.labels_cov[it] *= out.num_samples / (out.num_samples - 1.)
  438. return out
  439. # map function
  440. def _new_cllda(inputs):
  441. return CLLDA(*inputs[:3], **inputs[3])
  442. # map function
  443. def _update_cllda(inputs):
  444. inputs[0].update(*inputs[1:4], **inputs[4])
  445. return inputs[0]
  446. # if __name__ == '__main__':
  447. # # test suite
  448. # from DS import DS
  449. # test_data = DS.test_data()
  450. # CLLDA(test_data[0], test_data[1], test_data[2], num_epochs=10, burn_in=2, transform=('none', 'alr', 'ilr', 'clr'))
  451. # cls = concurrent_cllda(4, test_data[0], test_data[1], test_data[2],
  452. # num_epochs=10, burn_in=2, transform=('none', 'alr', 'ilr', 'clr'))
  453. # cl = combine_cllda(cls)
  454. # a=1