RegMaxSN.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. import os
  2. import numpy as np
  3. from regmaxsn.core.iterativeRegistration import IterativeRegistration, composeRefSWC, calcOverlap, getRemainderScale
  4. import shutil
  5. import json
  6. import sys
  7. from regmaxsn.core.transforms import decompose_matrix
  8. from regmaxsn.core.swcFuncs import transSWC
  9. from regmaxsn.core.misc import parFileCheck
  10. from regmaxsn.core.occupancyBasedMeasure import occupancyEMD
  11. def normalizeFinally(ipFiles, resDir, opFiles, fnwrtName, maxIter):
  12. itersAll = sorted([int(fle[3:-4]) for fle in os.listdir(resDir) if fle.find('ref') == 0])
  13. iters = [x for x in itersAll if x <= maxIter]
  14. totalTrans = np.eye(4)
  15. for iter in iters:
  16. solFle = os.path.join(resDir, fnwrtName + str(iter) + 'Sol.txt')
  17. if os.path.exists(solFle):
  18. with open(solFle, 'r') as f:
  19. pars = json.load(f)
  20. totalTrans = np.dot(pars['finalTransMat'], totalTrans)
  21. iTrans = np.linalg.inv(totalTrans)
  22. for ipFile, opFile in zip(ipFiles, opFiles):
  23. transSWC(ipFile, iTrans[:3, :3], iTrans[:3, 3], opFile)
  24. ipDir, ipName = os.path.split(ipFile[:-4])
  25. partsDir = os.path.join(ipDir, ipName)
  26. if os.path.isdir(partsDir):
  27. normedPartsDir = opFile[:-4]
  28. os.mkdir(normedPartsDir)
  29. swcs = [x for x in os.listdir(partsDir) if x.endswith('.swc')]
  30. for swc in swcs:
  31. opPart = os.path.join(normedPartsDir, swc)
  32. ipPart = os.path.join(partsDir, swc)
  33. transSWC(ipPart, iTrans[:3, :3], iTrans[:3, 3], opPart)
  34. def runRegMaxSN(parFile, parNames):
  35. assert os.path.isfile(parFile), "{} not found".format(parFile)
  36. ch = raw_input('Using parameter File {}.\n Continue?(y/n)'.format(parFile))
  37. if ch != 'y':
  38. print('User Abort!')
  39. sys.exit()
  40. parsList = parFileCheck(parFile, parNames)
  41. for pars in parsList:
  42. resDir = pars['resDir']
  43. refSWC = pars['initRefSWC']
  44. swcList = pars['swcList']
  45. fnwrt = pars['finallyNormalizeWRT']
  46. if os.path.isdir(resDir):
  47. ch = raw_input('Folder exists: ' + resDir + '\nDelete(y/n)?')
  48. if ch == 'y':
  49. shutil.rmtree(resDir)
  50. else:
  51. quit()
  52. try:
  53. os.mkdir(resDir)
  54. except Exception as e:
  55. raise(IOError('Could not create {}'.format(resDir)))
  56. assert os.path.isfile(refSWC), 'Could not find {}'.format(refSWC)
  57. for swc in swcList:
  58. assert os.path.isfile(swc), 'Could not find {}'.format(swc)
  59. assert swc.endswith('.swc'), 'Elements of swcList must be of SWC format with extension \'.swc\''
  60. assert fnwrt in swcList, 'The parameter finallyNormalizeWRT must be an element of the parameter swcList'
  61. print('All parameters are acceptable. Starting the Reg-MaxS-N jobs...')
  62. for parInd, pars in enumerate(parsList):
  63. print('Starting Job # {}'.format(parInd + 1))
  64. print('Current Parameters:')
  65. for parN, parV in pars.iteritems():
  66. print('{}: {}'.format(parN, parV))
  67. resDir = pars['resDir']
  68. refSWC = pars['initRefSWC']
  69. swcList = pars['swcList']
  70. fnwrt = pars['finallyNormalizeWRT']
  71. usePartsDir = pars['usePartsDir']
  72. nIter = pars['maxIter']
  73. gridSizes = pars['gridSizes']
  74. rotBounds = pars['rotBounds']
  75. transBounds = pars['transBounds']
  76. scaleBounds = pars['scaleBounds']
  77. transMinRes = pars['transMinRes']
  78. minScaleStepSize = pars['minScaleStepSize']
  79. rotMinRes = pars['rotMinRes']
  80. nCPU = pars['nCPU']
  81. shutil.copyfile(refSWC, os.path.join(resDir, 'ref' + str(-1) + '.swc'))
  82. if usePartsDir:
  83. for swc in swcList:
  84. dirPath, expName = os.path.split(swc[:-4])
  85. partsDirO = os.path.join(dirPath, expName)
  86. if os.path.isdir(partsDirO):
  87. partsDirD = os.path.join(resDir, expName + str(-1))
  88. shutil.copytree(partsDirO, partsDirD)
  89. prevAlignedSWCs = swcList
  90. occupancyMeasureLargestGridSize = []
  91. bestIterInd = nIter - 1
  92. nrnScaleBounds = {swc: scaleBounds[:] for swc in swcList}
  93. for iterInd in range(nIter):
  94. iterReg = IterativeRegistration(refSWC, gridSizes, rotBounds, transBounds,
  95. transMinRes, minScaleStepSize, rotMinRes, nCPU)
  96. presAlignedSWCs = []
  97. dones = []
  98. for swcInd, swc in enumerate(swcList):
  99. dirPath, expName = os.path.split(swc[:-4])
  100. print('Doing Iter ' + str(iterInd) + ' : ' + expName)
  101. SWC2Align = prevAlignedSWCs[swcInd]
  102. if iterInd > 0:
  103. initGuessTypeT = 'nothing'
  104. else:
  105. initGuessTypeT = 'just_centroids'
  106. initVals = [calcOverlap(refSWC, SWC2Align, g) for g in gridSizes]
  107. if usePartsDir:
  108. inPartsDir = os.path.join(resDir, expName + str(iterInd - 1))
  109. outPartsDir = os.path.join(resDir, expName + str(iterInd))
  110. else:
  111. inPartsDir = None
  112. outPartsDir = None
  113. resFile = os.path.join(resDir, expName + str(iterInd) + '.swc')
  114. resSWC, resSol = iterReg.performReg(SWC2Align, resFile,
  115. scaleBounds=nrnScaleBounds[swc],
  116. inPartsDir=inPartsDir,
  117. outPartsDir=outPartsDir,
  118. initGuessType=initGuessTypeT,
  119. retainTempFiles=True)
  120. finalVals = [calcOverlap(refSWC, resSWC, gridSize) for gridSize in gridSizes]
  121. considerIteration = False
  122. for iv, fv in zip(initVals, finalVals):
  123. if fv < iv:
  124. considerIteration = True
  125. break
  126. if fv > iv:
  127. considerIteration = False
  128. break
  129. if not considerIteration:
  130. shutil.copy(SWC2Align, resSWC)
  131. shutil.rmtree(os.path.join(resDir, expName + str(iterInd) + 'trans'))
  132. if usePartsDir and os.path.exists(outPartsDir):
  133. shutil.rmtree(outPartsDir)
  134. shutil.copytree(inPartsDir, outPartsDir)
  135. os.remove(resSol)
  136. print('finalVal (' + str(finalVals) + ') >= initVal (' + str(initVals) + '). Doing Nothing!')
  137. done = True
  138. else:
  139. print('finalVal (' + str(finalVals) + ') < initVal (' + str(initVals) + '). Keeping the iteration!')
  140. with open(resSol, 'r') as fle:
  141. pars = json.load(fle)
  142. totalTrans = np.array(pars['finalTransMat'])
  143. done = np.allclose(np.eye(3), totalTrans[:3, :3], atol=1e-3)
  144. scale, shear, angles, trans, persp = decompose_matrix(totalTrans)
  145. nrnScaleBounds[swc] = getRemainderScale(scale, nrnScaleBounds[swc])
  146. dones.append(done)
  147. print('Finished ' + expName + ' : ' + str(done))
  148. print('Remainder scale: ' + str(nrnScaleBounds[swc]))
  149. presAlignedSWCs.append(resSWC)
  150. newRefSWC = os.path.join(resDir, 'ref' + str(iterInd) + '.swc')
  151. overallOverlap = composeRefSWC(presAlignedSWCs, newRefSWC, gridSizes[-1])
  152. occupancyMeasure = occupancyEMD(presAlignedSWCs, gridSizes[-1])
  153. occupancyMeasureLargestGridSize.append(occupancyMeasure)
  154. refSWC = newRefSWC
  155. prevAlignedSWCs = presAlignedSWCs
  156. if all(dones):
  157. break
  158. bestIterInd = np.argmin(occupancyMeasureLargestGridSize)
  159. bestMeasure = min(occupancyMeasureLargestGridSize)
  160. shutil.copy(os.path.join(resDir, 'ref' + str(bestIterInd) + '.swc'), os.path.join(resDir, 'finalRef.swc'))
  161. ipFiles = []
  162. opFiles = []
  163. thrash, fnwrtName = os.path.split(fnwrt[:-4])
  164. for swc in swcList:
  165. dirPath, expName = os.path.split(swc[:-4])
  166. ipFiles.append(os.path.join(resDir, '{}{}.swc'.format(expName, bestIterInd)))
  167. opFiles.append(os.path.join(resDir, '{}.swc'.format(expName)))
  168. normalizeFinally(ipFiles, resDir, opFiles, fnwrtName, bestIterInd)
  169. finalSolFile = os.path.join(resDir, "bestIterInd.json")
  170. with open(finalSolFile, 'w') as fle:
  171. json.dump({'finalVal': bestMeasure,
  172. 'bestIteration': bestIterInd}, fle)
  173. print ('Finished Job # {}'.format(parInd + 1))
  174. if __name__ == '__main__':
  175. from regmaxsn.core.RegMaxSPars import RegMaxSNParNames
  176. assert len(sys.argv) == 2, 'Improper usage! Please use as \'python RegMaxSN.py parFile\''
  177. parFile = sys.argv[1]
  178. runRegMaxSN(parFile, RegMaxSNParNames)