Snakefile 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import sys
  2. from pathlib import Path
  3. sys.path.append(str(Path.cwd().parents[0]))
  4. from project_utils.project_paths import project_paths
  5. configfile: 'config.yaml'
  6. globals().update(config)
  7. wildcard_constraints:
  8. plot_name = '\w+',
  9. flag = '\d?|_temp-test',
  10. variant = '\|minimatrigger|\d?'
  11. def all_figures(w):
  12. variant = '|minimatrigger'
  13. return [project_paths.figures / f'dataset_comparison.pdf',
  14. project_paths.figures / f'method_comparison{variant}.pdf']
  15. rule all:
  16. input:
  17. all_figures
  18. def local_data(w, variant=''):
  19. df_file = f'wavefronts_channel-wise{variant}_measures.csv'
  20. return project_paths.dataframes / df_file
  21. def alt_local_data(w):
  22. return local_data(w, variant=w.variant)
  23. rule plot_dataset_comparison:
  24. input:
  25. local_data = local_data,
  26. global_data = project_paths.dataframes \
  27. / 'wavefronts_wave-wise_measures.csv',
  28. local_trend_data = project_paths.dataframes \
  29. / 'wavefronts_channel-wise_trend_measures.csv',
  30. global_trend_data = project_paths.dataframes \
  31. / 'wavefronts_wave-wise_trend_measures.csv',
  32. local_subsampled_data = project_paths.dataframes \
  33. / 'wavefronts_channel-wise|macrodim11_measures.csv',
  34. script = 'scripts/plot_dataset_comparison.py',
  35. utils = 'scripts/plotting_utils.py',
  36. output:
  37. project_paths.figures / 'dataset_comparison.pdf'
  38. shell:
  39. """
  40. python {input.script} --local_data {input.local_data:q} \
  41. --global_data {input.global_data:q} \
  42. --local_trend_data {input.local_trend_data:q} \
  43. --global_trend_data {input.global_trend_data:q} \
  44. --local_subsampled_data {input.local_subsampled_data:q} \
  45. --output {output:q}
  46. """
  47. def signal_path(w, variant=''):
  48. profile = 'LENS_M2_t1'
  49. return project_paths.pipeline_output / f'{profile}{variant}' \
  50. / 'stage04_wave_detection' / 'waves.nix'
  51. def alt_signal_path(w):
  52. return signal_path(w, variant=w.variant)
  53. rule plot_method_comparison:
  54. input:
  55. local_data = local_data,
  56. global_data = project_paths.dataframes \
  57. / 'wavefronts_wave-wise_measures.csv',
  58. alt_local_data = alt_local_data,
  59. alt_global_data = project_paths.dataframes \
  60. / 'wavefronts_wave-wise{variant}_measures.csv',
  61. signal = signal_path,
  62. alt_signal = alt_signal_path,
  63. script = 'scripts/plot_method_comparison.py',
  64. utils = 'scripts/plotting_utils.py',
  65. output:
  66. project_paths.figures / 'method_comparison{variant}.pdf'
  67. shell:
  68. """
  69. python {input.script} --local_data {input.local_data:q} \
  70. --global_data {input.global_data:q} \
  71. --alt_local_data {input.alt_local_data:q} \
  72. --alt_global_data {input.alt_global_data:q} \
  73. --signal {input.signal:q} \
  74. --alt_signal {input.alt_signal:q} \
  75. --output {output:q}
  76. """