plotPairwiseDistance.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import os
  2. import json
  3. import numpy as np
  4. import seaborn as sns
  5. import matplotlib.pyplot as plt
  6. import pandas as pd
  7. from regmaxsn.core.matplotlibRCParams import mplPars
  8. homeFolder = os.path.expanduser('~')
  9. def plotPairwiseDistances(parFile):
  10. plt.ion()
  11. sns.set(rc=mplPars)
  12. with open(parFile) as fle:
  13. parsList = json.load(fle)
  14. transErrs = pd.DataFrame(None, columns=['Exp. Name', 'Pairwise Distance in $\mu$m'])
  15. for par in parsList:
  16. refSWC = par['refSWC']
  17. resFile = par['resFile']
  18. testName = resFile[:-4]
  19. thresh = par['gridSizes'][-1]
  20. print('Doing ' + repr((refSWC, resFile)))
  21. refPts = np.loadtxt(refSWC)[:, 2:5]
  22. testPts = np.loadtxt(resFile)[:, 2:5]
  23. if refPts.shape[0] != testPts.shape[0]:
  24. print('Number of points do not match for ' + refSWC + 'and' + resFile)
  25. continue
  26. ptDiff = np.linalg.norm(refPts - testPts, axis=1)
  27. transErrs = transErrs.append(pd.DataFrame({'Pairwise Distance in $\mu$m': ptDiff,
  28. 'Exp. Name': testName}),
  29. ignore_index=True)
  30. transErrsGr = transErrs.groupby(by='Exp. Name')
  31. regErrs = transErrsGr['Pairwise Distance in $\mu$m'].agg({'\% of points closer than\n lowest grid size':
  32. lambda x: 100 * ((x <= thresh).sum()) / float(len(x))})
  33. fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(14, 11.2))
  34. fig1, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(14, 11.2))
  35. with sns.axes_style('darkgrid'):
  36. sns.boxplot(x='Exp. Name', y='Pairwise Distance in $\mu$m',
  37. ax=ax, data=transErrs, whis='range', color=sns.color_palette()[0])
  38. ax1.plot(range(regErrs.size), regErrs['\% of points closer than\n lowest grid size'],
  39. color=sns.color_palette()[0], marker='o', linestyle='-', ms=10)
  40. ax.set_xlim(-1, len(regErrs))
  41. ax.plot(ax.get_xlim(), [thresh, thresh], 'r--')
  42. ax.set_ylim(0, 40)
  43. ax.set_xticklabels(['job {}'.format(x) for x in range(len(parsList))], rotation=90)
  44. ax.set_xlabel('')
  45. ax1.set_xlim(-1, len(regErrs))
  46. ax1.set_ylim(-10, 110)
  47. ax1.set_xticks(range(regErrs.size))
  48. ax1.set_xticklabels(['par{}'.format(x) for x in range(len(parsList))], rotation=90)
  49. ax1.set_ylabel('\% of points closer than\n lowest grid size')
  50. for ind, f in enumerate([fig, fig1]):
  51. # f.canvas.draw()
  52. f.tight_layout()
  53. return fig, fig1
  54. # ----------------------------------------------------------------------------------------------------------------------
  55. if __name__ == '__main__':
  56. import sys
  57. assert len(sys.argv) == 2, 'Improper usage! Please use as \'python plotPairwiseDistance.py parFile\''
  58. parFile = sys.argv[1]
  59. figs = plotPairwiseDistances(parFile)
  60. raw_input('Press any key to close figures and quit:')