ViolinPlots_abdominal.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import seaborn as sns
  5. from scipy.stats import ttest_ind, f_oneway
  6. import glob
  7. import os
  8. from scipy.stats import ttest_ind
  9. from statsmodels.stats.multitest import multipletests
  10. import re
  11. #% Function to read CSV files and store them in a dictionary
  12. script_dir = os.path.dirname(__file__)
  13. out_path = os.path.join(script_dir, '..', 'figures')
  14. cm = 1/2.54 # centimeters in inches
  15. def read_csv_files(files):
  16. data_dict = {}
  17. for ff,file in enumerate(files):
  18. df = pd.read_csv(file)
  19. # Extract the data type from the file name (assuming the file name contains "anat", "diff", or "func")
  20. data_type = "anat" if "anat" in file else ("diff" if "diff" in file else "func")
  21. data_dict[ff] = {data_type:df}
  22. return data_dict
  23. # Function for statistical comparison and plotting
  24. def compare_and_plot(data, column_name, group_column):
  25. sns.set_palette("Set2") # Use colors for color-blind people
  26. plt.figure(figsize=(9*cm, 6*cm))
  27. sns.boxplot(x=group_column, y=column_name, data=data)
  28. plt.xlabel(group_column)
  29. plt.ylabel(column_name)
  30. plt.title(f"Statistical Comparison of {column_name} by {group_column}")
  31. plt.tight_layout()
  32. #plt.savefig(f"{column_name}_by_{group_column}_boxplot.png")
  33. plt.show()
  34. Path = r"C:\Users\arefk\Desktop\Projects\AIDAqcOutput_of_all_Datasets"
  35. anat_files = [file for file in glob.glob(os.path.join(Path, "*/*/*caculated_features_anat.csv"), recursive=True) if "m_Rei" in file or "7_m_Lo" in file]
  36. diff_files = [file for file in glob.glob(os.path.join(Path, "*/*/*caculated_features_diff.csv"), recursive=True) if "m_Rei" in file or "7_m_Lo" in file]
  37. func_files = [file for file in glob.glob(os.path.join(Path, "*/*/*caculated_features_func.csv"), recursive=True) if "m_Rei" in file or "7_m_Lo" in file]
  38. All_files = [anat_files,diff_files,func_files]
  39. # Read the CSV files and store them in dictionaries
  40. anat_data = read_csv_files(anat_files)
  41. diff_data = read_csv_files(diff_files)
  42. func_data = read_csv_files(func_files)
  43. All_Data = [anat_data,diff_data,func_data]
  44. All_type = ["anat","diff","func"]
  45. #% data statistisc figure 7
  46. BINS = [8,8,8]
  47. features_to_compare = ["SNR Chang", "SNR Normal", "tSNR (Averaged Brain ROI)", "Displacement factor (std of Mutual information)"]
  48. #features_to_compare = ["SpatRx", "SpatRy", "Slicethick"]
  49. Data_of_selected_feature2 = pd.DataFrame()
  50. for dd,data in enumerate(All_Data):
  51. for feature in features_to_compare:
  52. cc = 0
  53. temp = pd.DataFrame()
  54. Data_of_selected_feature = pd.DataFrame()
  55. temp_data = pd.DataFrame()
  56. for key in data:
  57. try:
  58. temp_data[feature] = data[key][All_type[dd]][feature]
  59. except KeyError:
  60. continue
  61. temp_data["Dataset"] = All_files[dd][cc].split(os.sep)[-3]
  62. cc = cc +1
  63. Data_of_selected_feature = pd.concat([Data_of_selected_feature, temp_data], ignore_index=True)
  64. #Data_of_selected_feature2 = pd.concat([Data_of_selected_feature2, temp_data], ignore_index=True)
  65. if not Data_of_selected_feature.empty:
  66. Data_of_selected_feature['sort'] = Data_of_selected_feature['Dataset'].str.extract('(\d+)', expand=True).astype(int)
  67. Data_of_selected_feature = Data_of_selected_feature.sort_values('sort')
  68. if feature == "SNR Normal":
  69. Data_of_selected_feature.rename(columns={"SNR Normal": "SNR-Standard (dB)"}, inplace=True)
  70. feature = "SNR-Standard (dB)"
  71. if feature == "SNR Chang":
  72. Data_of_selected_feature.rename(columns={"SNR Chang": "SNR-Chang (dB)"}, inplace=True)
  73. feature = "SNR-Chang (dB)"
  74. elif feature == "tSNR (Averaged Brain ROI)":
  75. Data_of_selected_feature.rename(columns={"tSNR (Averaged Brain ROI)": "tSNR (dB)"}, inplace=True)
  76. feature = "tSNR (dB)"
  77. elif feature == "Displacement factor (std of Mutual information)":
  78. Data_of_selected_feature.rename(columns={"Displacement factor (std of Mutual information)": "Motion severity (a.u)"}, inplace=True)
  79. BINS[dd] = 10
  80. feature = "Motion severity (a.u)"
  81. #Data_of_selected_feature2["Vol"] = Data_of_selected_feature2["SpatRx"]*Data_of_selected_feature2["SpatRy"]*Data_of_selected_feature2["Slicethick"]
  82. #Data_of_selected_feature = Data_of_selected_feature.sort_values("Dataset",ascending=False)
  83. # creating boxplots
  84. if All_type[dd] == "anat":
  85. plt.figure(figsize=(6*cm,3*cm),dpi=600)
  86. else:
  87. plt.figure(figsize=(6*cm,3*cm),dpi=600)
  88. sns.set_style('ticks')
  89. sns.set(font='Times New Roman',style=None) # Set font to Times New Roman and font size to 9
  90. palette = 'Set2'
  91. ax = sns.violinplot(x="Dataset", y=feature, data=Data_of_selected_feature, hue="Dataset", dodge=False,
  92. palette=palette,
  93. scale="width", inner=None,linewidth=1)
  94. patches = ax.patches
  95. #legend_colors = [patch.get_facecolor() for patch in patches[:]]
  96. xlim = ax.get_xlim()
  97. ylim = ax.get_ylim()
  98. for violin in ax.collections:
  99. bbox = violin.get_paths()[0].get_extents()
  100. x0, y0, width, height = bbox.bounds
  101. violin.set_clip_path(plt.Rectangle((x0, y0), width / 2, height, transform=ax.transData))
  102. sns.boxplot(x="Dataset", y=feature, data=Data_of_selected_feature, saturation=1, showfliers=False,
  103. width=0.3, boxprops={'zorder': 3, 'facecolor': 'none'}, ax=ax, linewidth=1)
  104. old_len_collections = len(ax.collections)
  105. sns.stripplot(x="Dataset", y=feature, data=Data_of_selected_feature,size=1.1, hue="Dataset", palette=palette, dodge=False, ax=ax)
  106. for dots in ax.collections[old_len_collections:]:
  107. dots.set_offsets(dots.get_offsets() + np.array([0.12, 0]))
  108. ax.set_xlim(xlim)
  109. ax.set_ylim(ylim)
  110. #ax.legend_.remove()
  111. ax.locator_params(axis='y', nbins=BINS[dd]) # Set the number of ticks for the y-axis
  112. ax
  113. ax.set_xticklabels(ax.get_xticklabels(), rotation=45,fontsize=8)
  114. ax.set_yticklabels(ax.get_yticklabels(),fontsize=8)
  115. #ax.set_yticks(np.linspace(ax.get_ylim()[0], ax.get_ylim()[1], BINS[dd]))
  116. # =============================================================================
  117. # for label, color in zip(ax.get_xticklabels(), legend_colors):
  118. # label.set_color(color)
  119. # =============================================================================
  120. ax.set_xlabel('')
  121. ax.set_title(All_type[dd].capitalize(),weight='bold',fontsize=10)
  122. y_label = ax.set_ylabel(ax.get_ylabel(),fontsize=8)
  123. # =============================================================================
  124. ax.xaxis.grid(True, linestyle='-', which='major', color='gray', linewidth=0)
  125. ax.xaxis.grid(True, linestyle='--', which='minor', color='gray', linewidth=0)
  126. #
  127. ax.yaxis.grid(True, linestyle='-', which='major', color='gray', linewidth=0)
  128. ax.yaxis.grid(True, linestyle='--', which='minor', color='gray', linewidth=0)
  129. # =============================================================================
  130. ax.spines['top'].set_linewidth(0.5) # Top border
  131. ax.spines['right'].set_linewidth(0.5) # Right border
  132. ax.spines['bottom'].set_linewidth(0.5) # Bottom border
  133. ax.spines['left'].set_linewidth(0.5) # Left border
  134. # Set axis ticks font and number of ticks
  135. ax.tick_params(axis='both', which='both', width=0.5, color='gray', length=2)
  136. ax.tick_params(axis='both', which='both', width=0.5,color='gray',length=2)
  137. plt.xticks(ha='right')
  138. plt.savefig(os.path.join(out_path,feature+"_"+All_type[dd]+"Abdominal.svg"), format='svg', bbox_inches='tight',transparent=False)
  139. plt.savefig(os.path.join(out_path,feature+"_"+All_type[dd]+"Abdominal.png"),dpi=300 ,format='png', bbox_inches='tight',transparent=False)
  140. plt.show()