icldata.py 77 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916
  1. from time import (time, gmtime, strftime)
  2. import numpy as np
  3. from sklearn.decomposition import PCA
  4. import h5py
  5. import os
  6. from shutil import rmtree
  7. from os.path import isdir, isfile, join, basename
  8. import cPickle as pkl
  9. import sqlite3
  10. import joblib
  11. from collections import OrderedDict
  12. from copy import copy
  13. from matplotlib import pyplot as plt
  14. import webbrowser as wb
  15. import urllib2
  16. class ICLabelDataset:
  17. """
  18. This class provides an easy interface to downloading, loading, organizing, and processing the ICLabel dataset.
  19. The ICLabel dataset is intended for training and validating electroencephalographic (EEG) independent component
  20. (IC) classifiers.
  21. It contains an unlabled training dataset, several collections of labels for small subset of the training dataset,
  22. and a test dataset 130 ICs where each IC was labeled by 6 experts.
  23. Features included:
  24. * Scalp topography images (32x32 pixel flattened to 740 elements after removing white-space)
  25. * Power spectral densities (1-100 Hz)
  26. * Autocorrelation functions (1 second)
  27. * Equivalent current dipole fits (1 and 2 dipole)
  28. * Hand crafted features (some new and some from previously published classifiers)
  29. :Example:
  30. icl = ICLabelDataset();
  31. icldata = icl.load_semi_supervised()
  32. """
  33. def __init__(self, features='all', label_type='all', datapath='', n_test_datasets=50, n_val_ics=200, transform='none',
  34. unique=True, do_pca=False, combine_output=False, seed=np.random.randint(0, int(1e5))):
  35. """
  36. Initialize an ICLabelDataset object.
  37. :param features: The types of features to return.
  38. :param label_type: Which ICLabels to use.
  39. :param datapath: Where the dataset and cache is stored.
  40. :param n_test_datasets: How many unlabeled datasets to include in the test set.
  41. :param n_val_ics: How many labeled components to transfer to the validation set.
  42. :param transform: The inverse log-ratio transform to use for labels and their covariances.
  43. :param unique: Whether or not to use ICs with the same scalp topography. Non-unique is not implemented.
  44. :param combine_output: determines whether output features are dictionaries or an array of combined features.
  45. :param seed: The seed for the pseudo random shuffle of data points.
  46. :return: Initialized ICLabelDataset object.
  47. """
  48. # data parameters
  49. self.datapath = datapath
  50. self.features = features
  51. self.n_test_datasets = n_test_datasets
  52. self.n_val_ics = n_val_ics
  53. self.transform = transform
  54. self.unique = unique
  55. if not self.unique:
  56. raise NotImplementedError
  57. self.do_pca = do_pca
  58. self.combine_output = combine_output
  59. self.label_type = label_type
  60. assert(label_type in ('all', 'luca', 'database'))
  61. self.seed = seed
  62. self.psd_mean = None
  63. self.psd_mean_var = None
  64. self.psd_mean_kurt = None
  65. self.psd_limits = None
  66. self.psd_var_limits = None
  67. self.psd_kurt_limits = None
  68. self.pscorr_mean = None
  69. self.pscorr_std = None
  70. self.pscorr_limits = None
  71. self.psd_freqs = 100
  72. # training feature-sets
  73. self.train_feature_indices = OrderedDict([
  74. ('ids', np.arange(2)),
  75. ('topo', np.arange(2, 742)),
  76. ('handcrafted', np.arange(742, 760)), # one lost due to removal in load_data
  77. ('dipole', np.arange(760, 780)),
  78. ('psd', np.arange(780, 880)),
  79. ('psd_var', np.arange(880, 980)),
  80. ('psd_kurt', np.arange(980, 1080)),
  81. ('autocorr', np.arange(1080, 1180)),
  82. ])
  83. self.test_feature_indices = OrderedDict([
  84. ('ids', np.arange(3)),
  85. ('topo', np.arange(3, 743)),
  86. ('handcrafted', np.arange(743, 761)), # one lost due to removal in load_data
  87. ('dipole', np.arange(761, 781)),
  88. ('psd', np.arange(781, 881)),
  89. ('psd_var', np.arange(881, 981)),
  90. ('psd_kurt', np.arange(981, 1081)),
  91. ('autocorr', np.arange(1081, 1181)),
  92. ])
  93. # reorganize features
  94. if self.features == 'all' or 'all' in self.features:
  95. self.features = self.train_feature_indices.keys()
  96. if isinstance(self.features, str):
  97. self.features = [self.features]
  98. if 'ids' not in self.features:
  99. self.features = ['ids'] + self.features
  100. # visualization parameters
  101. self.topo_ind = np.array([
  102. 43,
  103. 44,
  104. 45,
  105. 46,
  106. 47,
  107. 48,
  108. 49,
  109. 50,
  110. 51,
  111. 52,
  112. 72,
  113. 73,
  114. 74,
  115. 75,
  116. 76,
  117. 77,
  118. 78,
  119. 79,
  120. 80,
  121. 81,
  122. 82,
  123. 83,
  124. 84,
  125. 85,
  126. 86,
  127. 87,
  128. 103,
  129. 104,
  130. 105,
  131. 106,
  132. 107,
  133. 108,
  134. 109,
  135. 110,
  136. 111,
  137. 112,
  138. 113,
  139. 114,
  140. 115,
  141. 116,
  142. 117,
  143. 118,
  144. 119,
  145. 120,
  146. 134,
  147. 135,
  148. 136,
  149. 137,
  150. 138,
  151. 139,
  152. 140,
  153. 141,
  154. 142,
  155. 143,
  156. 144,
  157. 145,
  158. 146,
  159. 147,
  160. 148,
  161. 149,
  162. 150,
  163. 151,
  164. 152,
  165. 153,
  166. 165,
  167. 166,
  168. 167,
  169. 168,
  170. 169,
  171. 170,
  172. 171,
  173. 172,
  174. 173,
  175. 174,
  176. 175,
  177. 176,
  178. 177,
  179. 178,
  180. 179,
  181. 180,
  182. 181,
  183. 182,
  184. 183,
  185. 184,
  186. 185,
  187. 186,
  188. 196,
  189. 197,
  190. 198,
  191. 199,
  192. 200,
  193. 201,
  194. 202,
  195. 203,
  196. 204,
  197. 205,
  198. 206,
  199. 207,
  200. 208,
  201. 209,
  202. 210,
  203. 211,
  204. 212,
  205. 213,
  206. 214,
  207. 215,
  208. 216,
  209. 217,
  210. 218,
  211. 219,
  212. 227,
  213. 228,
  214. 229,
  215. 230,
  216. 231,
  217. 232,
  218. 233,
  219. 234,
  220. 235,
  221. 236,
  222. 237,
  223. 238,
  224. 239,
  225. 240,
  226. 241,
  227. 242,
  228. 243,
  229. 244,
  230. 245,
  231. 246,
  232. 247,
  233. 248,
  234. 249,
  235. 250,
  236. 251,
  237. 252,
  238. 258,
  239. 259,
  240. 260,
  241. 261,
  242. 262,
  243. 263,
  244. 264,
  245. 265,
  246. 266,
  247. 267,
  248. 268,
  249. 269,
  250. 270,
  251. 271,
  252. 272,
  253. 273,
  254. 274,
  255. 275,
  256. 276,
  257. 277,
  258. 278,
  259. 279,
  260. 280,
  261. 281,
  262. 282,
  263. 283,
  264. 284,
  265. 285,
  266. 290,
  267. 291,
  268. 292,
  269. 293,
  270. 294,
  271. 295,
  272. 296,
  273. 297,
  274. 298,
  275. 299,
  276. 300,
  277. 301,
  278. 302,
  279. 303,
  280. 304,
  281. 305,
  282. 306,
  283. 307,
  284. 308,
  285. 309,
  286. 310,
  287. 311,
  288. 312,
  289. 313,
  290. 314,
  291. 315,
  292. 316,
  293. 317,
  294. 322,
  295. 323,
  296. 324,
  297. 325,
  298. 326,
  299. 327,
  300. 328,
  301. 329,
  302. 330,
  303. 331,
  304. 332,
  305. 333,
  306. 334,
  307. 335,
  308. 336,
  309. 337,
  310. 338,
  311. 339,
  312. 340,
  313. 341,
  314. 342,
  315. 343,
  316. 344,
  317. 345,
  318. 346,
  319. 347,
  320. 348,
  321. 349,
  322. 353,
  323. 354,
  324. 355,
  325. 356,
  326. 357,
  327. 358,
  328. 359,
  329. 360,
  330. 361,
  331. 362,
  332. 363,
  333. 364,
  334. 365,
  335. 366,
  336. 367,
  337. 368,
  338. 369,
  339. 370,
  340. 371,
  341. 372,
  342. 373,
  343. 374,
  344. 375,
  345. 376,
  346. 377,
  347. 378,
  348. 379,
  349. 380,
  350. 381,
  351. 382,
  352. 385,
  353. 386,
  354. 387,
  355. 388,
  356. 389,
  357. 390,
  358. 391,
  359. 392,
  360. 393,
  361. 394,
  362. 395,
  363. 396,
  364. 397,
  365. 398,
  366. 399,
  367. 400,
  368. 401,
  369. 402,
  370. 403,
  371. 404,
  372. 405,
  373. 406,
  374. 407,
  375. 408,
  376. 409,
  377. 410,
  378. 411,
  379. 412,
  380. 413,
  381. 414,
  382. 417,
  383. 418,
  384. 419,
  385. 420,
  386. 421,
  387. 422,
  388. 423,
  389. 424,
  390. 425,
  391. 426,
  392. 427,
  393. 428,
  394. 429,
  395. 430,
  396. 431,
  397. 432,
  398. 433,
  399. 434,
  400. 435,
  401. 436,
  402. 437,
  403. 438,
  404. 439,
  405. 440,
  406. 441,
  407. 442,
  408. 443,
  409. 444,
  410. 445,
  411. 446,
  412. 449,
  413. 450,
  414. 451,
  415. 452,
  416. 453,
  417. 454,
  418. 455,
  419. 456,
  420. 457,
  421. 458,
  422. 459,
  423. 460,
  424. 461,
  425. 462,
  426. 463,
  427. 464,
  428. 465,
  429. 466,
  430. 467,
  431. 468,
  432. 469,
  433. 470,
  434. 471,
  435. 472,
  436. 473,
  437. 474,
  438. 475,
  439. 476,
  440. 477,
  441. 478,
  442. 481,
  443. 482,
  444. 483,
  445. 484,
  446. 485,
  447. 486,
  448. 487,
  449. 488,
  450. 489,
  451. 490,
  452. 491,
  453. 492,
  454. 493,
  455. 494,
  456. 495,
  457. 496,
  458. 497,
  459. 498,
  460. 499,
  461. 500,
  462. 501,
  463. 502,
  464. 503,
  465. 504,
  466. 505,
  467. 506,
  468. 507,
  469. 508,
  470. 509,
  471. 510,
  472. 513,
  473. 514,
  474. 515,
  475. 516,
  476. 517,
  477. 518,
  478. 519,
  479. 520,
  480. 521,
  481. 522,
  482. 523,
  483. 524,
  484. 525,
  485. 526,
  486. 527,
  487. 528,
  488. 529,
  489. 530,
  490. 531,
  491. 532,
  492. 533,
  493. 534,
  494. 535,
  495. 536,
  496. 537,
  497. 538,
  498. 539,
  499. 540,
  500. 541,
  501. 542,
  502. 545,
  503. 546,
  504. 547,
  505. 548,
  506. 549,
  507. 550,
  508. 551,
  509. 552,
  510. 553,
  511. 554,
  512. 555,
  513. 556,
  514. 557,
  515. 558,
  516. 559,
  517. 560,
  518. 561,
  519. 562,
  520. 563,
  521. 564,
  522. 565,
  523. 566,
  524. 567,
  525. 568,
  526. 569,
  527. 570,
  528. 571,
  529. 572,
  530. 573,
  531. 574,
  532. 577,
  533. 578,
  534. 579,
  535. 580,
  536. 581,
  537. 582,
  538. 583,
  539. 584,
  540. 585,
  541. 586,
  542. 587,
  543. 588,
  544. 589,
  545. 590,
  546. 591,
  547. 592,
  548. 593,
  549. 594,
  550. 595,
  551. 596,
  552. 597,
  553. 598,
  554. 599,
  555. 600,
  556. 601,
  557. 602,
  558. 603,
  559. 604,
  560. 605,
  561. 606,
  562. 609,
  563. 610,
  564. 611,
  565. 612,
  566. 613,
  567. 614,
  568. 615,
  569. 616,
  570. 617,
  571. 618,
  572. 619,
  573. 620,
  574. 621,
  575. 622,
  576. 623,
  577. 624,
  578. 625,
  579. 626,
  580. 627,
  581. 628,
  582. 629,
  583. 630,
  584. 631,
  585. 632,
  586. 633,
  587. 634,
  588. 635,
  589. 636,
  590. 637,
  591. 638,
  592. 641,
  593. 642,
  594. 643,
  595. 644,
  596. 645,
  597. 646,
  598. 647,
  599. 648,
  600. 649,
  601. 650,
  602. 651,
  603. 652,
  604. 653,
  605. 654,
  606. 655,
  607. 656,
  608. 657,
  609. 658,
  610. 659,
  611. 660,
  612. 661,
  613. 662,
  614. 663,
  615. 664,
  616. 665,
  617. 666,
  618. 667,
  619. 668,
  620. 669,
  621. 670,
  622. 674,
  623. 675,
  624. 676,
  625. 677,
  626. 678,
  627. 679,
  628. 680,
  629. 681,
  630. 682,
  631. 683,
  632. 684,
  633. 685,
  634. 686,
  635. 687,
  636. 688,
  637. 689,
  638. 690,
  639. 691,
  640. 692,
  641. 693,
  642. 694,
  643. 695,
  644. 696,
  645. 697,
  646. 698,
  647. 699,
  648. 700,
  649. 701,
  650. 706,
  651. 707,
  652. 708,
  653. 709,
  654. 710,
  655. 711,
  656. 712,
  657. 713,
  658. 714,
  659. 715,
  660. 716,
  661. 717,
  662. 718,
  663. 719,
  664. 720,
  665. 721,
  666. 722,
  667. 723,
  668. 724,
  669. 725,
  670. 726,
  671. 727,
  672. 728,
  673. 729,
  674. 730,
  675. 731,
  676. 732,
  677. 733,
  678. 738,
  679. 739,
  680. 740,
  681. 741,
  682. 742,
  683. 743,
  684. 744,
  685. 745,
  686. 746,
  687. 747,
  688. 748,
  689. 749,
  690. 750,
  691. 751,
  692. 752,
  693. 753,
  694. 754,
  695. 755,
  696. 756,
  697. 757,
  698. 758,
  699. 759,
  700. 760,
  701. 761,
  702. 762,
  703. 763,
  704. 764,
  705. 765,
  706. 771,
  707. 772,
  708. 773,
  709. 774,
  710. 775,
  711. 776,
  712. 777,
  713. 778,
  714. 779,
  715. 780,
  716. 781,
  717. 782,
  718. 783,
  719. 784,
  720. 785,
  721. 786,
  722. 787,
  723. 788,
  724. 789,
  725. 790,
  726. 791,
  727. 792,
  728. 793,
  729. 794,
  730. 795,
  731. 796,
  732. 804,
  733. 805,
  734. 806,
  735. 807,
  736. 808,
  737. 809,
  738. 810,
  739. 811,
  740. 812,
  741. 813,
  742. 814,
  743. 815,
  744. 816,
  745. 817,
  746. 818,
  747. 819,
  748. 820,
  749. 821,
  750. 822,
  751. 823,
  752. 824,
  753. 825,
  754. 826,
  755. 827,
  756. 837,
  757. 838,
  758. 839,
  759. 840,
  760. 841,
  761. 842,
  762. 843,
  763. 844,
  764. 845,
  765. 846,
  766. 847,
  767. 848,
  768. 849,
  769. 850,
  770. 851,
  771. 852,
  772. 853,
  773. 854,
  774. 855,
  775. 856,
  776. 857,
  777. 858,
  778. 870,
  779. 871,
  780. 872,
  781. 873,
  782. 874,
  783. 875,
  784. 876,
  785. 877,
  786. 878,
  787. 879,
  788. 880,
  789. 881,
  790. 882,
  791. 883,
  792. 884,
  793. 885,
  794. 886,
  795. 887,
  796. 888,
  797. 889,
  798. 903,
  799. 904,
  800. 905,
  801. 906,
  802. 907,
  803. 908,
  804. 909,
  805. 910,
  806. 911,
  807. 912,
  808. 913,
  809. 914,
  810. 915,
  811. 916,
  812. 917,
  813. 918,
  814. 919,
  815. 920,
  816. 936,
  817. 937,
  818. 938,
  819. 939,
  820. 940,
  821. 941,
  822. 942,
  823. 943,
  824. 944,
  825. 945,
  826. 946,
  827. 947,
  828. 948,
  829. 949,
  830. 950,
  831. 951,
  832. 971,
  833. 972,
  834. 973,
  835. 974,
  836. 975,
  837. 976,
  838. 977,
  839. 978,
  840. 979,
  841. 980,
  842. ])
  843. self.psd_ind = np.arange(1, 101)
  844. self.max_grid_plot = 144
  845. self.base_url_image = 'labeling.ucsd.edu/images/'
  846. # data url
  847. self.base_url_download = 'labeling.ucsd.edu/download/'
  848. self.feature_train_zip_url = self.base_url_download + 'features.zip'
  849. self.feature_train_urls = [
  850. self.base_url_download + 'features_0D1D2D.mat',
  851. self.base_url_download + 'features_PSD_med_var_kurt.mat',
  852. self.base_url_download + 'features_AutoCorr.mat',
  853. self.base_url_download + 'features_MI.mat',
  854. ]
  855. self.label_train_urls = [
  856. self.base_url_download + 'ICLabels_experts.pkl',
  857. self.base_url_download + 'ICLabels_onlyluca.pkl',
  858. ]
  859. self.feature_test_url = self.base_url_download + 'features_testset_full.mat'
  860. self.label_train_urls = self.base_url_download + 'ICLabels_test.pkl'
  861. self.db_url = self.base_url_download + 'anonymized_database.sqlite'
  862. self.cls_url = self.base_url_download + 'other_classifiers.mat'
  863. # util
  864. @staticmethod
  865. def __load_matlab_cellstr(f, var_name=''):
  866. var = []
  867. if var_name:
  868. for column in f[var_name]:
  869. row_data = []
  870. for row_number in range(len(column)):
  871. row_data.append(''.join(map(unichr, f[column[row_number]][:])))
  872. var.append(row_data)
  873. return [str(x)[3:-2] for x in var]
  874. @staticmethod
  875. def __match_indices(*indices):
  876. """ Match sets of multidimensional ids/indices when there is a 1-1 relationtionship """
  877. # find matching indices
  878. index = np.concatenate(indices) # array of values
  879. _, duplicates, counts = np.unique(index, return_inverse=True, return_counts=True, axis=0)
  880. duplicates = np.split(duplicates, np.cumsum([x.shape[0] for x in indices[:-1]]), 0) # list of vectors of ints
  881. sufficient_counts = np.where(counts == len(indices))[0] # vector of ints
  882. matching_indices = [np.where(np.in1d(x, sufficient_counts))[0] for x in duplicates] # list of vectors of ints
  883. indices = [y[x] for x, y in zip(matching_indices, indices)] # list of arrays of values
  884. # organize to match first index array
  885. try:
  886. sort_inds = [np.lexsort(np.fliplr(x).T) for x in indices]
  887. except ValueError:
  888. sort_inds = [np.argsort(x) for x in indices]
  889. out = np.array([x[y[sort_inds[0]]] for x, y in zip(matching_indices, sort_inds)])
  890. return out
  891. # data access
  892. def load_data(self):
  893. """
  894. Load the ICL dataset in an unprocessed form.
  895. Follows the settings provided during initializations
  896. :return: Dictionary of unprocessed but matched feature-sets and labels.
  897. """
  898. start = time()
  899. # organize info
  900. if self.transform in (None, 'none'):
  901. if self.label_type == 'all':
  902. file_name = 'ICLabels_all.pkl'
  903. elif self.label_type == 'luca':
  904. file_name = 'ICLabels_onlyluca.pkl'
  905. processed_file_name = 'processed_dataset'
  906. if self.unique:
  907. processed_file_name += '_unique'
  908. if self.label_type == 'all':
  909. processed_file_name += '_all'
  910. self.check_for_download('train_labels')
  911. elif self.label_type == 'luca':
  912. processed_file_name += '_luca'
  913. self.check_for_download('train_labels')
  914. elif self.label_type == 'database':
  915. processed_file_name += '_database'
  916. self.check_for_download('database')
  917. processed_file_name += '.pkl'
  918. # load processed data file if it exists
  919. if isfile(join(self.datapath, 'cache', processed_file_name)):
  920. dataset = joblib.load(join(self.datapath, 'cache', processed_file_name))
  921. # if not, create it
  922. else:
  923. # load features
  924. features = []
  925. feature_labels = []
  926. print('Loading full dataset...')
  927. self.check_for_download('train_features')
  928. # topo maps, old psd, dipole, and handcrafted
  929. with h5py.File(join(self.datapath, 'features', 'features_0D1D2D.mat'), 'r') as f:
  930. print('Loading 0D1D2D features...')
  931. features.append(np.asarray(f['features']).T)
  932. feature_labels.append(self.__load_matlab_cellstr(f, 'labels'))
  933. # new psd
  934. with h5py.File(join(self.datapath, 'features', 'features_PSD_med_var_kurt.mat'), 'r') as f:
  935. print('Loading PSD features...')
  936. features.append(list())
  937. for element in f['features_out'][0]:
  938. data = np.array(f[element]).T
  939. # if no data, skip
  940. if data.ndim == 1 or data.dtype != np.float64:
  941. continue
  942. nyquist = (data.shape[1] - 2) / 3
  943. nfreq = 100
  944. # if more than nfreqs, remove extra
  945. if nyquist > nfreq:
  946. data = data[:, np.concatenate((range(2 + nfreq),
  947. range(2 + nyquist, 2 + nyquist + nfreq),
  948. range(2 + 2*nyquist, 2 + 2*nyquist + nfreq)))]
  949. # if less than nfreqs, repeat last frequency value
  950. elif nyquist < nfreq:
  951. data = data[:, np.concatenate((range(2 + nyquist),
  952. np.repeat(1 + nyquist, nfreq - nyquist),
  953. range(2 + nyquist, 2 + 2*nyquist),
  954. np.repeat(1 + 2*nyquist, nfreq - nyquist),
  955. range(2 + 2*nyquist, 2 + 3*nyquist),
  956. np.repeat(1 + 3*nyquist, nfreq - nyquist))
  957. ).astype(int)]
  958. features[-1].append(data)
  959. features[-1] = np.concatenate(features[-1], axis=0)
  960. feature_labels.append(['ID_set', 'ID_ic'] + ['psd_median']*nfreq + ['psd_var']*nfreq + ['psd_kurt']*nfreq)
  961. # autocorrelation
  962. with h5py.File(join(self.datapath, 'features', 'features_AutoCorr.mat'), 'r') as f:
  963. print('Loading AutoCorr features...')
  964. features.append(list())
  965. for element in f['features_out'][0]:
  966. data = np.array(f[element]).T
  967. if data.size > 2 and data.shape[1] == 102 and not len(data.dtype):
  968. features[-1].append(data)
  969. features[-1] = np.concatenate(features[-1], axis=0)
  970. feature_labels.append(self.__load_matlab_cellstr(f, 'feature_labels')[:2] + ['Autocorr'] * 100)
  971. # find topomap duplicates
  972. print('Finding topo duplicates...')
  973. _, duplicate_order = np.unique(features[0][:, 2:742].astype(np.float32), return_inverse=True, axis=0)
  974. do_sortind = np.argsort(duplicate_order)
  975. do_sorted = duplicate_order[do_sortind]
  976. do_indices = np.where(np.diff(np.concatenate(([-1], do_sorted))))[0]
  977. group2indices = [do_sortind[do_indices[x]:do_indices[x + 1]] for x in range(0, duplicate_order.max())]
  978. del _
  979. # load labels
  980. if self.label_type == 'database':
  981. # load data from database
  982. conn = sqlite3.connect(join(self.datapath, 'labels', 'database.sqlite'))
  983. c = conn.cursor()
  984. dblabels = c.execute('SELECT * FROM labels '
  985. 'INNER JOIN images ON labels.image_id = images.id '
  986. 'WHERE user_id IN '
  987. '(SELECT user_id FROM labels '
  988. 'GROUP BY user_id '
  989. 'HAVING COUNT(*) >= 30)'
  990. ).fetchall()
  991. conn.close()
  992. # reformat as list of ndarrays
  993. dblabels = [(x[1], np.array(x[15:17]), np.array(x[3:11])) for x in dblabels]
  994. dblabels = [np.stack(x) for x in zip(*dblabels)]
  995. # organize labels by image
  996. udb = np.unique(dblabels[1], return_inverse=True, axis=0)
  997. dblabels = [(dblabels[0][y], dblabels[1][y][0], dblabels[2][y]) for y in (udb[1] == x for x in range(len(udb[0])))]
  998. label_index = np.stack((x[1] for x in dblabels))
  999. elif self.label_type == 'luca':
  1000. # load data from database
  1001. conn = sqlite3.connect(join(self.datapath, 'labels', 'database.sqlite'))
  1002. c = conn.cursor()
  1003. dblabelsluca = c.execute('SELECT * FROM labels '
  1004. 'INNER JOIN images ON labels.image_id = images.id '
  1005. 'WHERE user_id = 1').fetchall()
  1006. conn.close()
  1007. # remove low-confidence labels
  1008. dblabelsluca = [x for x in dblabelsluca if x[10] == 0]
  1009. # reformat as ndarray
  1010. labels = np.array([x[3:10] for x in dblabelsluca]).astype(np.float32)
  1011. labels /= labels.sum(1, keepdims=True)
  1012. labels = [labels]
  1013. label_index = np.array([x[15:17] for x in dblabelsluca])
  1014. transforms = ['none']
  1015. else:
  1016. # load labels from files
  1017. with open(join(self.datapath, 'labels', file_name), 'rb') as f:
  1018. print('Loading labels...')
  1019. data = pkl.load(f)
  1020. if 'transform' in data.keys():
  1021. transforms = data['transform']
  1022. else:
  1023. transforms = ['none']
  1024. labels = data['labels']
  1025. if isinstance(labels, np.ndarray):
  1026. labels = [labels]
  1027. if 'labels_cov' in data.keys():
  1028. label_cov = data['labels_cov']
  1029. label_index = np.stack((data['instance_set_numbers'], data['instance_ic_numbers'])).T
  1030. del data
  1031. # match components and labels
  1032. print('Matching components and labels...')
  1033. temp = self.__match_indices(label_index.astype(np.int), features[0][:, :2].astype(np.int))
  1034. label2component = dict(zip(*temp))
  1035. del temp
  1036. # match feature-sets
  1037. print('Matching features...')
  1038. feature_inds = self.__match_indices(*[x[:, :2].astype(np.int) for x in features])
  1039. # check which labels are not kept
  1040. print('Rearanging components and labels...')
  1041. kept_labels = [x for x, y in label2component.iteritems() if y in feature_inds[0]]
  1042. dropped_labels = [x for x, y in label2component.iteritems() if y not in feature_inds[0]]
  1043. # for each label, pick a new component that is kept (if any)
  1044. ind_n_data_points = [x for x, y in enumerate(feature_labels[0]) if y == 'number of data points'][0]
  1045. for ind in dropped_labels:
  1046. group = duplicate_order[label2component[ind]]
  1047. candidate_components = np.intersect1d(group2indices[group], feature_inds[0])
  1048. # if more than one choice, pick the one from the dataset with the most samples unless one from this
  1049. # group has already been found
  1050. if len(candidate_components) >= 1:
  1051. if len(candidate_components) == 1:
  1052. new_index = features[0][candidate_components, :2]
  1053. else:
  1054. new_index = features[0][candidate_components[features[0][candidate_components,
  1055. ind_n_data_points].argmax()], :2]
  1056. if not (new_index == label_index[dropped_labels]).all(1).any() \
  1057. and not any([(x == label_index[kept_labels]).all(1).any()
  1058. for x in features[0][candidate_components, :2]]):
  1059. label_index[ind] = new_index
  1060. del label2component, kept_labels, dropped_labels, duplicate_order
  1061. # feature labels (change with features)
  1062. psd_lims = np.where(np.char.startswith(feature_labels[0], 'psd'))[0][[0, -1]]
  1063. feature_labels = np.concatenate((feature_labels[0][:psd_lims[0]],
  1064. feature_labels[0][psd_lims[1] + 1:],
  1065. feature_labels[1][2:],
  1066. feature_labels[2][2:]))
  1067. # combine features, keeping only components with all features
  1068. print('Combining feature-sets...')
  1069. def index_features(data, new_index):
  1070. return np.concatenate((data[0][feature_inds[0][new_index], :psd_lims[0]].astype(np.float32),
  1071. data[0][feature_inds[0][new_index], psd_lims[1] + 1:].astype(np.float32),
  1072. data[1][feature_inds[1][new_index], 2:].astype(np.float32),
  1073. data[2][feature_inds[2][new_index], 2:].astype(np.float32)),
  1074. axis=1)
  1075. # rematch with labels
  1076. print('Rematching components and labels...')
  1077. ind_labeled_labels, ind_labeled_features = self.__match_indices(label_index.astype(np.int),
  1078. features[0][feature_inds[0], :2].astype(np.int))
  1079. del label_index
  1080. # find topomap duplicates
  1081. _, duplicate_order = np.unique(features[0][feature_inds[0], 2:742].astype(np.float32), return_inverse=True, axis=0)
  1082. do_sortind = np.argsort(duplicate_order)
  1083. do_sorted = duplicate_order[do_sortind]
  1084. do_indices = np.where(np.diff(np.concatenate(([-1], do_sorted))))[0]
  1085. group2indices = [do_sortind[do_indices[x]:do_indices[x + 1]] for x in range(0, duplicate_order.max())]
  1086. # aggregate data
  1087. dataset = dict()
  1088. try:
  1089. dataset['transform'] = transforms
  1090. except UnboundLocalError:
  1091. pass
  1092. if self.label_type == 'database':
  1093. dataset['labeled_labels'] = [dblabels[x] for x in np.where(ind_labeled_labels)[0]]
  1094. else:
  1095. dataset['labeled_labels'] = [x[ind_labeled_labels, :] for x in labels]
  1096. if 'label_cov' in locals():
  1097. dataset['labeled_label_covariances'] = [x[ind_labeled_labels, :].astype(np.float32) for x in label_cov]
  1098. dataset['labeled_features'] = index_features(features, ind_labeled_features)
  1099. # find equivalent datasets with most samples
  1100. unlabeled_groups = [x for it, x in enumerate(group2indices) if not np.intersect1d(x, ind_labeled_features).size]
  1101. ndata = features[0][feature_inds[0]][:, ind_n_data_points]
  1102. ind_unique_unlabled = [x[ndata[x].argmax()] for x in unlabeled_groups]
  1103. dataset['unlabeled_features'] = index_features(features, ind_unique_unlabled)
  1104. # close h5py pscorr file and clean workspace
  1105. del features, group2indices
  1106. try:
  1107. del labels
  1108. except NameError:
  1109. del dblabels
  1110. if 'label_cov' in locals():
  1111. del label_cov
  1112. # remove inf columns
  1113. print('Cleaning data of infs...')
  1114. inf_col = [ind for ind, x in enumerate(feature_labels) if x == 'SASICA snr'][0]
  1115. feature_labels = np.delete(feature_labels, inf_col)
  1116. dataset['unlabeled_features'] = np.delete(dataset['unlabeled_features'], inf_col, axis=1)
  1117. dataset['labeled_features'] = np.delete(dataset['labeled_features'], inf_col, axis=1)
  1118. # remove nan total_rows
  1119. print('Cleaning data of nans...')
  1120. # unlabeled
  1121. unlabeled_not_nan_inf_index = np.logical_not(
  1122. np.logical_or(np.isnan(dataset['unlabeled_features']).any(axis=1),
  1123. np.isinf(dataset['unlabeled_features']).any(axis=1)))
  1124. dataset['unlabeled_features'] = \
  1125. dataset['unlabeled_features'][unlabeled_not_nan_inf_index, :]
  1126. # labeled
  1127. labeled_not_nan_inf_index = np.logical_not(np.logical_or(np.isnan(dataset['labeled_features']).any(axis=1),
  1128. np.isinf(dataset['labeled_features']).any(axis=1)))
  1129. dataset['labeled_features'] = dataset['labeled_features'][labeled_not_nan_inf_index, :]
  1130. if self.label_type == 'database':
  1131. dataset['labeled_labels'] = [dataset['labeled_labels'][x] for x in np.where(labeled_not_nan_inf_index)[0]]
  1132. else:
  1133. dataset['labeled_labels'] = [x[labeled_not_nan_inf_index, :] for x in dataset['labeled_labels']]
  1134. if 'labeled_label_covariances' in dataset.keys():
  1135. dataset['labeled_label_covariances'] = [x[labeled_not_nan_inf_index, :, :]
  1136. for x in dataset['labeled_label_covariances']]
  1137. if not self.unique:
  1138. dataset['unlabeled_duplicates'] = dataset['unlabeled_duplicates'][unlabeled_not_nan_inf_index]
  1139. dataset['labeled_duplicates'] = dataset['labeled_duplicates'][labeled_not_nan_inf_index]
  1140. # save feature labels (names, e.g. psd)
  1141. dataset['feature_labels'] = feature_labels
  1142. # save the results
  1143. print('Saving aggregated dataset...')
  1144. joblib.dump(dataset, join(self.datapath, 'cache', processed_file_name), 0)
  1145. # print time
  1146. total = time() - start
  1147. print('Time to load: ' + strftime("%H:%M:%S", gmtime(total)) +
  1148. ':' + np.mod(total, 1).astype(str)[2:5] + '\t(HH:MM:SS:sss)')
  1149. return dataset
  1150. def load_semi_supervised(self):
  1151. """
  1152. Load the ICL dataset where only a fraction of data points are labeled.
  1153. Follows the settings provided during initializations
  1154. :return: (train set unlabeled, train set labeled, sample test set (unlabeled), validation set (labeled), output labels)
  1155. """
  1156. rng = np.random.RandomState(seed=self.seed)
  1157. start = time()
  1158. # get data
  1159. icl = self.load_data()
  1160. # copy full dataset
  1161. icl['unlabeled_features'] = \
  1162. OrderedDict([(key, icl['unlabeled_features'][:, ind]) for key, ind
  1163. in self.train_feature_indices.iteritems() if key in self.features])
  1164. icl['labeled_features'] = \
  1165. OrderedDict([(key, icl['labeled_features'][:, ind]) for key, ind
  1166. in self.train_feature_indices.iteritems() if key in self.features])
  1167. # set ids to int
  1168. icl['unlabeled_features']['ids'] = icl['unlabeled_features']['ids'].astype(int)
  1169. icl['labeled_features']['ids'] = icl['labeled_features']['ids'].astype(int)
  1170. # decide how to split into train / validation / test
  1171. # validation set of random labeled components for overfitting / convergence estimation
  1172. try:
  1173. valid_ind = rng.choice(icl['labeled_features']['ids'].shape[0], size=100, replace=False)
  1174. except:
  1175. valid_ind = rng.choice(icl['labeled_features']['ids'].shape[0], size=100, replace=True)
  1176. # random unlabeled datasets for manual analysis
  1177. test_datasets = rng.choice(np.unique(icl['unlabeled_features']['ids'][:, 0]),
  1178. size=self.n_test_datasets, replace=False)
  1179. test_ind = np.where(np.array([x == icl['unlabeled_features']['ids'][:, 0] for x in test_datasets]).any(0))[0]
  1180. # normalize other features
  1181. if 'topo' in self.features:
  1182. print('Normalizing topo features...')
  1183. icl['unlabeled_features']['topo'], pca = self.normalize_topo_features(icl['unlabeled_features']['topo'])
  1184. icl['labeled_features']['topo'] = self.normalize_topo_features(icl['labeled_features']['topo'], pca)[0]
  1185. # normalize psd features
  1186. if 'psd' in self.features:
  1187. print('Normalizing psd features...')
  1188. icl['unlabeled_features']['psd'] = self.normalize_psd_features(icl['unlabeled_features']['psd'])
  1189. icl['labeled_features']['psd'] = self.normalize_psd_features(icl['labeled_features']['psd'])
  1190. # normalize psd_var features
  1191. if 'psd_var' in self.features:
  1192. print('Normalizing psd_var features...')
  1193. icl['unlabeled_features']['psd_var'] = self.normalize_psd_features(icl['unlabeled_features']['psd_var'])
  1194. icl['labeled_features']['psd_var'] = self.normalize_psd_features(icl['labeled_features']['psd_var'])
  1195. # normalize psd_kurt features
  1196. if 'psd_kurt' in self.features:
  1197. print('Normalizing psd_kurt features...')
  1198. icl['unlabeled_features']['psd_kurt'] = self.normalize_psd_features(icl['unlabeled_features']['psd_kurt'])
  1199. icl['labeled_features']['psd_kurt'] = self.normalize_psd_features(icl['labeled_features']['psd_kurt'])
  1200. # normalize psd_kurt features
  1201. if 'autocorr' in self.features:
  1202. print('Normalizing autocorr features...')
  1203. icl['unlabeled_features']['autocorr'] = self.normalize_autocorr_features(icl['unlabeled_features']['autocorr'])
  1204. icl['labeled_features']['autocorr'] = self.normalize_autocorr_features(icl['labeled_features']['autocorr'])
  1205. # normalize dipole features
  1206. if 'dipole' in self.features:
  1207. print('Normalizing dipole features...')
  1208. icl['unlabeled_features']['dipole'] = self.normalize_dipole_features(icl['unlabeled_features']['dipole'])
  1209. icl['labeled_features']['dipole'] = self.normalize_dipole_features(icl['labeled_features']['dipole'])
  1210. # normalize handcrafted features
  1211. if 'handcrafted' in self.features:
  1212. print('Normalizing hand-crafted features...')
  1213. icl['unlabeled_features']['handcrafted'] = \
  1214. self.normalize_handcrafted_features(icl['unlabeled_features']['handcrafted'],
  1215. icl['unlabeled_features']['ids'][:, 1])
  1216. icl['labeled_features']['handcrafted'] = \
  1217. self.normalize_handcrafted_features(icl['labeled_features']['handcrafted'], icl['labeled_features']['ids'][:, 1])
  1218. # normalize mi features
  1219. if 'mi' in self.features:
  1220. print('Normalizing mi features...')
  1221. icl['unlabeled_features']['mi'] = self.normalize_mi_features(icl['unlabeled_features']['mi'])
  1222. icl['labeled_features']['mi'] = self.normalize_mi_features(icl['labeled_features']['mi'])
  1223. # recast labels
  1224. if self.label_type == 'database':
  1225. pass
  1226. else:
  1227. icl['labeled_labels'] = [x.astype(np.float32) for x in icl['labeled_labels']]
  1228. if 'labeled_label_covariances' in icl.keys():
  1229. icl['labeled_label_covariances'] = [x.astype(np.float32) for x in icl['labeled_label_covariances']]
  1230. # separate data into train, validation, and test sets
  1231. print('Splitting and shuffling data...')
  1232. # unlabeled training set
  1233. ind = rng.permutation(np.setdiff1d(range(icl['unlabeled_features']['ids'].shape[0]), test_ind))
  1234. x_u = OrderedDict([(key, val[ind]) for key, val in icl['unlabeled_features'].iteritems()])
  1235. y_u = None
  1236. # labeled training set
  1237. ind = rng.permutation(np.setdiff1d(range(icl['labeled_features']['ids'].shape[0]), valid_ind))
  1238. x_l = OrderedDict([(key, val[ind]) for key, val in icl['labeled_features'].iteritems()])
  1239. if self.label_type == 'database':
  1240. print(icl['labeled_labels'][0])
  1241. y_l = [icl['labeled_labels'][x] for x in ind]
  1242. else:
  1243. y_l = [x[ind] for x in icl['labeled_labels']]
  1244. if 'labeled_label_covariances' in icl.keys():
  1245. c_l = [x[ind] for x in icl['labeled_label_covariances']]
  1246. # validation set.
  1247. rng.shuffle(valid_ind)
  1248. x_v = OrderedDict([(key, val[valid_ind]) for key, val in icl['labeled_features'].iteritems()])
  1249. if self.label_type == 'database':
  1250. y_v = [icl['labeled_labels'][x] for x in valid_ind]
  1251. else:
  1252. y_v = [x[valid_ind] for x in icl['labeled_labels']]
  1253. if 'labeled_label_covariances' in icl.keys():
  1254. c_v = [x[valid_ind] for x in icl['labeled_label_covariances']]
  1255. # unlabeled test set.
  1256. rng.shuffle(test_ind)
  1257. x_t = OrderedDict([(key, val[test_ind]) for key, val in icl['unlabeled_features'].iteritems()])
  1258. y_t = None
  1259. train_u = (x_u, y_u)
  1260. if 'labeled_label_covariances' in icl.keys():
  1261. train_l = (x_l, y_l, c_l)
  1262. else:
  1263. train_l = (x_l, y_l)
  1264. test = (x_t, y_t)
  1265. if 'labeled_label_covariances' in icl.keys():
  1266. val = (x_v, y_v, c_v)
  1267. else:
  1268. val = (x_v, y_v)
  1269. # print time
  1270. total = time() - start
  1271. print('Time to load: ' + strftime("%H:%M:%S", gmtime(total)) +
  1272. ':' + np.mod(total, 1).astype(str)[2:5] + '\t(HH:MM:SS:sss)')
  1273. return train_u, train_l, test, val, \
  1274. ('train_unlabeled', 'train_labeled', 'test', 'validation', 'labels')
  1275. def load_test_data(self, process_features=True):
  1276. """
  1277. Load the ICL test dataset used in the publication.
  1278. Follows the settings provided during initializations.
  1279. :param process_features: Whether to preprocess/normalize features.
  1280. :return: (features, labels)
  1281. """
  1282. # check for files and download if missing
  1283. self.check_for_download(('test_labels', 'test_features'))
  1284. # load features
  1285. with h5py.File(join(self.datapath, 'features', 'features_testset_full.mat'), 'r') as f:
  1286. features = np.asarray(f['features']).T
  1287. feature_labels = self.__load_matlab_cellstr(f, 'feature_label')
  1288. # load labels
  1289. with open(join(self.datapath, 'labels', 'ICLabels_test.pkl'), 'rb') as f:
  1290. labels = pkl.load(f)
  1291. # match features and labels
  1292. _, _, ind = np.intersect1d(labels['instance_id'], labels['instance_number'], return_indices=True)
  1293. label_id = np.stack((labels['instance_study_numbers'][ind],
  1294. labels['instance_set_numbers'][ind],
  1295. labels['instance_ic_numbers'][ind]), axis=1)
  1296. feature_id = features[:, :3].astype(int)
  1297. match = self.__match_indices(label_id, feature_id)
  1298. features = features[match[1, :][match[0, :]], :]
  1299. # remove inf columns
  1300. print('Cleaning data of infs...')
  1301. inf_col = [ind for ind, x in enumerate(feature_labels) if x == 'SASICA snr'][0]
  1302. feature_labels = np.delete(feature_labels, inf_col)
  1303. features = np.delete(features, inf_col, axis=1)
  1304. # convert to ordered dict
  1305. features = \
  1306. OrderedDict([(key, features[:, ind]) for key, ind
  1307. in self.test_feature_indices.iteritems() if key in self.features])
  1308. # process features
  1309. if process_features:
  1310. # normalize other features
  1311. if 'topo' in self.features:
  1312. print('Normalizing topo features...')
  1313. features['topo'] = self.normalize_topo_features(features['topo'])
  1314. # normalize psd features
  1315. if 'psd' in self.features:
  1316. print('Normalizing psd features...')
  1317. features['psd'] = self.normalize_psd_features(features['psd'])
  1318. # normalize psd_var features
  1319. if 'psd_var' in self.features:
  1320. print('Normalizing psd_var features...')
  1321. features['psd_var'] = self.normalize_psd_features(features['psd_var'])
  1322. # normalize psd_kurt features
  1323. if 'psd_kurt' in self.features:
  1324. print('Normalizing psd_kurt features...')
  1325. features['psd_kurt'] = self.normalize_psd_features(features['psd_kurt'])
  1326. # normalize psd_kurt features
  1327. if 'autocorr' in self.features:
  1328. print('Normalizing autocorr features...')
  1329. features['autocorr'] = self.normalize_autocorr_features(features['autocorr'])
  1330. # normalize dipole features
  1331. if 'dipole' in self.features:
  1332. print('Normalizing dipole features...')
  1333. features['dipole'] = self.normalize_dipole_features(features['dipole'])
  1334. # normalize handcrafted features
  1335. if 'handcrafted' in self.features:
  1336. print('Normalizing hand-crafted features...')
  1337. features['handcrafted'] = self.normalize_handcrafted_features(features['handcrafted'],
  1338. features['ids'][:, 1])
  1339. return features, labels
  1340. def load_classifications(self, n_cls, ids=None):
  1341. """
  1342. Load classification of the ICLabel training set by several published and publicly available IC classifiers.
  1343. Classifiers included are MARA, ADJUST, FASTER, IC_MARC, and EyeCatch. MARA, and FASTER are only included in
  1344. the 2 class case. ADJUST is also included in the 3-class case. IC_MARC and EyeCatch are included in all
  1345. cases. Note that EyeCatch only has two classes (Eye and Not-Eye) but does not follow the patter of label
  1346. conflation used for the other classifiers as it has not Brain IC class.
  1347. :param n_cls: How many IC classes to consider. Must be 2, 3, or 5.
  1348. :param ids: If only a subset of ICs are desired, the relevant IC IDs may be passed here as an (n by 2) ndarray.
  1349. :return: Dictionary of classifications separated by classifier.
  1350. """
  1351. # check inputs
  1352. assert(n_cls in (2, 3, 5), 'n_cls must be 2, 3, or 5')
  1353. # load raw classifications
  1354. raw = self._load_classifications(ids)
  1355. # format and limit to number of desired classes
  1356. # 2: brain, other
  1357. # 3: brain, eye, other
  1358. # 5: brain, muscle, eye, heart, other
  1359. # exception for eye_catch which is always [eye] where eye >= 0.93 is the threshold for detection
  1360. classifications = {}
  1361. for cls, lab in raw.iteritems():
  1362. if cls == 'adjust':
  1363. if n_cls == 2:
  1364. non_brain = raw[cls].max(1, keepdims=True)
  1365. classifications[cls] = np.concatenate((1 - non_brain, non_brain), 1)
  1366. elif n_cls == 3:
  1367. brain = 1 - raw[cls].max(1, keepdims=True)
  1368. eye = raw[cls][:, :-1].max(1, keepdims=True)
  1369. other = raw[cls][:, -1:]
  1370. classifications[cls] = np.concatenate((brain, eye, other), 1)
  1371. elif cls == 'mara':
  1372. if n_cls == 2:
  1373. classifications[cls] = np.concatenate((1 - raw[cls], raw[cls]), 1)
  1374. elif cls == 'faster':
  1375. if n_cls == 2:
  1376. classifications[cls] = np.concatenate((1 - raw[cls], raw[cls]), 1)
  1377. elif cls == 'ic_marc': # ['blink', 'neural', 'heart', 'lat. eye', 'muscle', 'mixed']
  1378. brain = raw[cls][:, 1:2]
  1379. if n_cls == 2:
  1380. classifications[cls] = np.concatenate((brain, 1 - brain), 1)
  1381. elif n_cls == 3:
  1382. eye = raw[cls][:, [0, 3]].sum(1, keepdims=True)
  1383. other = raw[cls][:, [2, 4, 5]].sum(1, keepdims=True)
  1384. classifications[cls] = np.concatenate((brain, eye, other), 1)
  1385. elif n_cls == 5:
  1386. muscle = raw[cls][:, 4:5]
  1387. eye = raw[cls][:, [0, 3]].sum(1, keepdims=True)
  1388. heart = raw[cls][:, 2:3]
  1389. other = raw[cls][:, 5:]
  1390. classifications[cls] = np.concatenate((brain, muscle, eye, heart, other), 1)
  1391. elif cls == 'eye_catch':
  1392. classifications[cls] = raw[cls]
  1393. else:
  1394. raise UserWarning('Unknown classifier: {}'.format(cls))
  1395. # return
  1396. return classifications
  1397. def _load_classifications(self, ids=None):
  1398. # check for files and download if missing
  1399. self.check_for_download('classifications')
  1400. # load classifications
  1401. classifications = {}
  1402. with h5py.File(join(self.datapath, 'other', 'other_classifiers.mat'), 'r') as f:
  1403. print('Loading classifications...')
  1404. for cls, lab in f.iteritems():
  1405. classifications[cls] = lab[:].T
  1406. # match to given ids
  1407. if ids is not None:
  1408. for cls, lab in classifications.iteritems():
  1409. _, ind_id, ind_lab = np.intersect1d((ids * [100, 1]).sum(1), (lab[:, :2].astype(int) * [100, 1]).sum(1),
  1410. return_indices=True)
  1411. classifications[cls] = np.empty((ids.shape[0], lab.shape[1] - 2))
  1412. classifications[cls][:] = np.nan
  1413. classifications[cls][ind_id] = lab[ind_lab, 2:]
  1414. return classifications
  1415. def generate_cache(self, refresh=False):
  1416. """
  1417. Generate all possible training set cache files to speed up later requests.
  1418. :param refresh: If true, deletes previous cache files. Otherwise only missing cache files will be generated.
  1419. """
  1420. if refresh:
  1421. rmtree(join(self.datapath, 'cache'))
  1422. os.mkdir(join(self.datapath, 'cache'))
  1423. urexpert = copy(self.label_type)
  1424. for label_type in ('luca', 'all', 'database'):
  1425. self.label_type = label_type
  1426. self.load_data()
  1427. self.label_type = urexpert
  1428. # download
  1429. def _download(self, url, filename):
  1430. CHUNK = 16 * 1024
  1431. try:
  1432. f = urllib2.urlopen(url)
  1433. # Open our local file for writing
  1434. with open(filename, 'wb') as local_file:
  1435. while True:
  1436. chunk = f.read(CHUNK)
  1437. if not chunk:
  1438. break
  1439. local_file.write(chunk)
  1440. print('Done.')
  1441. except urllib2.HTTPError, e:
  1442. print "HTTP Error:", e.code, url
  1443. except urllib2.URLError, e:
  1444. print "URL Error:", e.reason, url
  1445. def download_trainset_cllabels(self):
  1446. """
  1447. Download labels for the ICLabel training set.
  1448. """
  1449. print('Downloading individual ICLabel training set CL label files...')
  1450. folder = 'labels'
  1451. if not isdir(join(self.datapath, folder)):
  1452. os.mkdir(join(self.datapath, folder))
  1453. for it, url in enumerate(self.label_train_urls):
  1454. print('Downloading label file {} of {}...'.format(it, len(self.label_train_urls)))
  1455. self._download(url, join(self.datapath, folder, basename(url)))
  1456. def download_trainset_features(self, zip=True):
  1457. """
  1458. Download features for the ICLabel training set.
  1459. :param zip: If true, downloads the zipped feature files. Otherwise individual files are downloaded.
  1460. """
  1461. print('Caution: this download is approximately 25GB and requires twice that space on your drive if unzipping!')
  1462. folder = 'features'
  1463. if zip:
  1464. print('Downloading zipped ICLabel training set features...')
  1465. if not isdir(join(self.datapath, folder)):
  1466. os.mkdir(join(self.datapath, folder))
  1467. zip_name = join(self.datapath, folder, 'features.zip')
  1468. self._download(self.feature_train_zip_url, zip_name)
  1469. print('Extracting zipped ICLabel training set features...')
  1470. from zipfile import ZipFile
  1471. with ZipFile(zip_name) as myzip:
  1472. myzip.extractall(path=join(self.datapath, folder))
  1473. print('Deleting zip archive...')
  1474. os.remove(zip_name)
  1475. else:
  1476. print('Downloading individual ICLabel training set feature files...')
  1477. if not isdir(join(self.datapath, folder)):
  1478. os.mkdir(join(self.datapath, folder))
  1479. for it, url in enumerate(self.feature_train_urls):
  1480. print('Downloading feature file {} of {}...'.format(it, len(self.feature_train_urls)))
  1481. self._download(url, join(self.datapath, 'labels', basename(url)))
  1482. def download_testset_cllabels(self):
  1483. """
  1484. Download labels for the ICLabel test set.
  1485. """
  1486. print('Downloading ICLabel test set CL label files...')
  1487. folder = 'labels'
  1488. if not isdir(join(self.datapath, folder)):
  1489. os.mkdir(join(self.datapath, folder))
  1490. self._download(self.label_test_urls, join(self.datapath, folder, 'ICLabels_test.pkl'))
  1491. def download_testset_features(self):
  1492. """
  1493. Download features for the ICLabel test set.
  1494. """
  1495. print('Downloading ICLabel test set features...')
  1496. folder = 'features'
  1497. if not isdir(join(self.datapath, folder)):
  1498. os.mkdir(join(self.datapath, folder))
  1499. self._download(self.feature_test_urls, join(self.datapath, folder, 'features_testset_full.mat'))
  1500. def download_database(self):
  1501. """
  1502. Download anonymized ICLabel website database.
  1503. """
  1504. print('Downloading anonymized ICLabel website database...')
  1505. folder = 'labels'
  1506. if not isdir(join(self.datapath, folder)):
  1507. os.mkdir(join(self.datapath, folder))
  1508. self._download(self.db_url, join(self.datapath, folder, 'database.sqlite'))
  1509. def download_icclassifications(self):
  1510. """
  1511. Download precalculated classification for several publicly available IC classifiers.
  1512. """
  1513. print('Downloading classifications for some publicly available classifiers...')
  1514. folder = 'other'
  1515. if not isdir(join(self.datapath, folder)):
  1516. os.mkdir(join(self.datapath, folder))
  1517. self._download(self.cls_url, join(self.datapath, folder, 'other_classifiers.mat'))
  1518. def check_for_download(self, data_type):
  1519. """
  1520. Check if something has been downloaded and, if not, get it.
  1521. :param data_type: What data to check for. Can be: train_labels, train, features, test_labels, test_features,
  1522. database, and/or 'classifications'.
  1523. """
  1524. if '__iter__' not in dir(data_type):
  1525. data_type = [data_type]
  1526. for val in data_type:
  1527. if val == 'train_labels':
  1528. for it, url in enumerate(self.label_train_urls):
  1529. if not isfile(join(self.datapath, 'labels', basename(url))):
  1530. self.download_trainset_cllabels()
  1531. elif val == 'train_features':
  1532. for it, url in enumerate(self.feature_train_urls):
  1533. assert isfile(join(self.datapath, 'features', basename(url))), \
  1534. 'Missing training feature file "' + basename(url) + '" and possibly others. ' \
  1535. 'It is a large download which you may accomplish through calling the method ' \
  1536. '"download_testset_features()".'
  1537. elif val == 'test_labels':
  1538. if not isfile(join(self.datapath, 'labels', 'ICLabels_test.pkl')):
  1539. self.download_testset_cllabels()
  1540. elif val == 'test_features':
  1541. if not isfile(join(self.datapath, 'features', 'features_testset.mat')):
  1542. self.download_testset_features()
  1543. elif val == 'database':
  1544. if not isfile(join(self.datapath, 'labels', 'database.sqlite')):
  1545. self.download_database()
  1546. elif val == 'classifications':
  1547. if not isfile(join(self.datapath, 'other', 'other_classifiers.mat')):
  1548. self.download_icclassifications()
  1549. # data normalization
  1550. @staticmethod
  1551. def _clip_and_rescale(vec, min, max):
  1552. return (np.clip(vec, min, max) - min) * 2. / (max - min) - 1
  1553. @staticmethod
  1554. def _unscale(vec, min, max):
  1555. return (vec + 1) * (max-min) / 2 + min
  1556. @staticmethod
  1557. def normalize_dipole_features(data):
  1558. """
  1559. Normalize dipole features.
  1560. :param data: dipole features
  1561. :return: normalized dipole features
  1562. """
  1563. # indices
  1564. ind_dipole_pos = np.array([1, 2, 3, 8, 9, 10, 14, 15, 16])
  1565. ind_dipole1_mom = np.array([4, 5, 6])
  1566. ind_dipole2_mom = np.array([11, 12, 13, 17, 18, 19])
  1567. ind_rv = np.array([0, 7])
  1568. # normalize dipole positions
  1569. data[:, ind_dipole_pos] /= 100
  1570. # clip dipole position
  1571. max_dist = 1.5
  1572. data[:, ind_dipole_pos] = np.clip(data[:, ind_dipole_pos], -max_dist, max_dist) / max_dist
  1573. # normalize single dipole moments
  1574. data[:, ind_dipole1_mom] /= np.abs(data[:, ind_dipole1_mom]).max(1, keepdims=True)
  1575. # normalize double dipole moments
  1576. data[:, ind_dipole2_mom] /= np.abs(data[:, ind_dipole2_mom]).max(1, keepdims=True)
  1577. # center residual variance
  1578. data[:, ind_rv] = data[:, ind_rv] * 2 - 1
  1579. return data.astype(np.float32)
  1580. def normalize_topo_features(self, data, pca=None):
  1581. """
  1582. Normalize scalp topography features.
  1583. :param data: scalp topography features
  1584. :param pca: A PCA matrix to use if for the test set if do_pca was set to true in __init__.
  1585. :return: (normalized dipole features, pca matrix or None)
  1586. """
  1587. # apply pca
  1588. if self.do_pca:
  1589. if pca is None:
  1590. pca = PCA(whiten=True)
  1591. pca.fit_transform(data)
  1592. else:
  1593. data = pca.transform(data)
  1594. # clip extreme values
  1595. data = np.clip(data, -2, 2)
  1596. else:
  1597. # normalize to norm 1
  1598. data /= np.linalg.norm(data, axis=1, keepdims=True)
  1599. return data.astype(np.float32), pca
  1600. def normalize_psd_features(self, data):
  1601. """
  1602. Normalize power spectral density features.
  1603. :param data: power spectral density features
  1604. :return: normalized power spectral density features
  1605. """
  1606. # undo notch filter
  1607. for linenoise_ind in (49, 59):
  1608. notch_ind = (
  1609. data[:, [linenoise_ind - 1, linenoise_ind + 1]] - data[:, linenoise_ind, np.newaxis] > 5).all(1)
  1610. data[notch_ind, linenoise_ind] = data[notch_ind][:, [linenoise_ind - 1, linenoise_ind + 1]].mean(1)
  1611. # divide by max abs
  1612. data /= np.amax(np.abs(data), axis=1, keepdims=True)
  1613. return data.astype(np.float32)
  1614. @staticmethod
  1615. def normalize_autocorr_features(data):
  1616. """
  1617. Normalize autocorrelation function features.
  1618. :param data: autocorrelation function features
  1619. :return: normalized autocorrelation function features
  1620. """
  1621. # normalize to max of 1
  1622. data[data > 1] = 1
  1623. return data.astype(np.float32)
  1624. def normalize_handcrafted_features(self, data, ic_nums):
  1625. """
  1626. Normalize hand crafted features.
  1627. :param data: hand crafted features
  1628. :param data: ic indices when sorted by power within their respective datasets. The 2nd ID number can be used for
  1629. this in the training dataset
  1630. :return: normalized handcrafted features
  1631. """
  1632. # autocorreclation
  1633. data[:, 0] = self._clip_and_rescale(data[:, 0], -0.5, 1.)
  1634. # SASICA focal topo
  1635. data[:, 1] = self._clip_and_rescale(data[:, 1], 1.5, 12.)
  1636. # SASICA snr REMOVED
  1637. # SASICA ic variance
  1638. data[:, 2] = self._clip_and_rescale(np.log(data[:, 2]), -6., 7.)
  1639. # ADJUST diff_var
  1640. data[:, 3] = self._clip_and_rescale(data[:, 3], -0.05, 0.06)
  1641. # ADJUST Temporal Kurtosis
  1642. data[:, 4] = self._clip_and_rescale(np.tanh(data[:, 4]), -0.5, 1.)
  1643. # ADJUST Spatial Eye Difference
  1644. data[:, 5] = self._clip_and_rescale(data[:, 5], 0., 0.4)
  1645. # ADJUST spatial average difference
  1646. data[:, 6] = self._clip_and_rescale(data[:, 6], -0.2, 0.25)
  1647. # ADJUST General Discontinuity Spatial Feature
  1648. # ADJUST maxvar/meanvar
  1649. data[:, 8] = self._clip_and_rescale(data[:, 8], 1., 20.)
  1650. # FASTER Median gradient value
  1651. data[:, 9] = self._clip_and_rescale(data[:, 9], -0.2, 0.2)
  1652. # FASTER Kurtosis of spatial map
  1653. data[:, 10] = self._clip_and_rescale(data[:, 10], -50., 100.)
  1654. # FASTER Hurst exponent
  1655. data[:, 11] = self._clip_and_rescale(data[:, 11], -0.2, 0.2)
  1656. # number of channels
  1657. # number of ICs
  1658. # ic number relative to number of channels
  1659. ic_rel = self._clip_and_rescale(ic_nums * 1. / data[:, 13], 0., 1.)
  1660. # topoplot plot radius
  1661. data[:, 12] = self._clip_and_rescale(data[:, 14], 0.5, 1)
  1662. # epoched?
  1663. # sampling rate
  1664. # number of data points
  1665. return np.hstack((data[:, :13], ic_rel.reshape(-1, 1))).astype(np.float32)
  1666. # plotting functions
  1667. @staticmethod
  1668. def _plot_grid(data, function):
  1669. nax = data.shape[0]
  1670. a = np.ceil(np.sqrt(nax)).astype(np.int)
  1671. b = np.ceil(1. * nax / a).astype(np.int)
  1672. f, axarr = plt.subplots(a, b, sharex='col', sharey='row')
  1673. axarr = axarr.flatten()
  1674. for x in range(nax):
  1675. function(data[x], axis=axarr[x])
  1676. axarr[x].set_title(str(x))
  1677. def pad_topo(self, data):
  1678. """
  1679. Reshape scalp topography images features and pad with zeros to make 32x32 pixel images.
  1680. :param data: Scalp topography features as provided by load_data() and load_semisupervised_data().
  1681. :return: Padded scalp topography images.
  1682. """
  1683. if data.ndim == 1:
  1684. ntopo = 1
  1685. else:
  1686. ntopo = data.shape[0]
  1687. topos = np.zeros((ntopo, 32 * 32))
  1688. topos[:, self.topo_ind] = data
  1689. topos = topos.reshape(-1, 32, 32).transpose(0, 2, 1)
  1690. return np.squeeze(topos)
  1691. def plot_topo(self, data, axis=plt):
  1692. """
  1693. Plot an IC scalp topography.
  1694. :param data: Scalp topography vector (unpadded).
  1695. :param axis: Optional matplotlib axis in which to plot.
  1696. """
  1697. topo = self.pad_topo(data)
  1698. topo = np.flipud(topo)
  1699. maxabs = np.abs(data).max()
  1700. axis.matshow(topo, cmap='jet', aspect='equal', vmin=-maxabs, vmax=maxabs)
  1701. def plot_topo_grid(self, data):
  1702. """
  1703. Plot a grid of IC scalp topographies.
  1704. :param data: Matrix of scalp topography vectors (unpadded).
  1705. """
  1706. if data.ndim == 1:
  1707. self.plot_topo(data)
  1708. else:
  1709. nax = data.shape[0]
  1710. if nax == 740:
  1711. data = data.T
  1712. nax = data.shape[0]
  1713. if nax > self.max_grid_plot:
  1714. print 'Too many plots requested.'
  1715. return
  1716. self._plot_grid(data, self.plot_topo)
  1717. def plot_psd(self, data, axis=plt):
  1718. """
  1719. Plot an IC power spectral density.
  1720. :param data: Power spectral density vector.
  1721. :param axis: Optional matplotlib axis in which to plot.
  1722. """
  1723. if self.psd_limits is not None:
  1724. data = self._unscale(data, *self.psd_limits)
  1725. if self.psd_mean is not None:
  1726. data = data + self.psd_mean
  1727. axis.plot(self.psd_ind[:data.flatten().shape[0]], data.flatten())
  1728. def plot_psd_grid(self, data):
  1729. """
  1730. Plot a grid of IC power spectral densities.
  1731. :param data: Matrix of power spectral density vectors.
  1732. """
  1733. if data.ndim == 1:
  1734. self.plot_psd(data)
  1735. else:
  1736. nax = data.shape[0]
  1737. if nax > self.max_grid_plot:
  1738. print 'Too many plots requested.'
  1739. return
  1740. self._plot_grid(data, self.plot_psd)
  1741. @staticmethod
  1742. def plot_autocorr(data, axis=plt):
  1743. """
  1744. Plot an IC autocorrelation function.
  1745. :param data: autocorrelation function vector.
  1746. :param axis: Optional matplotlib axis in which to plot.
  1747. """
  1748. axis.plot(np.linspace(0, 1, 101)[1:], data.flatten())
  1749. def plot_autocorr_grid(self, data):
  1750. """
  1751. Plot a grid of IC autocorrelation functions.
  1752. :param data: Matrix of autocorrelation function vectors.
  1753. """
  1754. if data.ndim == 1:
  1755. self.plot_autocorr(data)
  1756. else:
  1757. nax = data.shape[0]
  1758. if nax > self.max_grid_plot:
  1759. print 'Too many plots requested.'
  1760. return
  1761. self._plot_grid(data, self.plot_autocorr)
  1762. def web_image(self, component_id):
  1763. """
  1764. Open the component properties image from the ICLabel website (iclabel.ucsd.edu) for an IC. Not all ICs have
  1765. images available.
  1766. :param component_id: ID for the component which can be either 2 or 3 numbers if from the training set or test
  1767. set, respectively.
  1768. """
  1769. if len(component_id) == 2:
  1770. wb.open_new_tab(self.base_url_image + '{0:0>6}_{1:0>3}.png'.format(*component_id))
  1771. elif len(component_id) == 3:
  1772. wb.open_new_tab(self.base_url_image + '{0:0>2}_{1:0>2}_{2:0>3}.png'.format(*component_id))
  1773. else:
  1774. raise ValueError('component_id must have 2 or 3 elements.')