main.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. from importlib.resources import files
  2. from pathlib import Path
  3. import argparse
  4. from nipype.interfaces import fsl, freesurfer
  5. import nipype.pipeline as pe
  6. from hcpdiffpy.container import SimgCmd
  7. from hcpdiffpy.interfaces.data import InitData, SaveData
  8. from hcpdiffpy.interfaces.preproc import (
  9. BETMask, EddyIndex, EddyPostProc, ExtractB0, DilateMask, MergeBFiles, Rescale, RotateBVec2Str,
  10. PrepareTopup, WBDilate)
  11. from hcpdiffpy.interfaces.utilities import (
  12. CreateList, CombineStrings, DiffRes, FlattenList, ListItem, PickDiffFiles, SplitDiffFiles,
  13. UpdateDiffFiles)
  14. from hcpdiffpy import utilities
  15. def main() -> None:
  16. parser = argparse.ArgumentParser(
  17. description="HCP Pipeline for diffusion preprocessing",
  18. formatter_class=lambda prog: argparse.ArgumentDefaultsHelpFormatter(prog, width=100))
  19. required = parser.add_argument_group("required arguments")
  20. required.add_argument(
  21. "subject_dir", type=Path,
  22. help="Absolute path to the subject's data folder (organised in HCP-like structure)")
  23. required.add_argument("subject", type=str, help="Subject ID")
  24. required.add_argument(
  25. "echo_spacing", type=float, help="Echo spacing used for acquisition in ms")
  26. required.add_argument("--ndirs", required=True, nargs="+", help="List of numbers of directions")
  27. required.add_argument(
  28. "--phases", required=True, nargs="+", help="List of phase encoding directions")
  29. optional = parser.add_argument_group("optional arguments")
  30. optional.add_argument(
  31. "--work_dir", type=Path, default=Path.cwd(), help="Absolute path to work directory")
  32. optional.add_argument(
  33. "--output_dir", type=Path, default=Path.cwd(), help="Absolute path to output directory")
  34. optional.add_argument("--fsl_simg", type=Path, default=None, help="singularity image for FSL")
  35. optional.add_argument(
  36. "--fs_simg", type=Path, default=None, help="singularity image for FreeSurfer")
  37. optional.add_argument(
  38. "--wb_simg", type=Path, default=None, help="singularity image for Connectome Workbench")
  39. optional.add_argument("--condordag", action="store_true", help="Submit as DAG to HTCondor")
  40. config = vars(parser.parse_args())
  41. # Set-up
  42. config["output_dir"].mkdir(parents=True, exist_ok=True)
  43. config["tmp_dir"] = Path(config["work_dir"], f"hcpdiff_{config['subject']}_tmp")
  44. config["tmp_dir"].mkdir(parents=True, exist_ok=True)
  45. config["keys"] = [
  46. f"dir{ndir}_{phase}" for ndir in sorted(config["ndirs"])
  47. for phase in sorted(config["phases"])]
  48. d_iterables = [("ndir", config["ndirs"]), ("phase", config["phases"])]
  49. fsl_cmd = SimgCmd(config, config["fsl_simg"])
  50. fs_cmd = SimgCmd(config, config["fs_simg"])
  51. wb_cmd = SimgCmd(config, config["wb_simg"])
  52. topup_config = Path(files(utilities) / "b02b0.cnf")
  53. sch_file = Path(files(utilities) / "bbr.sch")
  54. fs_dir = Path(config["subject_dir"], "T1w")
  55. # Workflow set-up
  56. hcpdiff_wf = pe.Workflow(f"hcpdiff_{config['subject']}_wf", base_dir=config["work_dir"])
  57. hcpdiff_wf.config["execution"]["remove_node_directories"] = "true"
  58. hcpdiff_wf.config["execution"]["crashfile_format"] = "txt"
  59. hcpdiff_wf.config["execution"]["stop_on_first_crash"] = "true"
  60. # Get data
  61. init_data = pe.Node(InitData(config=config), "init_data")
  62. d_files = pe.Node(PickDiffFiles(), "d_files", iterables=d_iterables)
  63. hcpdiff_wf.connect([(init_data, d_files, [("d_files", "d_files")])])
  64. # 1. PreEddy
  65. # 1.1. Normalise intensity
  66. mean_dwi = pe.Node(
  67. fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths"), args="-Xmean -Ymean -Zmean"), "mean_dwi")
  68. extract_b0s = pe.Node(ExtractB0(fsl_cmd=fsl_cmd, config=config), "extract_b0s")
  69. merge_b0s = pe.Node(fsl.Merge(command=fsl_cmd.cmd("fslmerge"), dimension="t"), "merge_b0s")
  70. mean_b0 = pe.Node(fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths"), args="-Tmean"), "mean_b0")
  71. scale = pe.Node(fsl.ImageMeants(command=fsl_cmd.cmd("fslmeants")), "scale")
  72. rescale = pe.JoinNode(
  73. Rescale(fsl_cmd=fsl_cmd, config=config), "rescale",
  74. joinfield=["scale_files"], joinsource="d_files")
  75. hcpdiff_wf.connect([
  76. (d_files, mean_dwi, [("data_file", "in_file")]),
  77. (d_files, extract_b0s, [("bval_file", "bval_file")]),
  78. (mean_dwi, extract_b0s, [("out_file", "data_file")]),
  79. (extract_b0s, merge_b0s, [("roi_files", "in_files")]),
  80. (merge_b0s, mean_b0, [("merged_file", "in_file")]),
  81. (mean_b0, scale, [("out_file", "in_file")]),
  82. (init_data, rescale, [("d_files", "d_files")]),
  83. (scale, rescale, [("out_file", "scale_files")])])
  84. # 1.2. Prepare b0s and index files for topup
  85. update_d_files = pe.Node(UpdateDiffFiles(config=config), "update_d_files")
  86. rescaled_d_files = pe.Node(PickDiffFiles(), "split_rescaled", iterables=d_iterables)
  87. rescaled_b0s = pe.Node(ExtractB0(config=config, rescale=True, fsl_cmd=fsl_cmd), "rescaled_b0s")
  88. b0_list = pe.JoinNode(FlattenList(), "b0_list", joinfield="input", joinsource="split_rescaled")
  89. pos_b0_list = pe.JoinNode(
  90. FlattenList(), "pos_b0_list", joinfield="input", joinsource="split_rescaled")
  91. neg_b0_list = pe.JoinNode(
  92. FlattenList(), "neg_b0_list", joinfield="input", joinsource="split_rescaled")
  93. merge_rescaled_b0s = pe.Node(
  94. fsl.Merge(command=fsl_cmd.cmd("fslmerge"), dimension="t"), "merge_rescaled_b0s")
  95. merge_pos_b0s = pe.Node(
  96. fsl.Merge(command=fsl_cmd.cmd("fslmerge"), dimension="t"), "merge_pos_b0s")
  97. merge_neg_b0s = pe.Node(
  98. fsl.Merge(command=fsl_cmd.cmd("fslmerge"), dimension="t"), "merge_neg_b0s")
  99. hcpdiff_wf.connect([
  100. (init_data, update_d_files, [("d_files", "d_files")]),
  101. (rescale, update_d_files, [("rescaled_files", "data_files")]),
  102. (update_d_files, rescaled_d_files, [("d_files", "d_files")]),
  103. (rescaled_d_files, rescaled_b0s, [("bval_file", "bval_file"), ("data_file", "data_file")]),
  104. (rescaled_b0s, b0_list, [("roi_files", "input")]),
  105. (rescaled_b0s, pos_b0_list, [("pos_files", "input")]),
  106. (rescaled_b0s, neg_b0_list, [("neg_files", "input")]),
  107. (b0_list, merge_rescaled_b0s, [("output", "in_files")]),
  108. (pos_b0_list, merge_pos_b0s, [("output", "in_files")]),
  109. (neg_b0_list, merge_neg_b0s, [("output", "in_files")])])
  110. # 1.3. Topup
  111. prepare_topup = pe.Node(PrepareTopup(config=config), "prepare_topup")
  112. topup = pe.Node(fsl.TOPUP(command=fsl_cmd.cmd("topup"), config=str(topup_config)), "topup")
  113. pos_b01 = pe.Node(fsl.ExtractROI(command=fsl_cmd.cmd("fslroi"), t_min=0, t_size=1), "pos_b01")
  114. neg_b01 = pe.Node(fsl.ExtractROI(command=fsl_cmd.cmd("fslroi"), t_min=0, t_size=1), "neg_b01")
  115. b01_files = pe.Node(CreateList(), "b01_files")
  116. apply_topup = pe.Node(
  117. fsl.ApplyTOPUP(command=fsl_cmd.cmd("applytopup"), method="jac"), "apply_topup")
  118. nodiff_mask = pe.Node(BETMask(config=config, fsl_cmd=fsl_cmd), "nodiff_mask")
  119. hcpdiff_wf.connect([
  120. (update_d_files, prepare_topup, [("d_files", "d_files")]),
  121. (b0_list, prepare_topup, [("output", "roi_files")]),
  122. (merge_pos_b0s, prepare_topup, [("merged_file", "pos_b0_file")]),
  123. (merge_rescaled_b0s, topup, [("merged_file", "in_file")]),
  124. (prepare_topup, topup, [("enc_dir", "encoding_direction"), ("ro_time", "readout_times")]),
  125. (merge_pos_b0s, pos_b01, [("merged_file", "in_file")]),
  126. (merge_neg_b0s, neg_b01, [("merged_file", "in_file")]),
  127. (pos_b01, b01_files, [("roi_file", "input1")]),
  128. (neg_b01, b01_files, [("roi_file", "input2")]),
  129. (prepare_topup, apply_topup, [("indices_t", "in_index")]),
  130. (topup, apply_topup, [
  131. ("out_enc_file", "encoding_file"), ("out_fieldcoef", "in_topup_fieldcoef"),
  132. ("out_movpar", "in_topup_movpar")]),
  133. (b01_files, apply_topup, [("output", "in_files")]),
  134. (apply_topup, nodiff_mask, [("out_corrected", "in_file")])])
  135. # 2. Eddy
  136. d_filetype = pe.Node(SplitDiffFiles(), "split_d_filetype")
  137. merge_bfiles = pe.Node(MergeBFiles(config=config), "merge_bfiles")
  138. merge_dwi = pe.Node(fsl.Merge(command=fsl_cmd.cmd("fslmerge"), dimension="t"), "merge_dwi")
  139. eddy_index = pe.Node(EddyIndex(config=config), "eddy_index")
  140. eddy = pe.Node(fsl.Eddy(command=fsl_cmd.cmd("eddy"), fwhm=0, args="-v"), name="eddy")
  141. hcpdiff_wf.connect([
  142. (update_d_files, d_filetype, [("d_files", "d_files")]),
  143. (d_filetype, merge_bfiles, [
  144. ("bval_files", "bval_files"), ("bvec_files", "bvec_files")]),
  145. (d_filetype, merge_dwi, [("data_files", "in_files")]),
  146. (b0_list, eddy_index, [("output", "roi_files")]),
  147. (d_filetype, eddy_index, [("data_files", "dwi_files")]),
  148. (merge_dwi, eddy, [("merged_file", "in_file")]),
  149. (merge_bfiles, eddy, [("bval_merged", "in_bval"), ("bvec_merged", "in_bvec")]),
  150. (topup, eddy, [("out_enc_file", "in_acqp")]),
  151. (eddy_index, eddy, [("index_file", "in_index")]),
  152. (nodiff_mask, eddy, [("mask_file", "in_mask")]),
  153. (topup, eddy, [
  154. ("out_fieldcoef", "in_topup_fieldcoef"), ("out_movpar", "in_topup_movpar")])])
  155. # 3. PostEddy
  156. # 3.1. Postproc
  157. postproc = pe.Node(EddyPostProc(fsl_cmd=fsl_cmd, config=config), "postproc")
  158. fov_mask = pe.Node(
  159. fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths"), args="-abs -Tmin -bin -fillh"), "fov_mask")
  160. mask_args = pe.Node(CombineStrings(input1="-mas "), "mask_args")
  161. fmask_data = pe.Node(fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths")), "fmask_data")
  162. thr_data = pe.Node(fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths"), args="-thr 0"), "thr_data")
  163. hcpdiff_wf.connect([
  164. (d_filetype, postproc, [
  165. ("bval_files", "bval_files"), ("bvec_files", "bvec_files"),
  166. ("data_files", "rescaled_files")]),
  167. (eddy, postproc, [
  168. ("out_corrected", "eddy_corrected_file"), ("out_rotated_bvecs", "eddy_bvecs_file")]),
  169. (postproc, fov_mask, [("combined_dwi_file", "in_file")]),
  170. (fov_mask, mask_args, [("out_file", "input2")]),
  171. (postproc, fmask_data, [("combined_dwi_file", "in_file")]),
  172. (mask_args, fmask_data, [("output", "args")]),
  173. (fmask_data, thr_data, [("out_file", "in_file")])])
  174. # 3.2. DiffusionToStructural
  175. # 3.2.1. nodiff-to-T1
  176. nodiff_brain = pe.Node(
  177. fsl.ExtractROI(command=fsl_cmd.cmd("fslroi"), t_min=0, t_size=1), "nodiff_brain")
  178. wm_seg = pe.Node(fsl.FAST(command=fsl_cmd.cmd("fast"), output_type="NIFTI_GZ"), "wm_seg")
  179. pve_file = pe.Node(ListItem(index=-1), "pve_file")
  180. wm_thr = pe.Node(
  181. fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths"), args="-thr 0.5 -bin"), "wm_thr")
  182. flirt_init = pe.Node(fsl.FLIRT(command=fsl_cmd.cmd("flirt"), dof=6), "flirt_init")
  183. flirt_nodiff2t1 = pe.Node(
  184. fsl.FLIRT(command=fsl_cmd.cmd("flirt"), dof=6, cost="bbr", schedule=sch_file),
  185. "flirt_nodiff2t1")
  186. nodiff2t1 = pe.Node(
  187. fsl.ApplyWarp(command=fsl_cmd.cmd("applywarp"), interp="spline", relwarp=True), "nodiff2t1")
  188. bias_args = pe.Node(CombineStrings(input1="-div "), "bias_args")
  189. nodiff_bias = pe.Node(fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths")), "nodiff_bias")
  190. hcpdiff_wf.connect([
  191. (thr_data, nodiff_brain, [("out_file", "in_file")]),
  192. (init_data, wm_seg, [("t1_brain_file", "in_files")]),
  193. (wm_seg, pve_file, [("partial_volume_files", "input")]),
  194. (pve_file, wm_thr, [("output", "in_file")]),
  195. (nodiff_brain, flirt_init, [("roi_file", "in_file")]),
  196. (init_data, flirt_init, [("t1_brain_file", "reference")]),
  197. (nodiff_brain, flirt_nodiff2t1, [("roi_file", "in_file")]),
  198. (init_data, flirt_nodiff2t1, [("t1_file", "reference")]),
  199. (wm_thr, flirt_nodiff2t1, [("out_file", "wm_seg")]),
  200. (flirt_init, flirt_nodiff2t1, [("out_matrix_file", "in_matrix_file")]),
  201. (nodiff_brain, nodiff2t1, [("roi_file", "in_file")]),
  202. (init_data, nodiff2t1, [("t1_file", "ref_file")]),
  203. (flirt_nodiff2t1, nodiff2t1, [("out_matrix_file", "premat")]),
  204. (init_data, bias_args, [("bias_file", "input2")]),
  205. (nodiff2t1, nodiff_bias, [("out_file", "in_file")]),
  206. (bias_args, nodiff_bias, [("output", "args")])])
  207. # 3.2.2. diff-to-struct
  208. bbr_epi2t1 = pe.Node(
  209. freesurfer.BBRegister(
  210. command=fs_cmd.cmd("bbregister", options=f"--env SUBJECTS_DIR={fs_dir}"),
  211. contrast_type="bold", dof=6, args="--surf white.deformed", subjects_dir=fs_dir,
  212. subject_id=config["subject"]),
  213. "bbr_epi2t1")
  214. tkr_diff2str = pe.Node(
  215. freesurfer.Tkregister2(command=fs_cmd.cmd("tkregister2"), noedit=True, fsl_out=True),
  216. "tkr_diff2str")
  217. diff2str = pe.Node(
  218. fsl.ConvertXFM(command=fsl_cmd.cmd("convert_xfm"), concat_xfm=True), "diff2str")
  219. hcpdiff_wf.connect([
  220. (nodiff_bias, bbr_epi2t1, [("out_file", "source_file")]),
  221. (init_data, bbr_epi2t1, [("eye_file", "init_reg_file")]),
  222. (nodiff_bias, tkr_diff2str, [("out_file", "moving_image")]),
  223. (bbr_epi2t1, tkr_diff2str, [("out_reg_file", "reg_file")]),
  224. (init_data, tkr_diff2str, [("t1_file", "target_image")]),
  225. (flirt_nodiff2t1, diff2str, [("out_matrix_file", "in_file")]),
  226. (tkr_diff2str, diff2str, [("fsl_file", "in_file2")])])
  227. # 3.2.3. resampling
  228. res_dil = pe.Node(DiffRes(), "res_dil")
  229. flirt_resamp = pe.Node(fsl.FLIRT(command=fsl_cmd.cmd("flirt")), "flirt_resamp")
  230. t1_resamp = pe.Node(
  231. fsl.ApplyWarp(command=fsl_cmd.cmd("applywarp"), interp="spline", relwarp=True), "t1_resamp")
  232. dilate_data = pe.Node(WBDilate(config=config, wb_cmd=wb_cmd), "dilate_data")
  233. resamp_data = pe.Node(
  234. fsl.FLIRT(command=fsl_cmd.cmd("flirt"), apply_xfm=True, interp="spline"), "resamp_data")
  235. resamp_mask = pe.Node(
  236. fsl.FLIRT(command=fsl_cmd.cmd("flirt"), interp="nearestneighbour"), "resamp_mask")
  237. resamp_fmask = pe.Node(
  238. fsl.FLIRT(command=fsl_cmd.cmd("flirt"), apply_xfm=True, interp="trilinear"), "resamp_fmask")
  239. hcpdiff_wf.connect([
  240. (thr_data, res_dil, [("out_file", "data_file")]),
  241. (init_data, flirt_resamp, [
  242. ("t1_restore_file", "in_file"), ("t1_restore_file", "reference")]),
  243. (res_dil, flirt_resamp, [("res", "apply_isoxfm")]),
  244. (init_data, t1_resamp, [("t1_restore_file", "in_file")]),
  245. (flirt_resamp, t1_resamp, [("out_file", "ref_file")]),
  246. (thr_data, dilate_data, [("out_file", "data_file")]),
  247. (res_dil, dilate_data, [("dilate", "dilate")]),
  248. (dilate_data, resamp_data, [("out_file", "in_file")]),
  249. (t1_resamp, resamp_data, [("out_file", "reference")]),
  250. (diff2str, resamp_data, [("out_file", "in_matrix_file")]),
  251. (init_data, resamp_mask, [("mask_file", "in_file"), ("mask_file", "reference")]),
  252. (res_dil, resamp_mask, [("res", "apply_isoxfm")]),
  253. (fov_mask, resamp_fmask, [("out_file", "in_file")]),
  254. (t1_resamp, resamp_fmask, [("out_file", "reference")]),
  255. (diff2str, resamp_fmask, [("out_file", "in_matrix_file")])])
  256. # 3.2.4. postprocessing
  257. dilate_mask = pe.Node(DilateMask(fsl_cmd=fsl_cmd), "dilate_mask")
  258. thr_fmask = pe.Node(
  259. fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths"), args="-thr 0.999 -bin"), "thr_fmask")
  260. masks_args = pe.Node(CombineStrings(input1="-mas ", input3=" -mas "), "masks_args")
  261. mask_data = pe.Node(fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths")), "mask_data")
  262. nonneg_data = pe.Node(
  263. fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths"), args="-thr 0"), "nonneg_data")
  264. mean_mask = pe.Node(fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths"), args="-Tmean"), "mean_mask")
  265. mean_args = pe.Node(CombineStrings(input1="-mas "), "mean_args")
  266. mask_mask = pe.Node(fsl.ImageMaths(command=fsl_cmd.cmd("fslmaths")), "mask_mask")
  267. rot_matrix = pe.Node(fsl.AvScale(command=fsl_cmd.cmd("avscale")), "rot_matrix")
  268. rotate_bvec = pe.Node(RotateBVec2Str(config=config), "rotate_bvec")
  269. hcpdiff_wf.connect([
  270. (resamp_mask, dilate_mask, [("out_file", "mask_file")]),
  271. (resamp_fmask, thr_fmask, [("out_file", "in_file")]),
  272. (dilate_mask, masks_args, [("out_file", "input2")]),
  273. (thr_fmask, masks_args, [("out_file", "input4")]),
  274. (resamp_data, mask_data, [("out_file", "in_file")]),
  275. (masks_args, mask_data, [("output", "args")]),
  276. (mask_data, nonneg_data, [("out_file", "in_file")]),
  277. (nonneg_data, mean_mask, [("out_file", "in_file")]),
  278. (mean_mask, mean_args, [("out_file", "input2")]),
  279. (dilate_mask, mask_mask, [("dil0_file", "in_file")]),
  280. (mean_args, mask_mask, [("output", "args")]),
  281. (diff2str, rot_matrix, [("out_file", "mat_file")]),
  282. (postproc, rotate_bvec, [("rot_bvecs", "bvecs_file")]),
  283. (rot_matrix, rotate_bvec, [("rotation_translation_matrix", "rot")])])
  284. # Save data
  285. save_data = pe.Node(SaveData(config=config), "save_data")
  286. hcpdiff_wf.connect([
  287. (postproc, save_data, [("rot_bvals", "bval_file")]),
  288. (nonneg_data, save_data, [("out_file", "data_file")]),
  289. (mask_mask, save_data, [("out_file", "mask_file")]),
  290. (rotate_bvec, save_data, [("rotated_file", "bvec_file")])])
  291. # Run workflow
  292. hcpdiff_wf.write_graph()
  293. if config["condordag"]:
  294. hcpdiff_wf.run(
  295. plugin="CondorDAGMan",
  296. plugin_args={
  297. "dagman_args": f"-outfile_dir {config['tmp_dir']} -import_env",
  298. "wrapper_cmd": Path(files(utilities) / "venv_wrapper.sh"),
  299. "override_specs": "request_memory = 5 GB\nrequest_cpus = 1"})
  300. else:
  301. hcpdiff_wf.run()
  302. if __name__ == "__main__":
  303. main()