plotOccupancyBasedMeasureVsAverageRotScaleTransform.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import pandas as pd
  2. from regmaxsn.core.occupancyBasedMeasure import occupancyEMD
  3. import pathlib2
  4. import json
  5. import numpy as np
  6. from matplotlib import pyplot as plt
  7. import seaborn as sns
  8. from regmaxsn.core.matplotlibRCParams import mplPars
  9. import sys
  10. dirPath = pathlib2.Path("/media/ajay/ADATA_HD720/Ginjang/DataAndResults/morphology/OriginalData/Tests")
  11. expName = 'HSN-fluoro01.CNG'
  12. swcSetSize = 20
  13. N = 500
  14. resDir = pathlib2.Path("/media/ajay/ADATA_HD720/Ginjang/DataAndResults/morphology/OccupancyBaseTests")
  15. if not resDir.exists():
  16. resDir.mkdir()
  17. voxelSize = 10
  18. scaleDF = pd.DataFrame()
  19. rotDF = pd.DataFrame()
  20. transDF = pd.DataFrame()
  21. suffixes = ("RandRotY", "RandScaleY", "RandTranslateY")
  22. dfs = [rotDF, scaleDF, transDF]
  23. labels = ("Rotation (degrees)", "Scale", "Translation (um)")
  24. jsonKeys = ("angles", "scale", "translation")
  25. parInds = (1, 1, 1)
  26. figs = []
  27. def saveData():
  28. for suffixInd, suffix in enumerate(suffixes):
  29. print("Doing {}".format(labels[suffixInd]))
  30. parSWCDict = {}
  31. for ind in range(N):
  32. label = labels[suffixInd]
  33. jsonKey = jsonKeys[suffixInd]
  34. parInd = parInds[suffixInd]
  35. outFile = str(dirPath / '{}{}{}.swc'.format(expName, suffix, ind))
  36. transJSONFile = str(dirPath / '{}{}{}.json'.format(expName, suffix, ind))
  37. with open(transJSONFile, "r") as fle:
  38. transJSON = json.load(fle)
  39. jsonPars = transJSON[jsonKey]
  40. parSWCDict[jsonPars[1]] = outFile
  41. allPars = parSWCDict.keys()
  42. allParsSorted = np.sort(allPars)
  43. maxStepSize = int(np.floor(float(N) / float(swcSetSize)))
  44. baseSet = np.arange(0, swcSetSize)
  45. for stepSize in range(1, maxStepSize + 1):
  46. print("Doing StepSize {}/{}".format(stepSize, maxStepSize))
  47. windowSlideSize = int(stepSize * swcSetSize / 2)
  48. windowStarts = range(0, N - stepSize * swcSetSize + 1, windowSlideSize)
  49. for windowStart in windowStarts:
  50. print("Doing Window start {}/{}".format(windowStart, windowStarts))
  51. pars = allParsSorted[windowStart + stepSize * baseSet]
  52. swcFiles = [parSWCDict[par] for par in pars]
  53. metric = occupancyEMD(swcFiles, voxelSize)
  54. if suffix == "RandRotY":
  55. pars = np.rad2deg(pars)
  56. tempDict = {"mean of {}".format(label): np.mean(pars),
  57. "std of {}".format(label): np.std(pars),
  58. "metric": metric}
  59. dfs[suffixInd] = dfs[suffixInd].append(tempDict, ignore_index=True)
  60. outFile = str(resDir / "metricVs{}.xlsx".format(suffix))
  61. dfs[suffixInd].to_excel(outFile)
  62. def plotData():
  63. sns.set(rc=mplPars)
  64. figs = []
  65. for suffixInd, suffix in enumerate(suffixes):
  66. label = labels[suffixInd]
  67. dfXL = str(resDir / "metricVs{}.xlsx".format(suffix))
  68. df = pd.read_excel(dfXL)
  69. df["mean of {}".format(label)] = df["mean of {}".format(label)].apply(lambda x: round(x, 3))
  70. df["std of {}".format(label)] = df["std of {}".format(label)].apply(lambda x: round(x, 3))
  71. df2Plot = df.pivot(index="mean of {}".format(label),
  72. columns="std of {}".format(label),
  73. values="metric")
  74. fig, ax = plt.subplots(figsize=(14, 11.2))
  75. sns.heatmap(data=df2Plot, ax=ax, xticklabels=10, yticklabels=10, cmap=plt.cm.jet)
  76. ax.set_xticklabels(ax.get_xticklabels(), rotation="vertical")
  77. ax.set_yticklabels(ax.get_yticklabels(), rotation="horizontal")
  78. fig.tight_layout()
  79. outFile = str(resDir / "metricVs{}.png".format(suffix))
  80. fig.savefig(outFile, dpi=150)
  81. figs.append(fig)
  82. plt.show()
  83. if __name__ == "__main__":
  84. assert len(sys.argv) == 2, 'Improper usage! Please use as \'python {arg} save\' or' \
  85. '\'python {arg} plot'.format(arg=sys.argv[0])
  86. if sys.argv[1] == "save":
  87. saveData()
  88. elif sys.argv[1] == "plot":
  89. plotData()
  90. else:
  91. raise ValueError