pythonsnakemake

Snakemake : How to tranverse DAG depth-first with aggregate rule at the end?


Snakemake version : 6.3.0

I want to apply DAG depth-first to my snakemake for rules producing a similar number of files, up to a rule aggregating several files according to a wildcard.

For now, I've applied the advice given in this thread: Snakemake: Tranverse DAG depth-first?

I've prioritized the steps so that my snakemake is in depth mode up to the rule sort_and_index_binning. However, the files are still quite large at the end of the sorting step, and what I'd like to do is reach the step that aggregates the depth (summarize_contig_depth) as quickly as possible for each sample so that I don't have any temporary files from the previous steps. In the end, the idea would be to run the wildcard src1 sample by sample but I don't see how to do it automatically.

For example with 2 samples sample_1 and sample_2 I would like to do :

binning_mapping : sample_1_to_sample_2, sample_2_to_sample_2 (here sample_1 and sample_2 are src wildcards and sample_2 is src1 wildcard)

filter_bam : sample_1_to_sample_2, sample_2_to_sample_2

sort_and_index_binning : sample_1_to_sample_2, sample_2_to_sample_2

summarize_contig_depth : depth_sample_2

THEN

binning_mapping : sample_1_to_sample_1, sample_2_to_sample_1 (here sample_1 and sample_2 are src wildcards and sample_1 is src1 wildcard)

filter_bam : sample_1_to_sample_1, sample_2_to_sample_1

sort_and_index_binning : sample_1_to_sample_1, sample_2_to_sample_1

summarize_contig_depth : depth_sample_1

THEN

a rule using depth_sample_1 and depth_sample_2


def input_cmd(wildcards):
    if wildcards.assembly == "single_assembly":
        list_reads = []
        for run in reads2use[wildcards.src]:
            list_reads.extend(reads2use[wildcards.src][run])
        return list_reads
    elif wildcards.assembly == "co_assembly":
        if simka_type is "None":
            return os.path.join(tmpdir, 'samples.txt')
        return os.path.join(intermediate_results_dir, "assembly/co_assembly/clusters/{src}.txt")
    else:
        raise ValueError

rule binning_mapping:
    '''
        Align the source reads files against the assembled contigs file to assess contigs' abundance.
    '''
    output:
        temp(os.path.join(tmpdir, "{assembly}/{src}_to_{src1}_" + f"{index}_bin_filtering.sam"))
    input:
        assembly = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}", f"contigs/{assembly}"),
        index1 = expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{{src1}}/index/{{src1}}_" + index + "_filtering.{id}.bt2l"), id=range(1, 4)),
        index2 = expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{{src1}}/index/{{src1}}_" + index + "_filtering.rev.{id}.bt2l"), id=range(1,2)),
        reads = input_cmd,
        finished_assembly = os.path.join(tmp, "assembly.checkpoint")
    params:
        prefix = os.path.join("intermediate_results/assembly/{assembly}", assembler, "{src1}/index/{src1}_" + f"{index}_filtering"),
        input_reads = lambda wildcards, input : cmdparser.cmd(wildcards.src, input.reads, reads2use, "bowtie2").cmd,
        cmd = lambda wildcards : conf.mapping_cmd(config, wildcards.assembly),
    threads: 5
    priority: 1
    conda:
        os.path.join(CONDAENV, "bowtie2.yaml")
    shell:
        "bowtie2 "
        "-p {threads} "             
        "--no-unal "                
        "-x {params.prefix} "      
        "{params.input_reads} "
        "{params.cmd} "
        "-S {output} "

rule filter_bam:
    """
    Filter reads based on mapping quality and identity.
    Output is temporary because it will be sorted.
    """
    output:
        temp(os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/mapped_reads/{src}.filtered.sam")),
    input:
        os.path.join(tmpdir, "{assembly}/{src}_to_{src1}_" + f"{index}_bin_filtering.sam")
    conda:
        os.path.join(CONDAENV, "bamutils.yaml")
    priority: 2
    params:
        min_mapq = config["bam_filtering_before_binning"]["min_quality"],
        min_idt = config["bam_filtering_before_binning"]["min_identity"],
        min_len = config["bam_filtering_before_binning"]["min_len"],
        pp = config["bam_filtering_before_binning"]["properly_paired"],
    script:
        "../scripts/bamprocess.py"

rule sort_and_index_binning:
    output:
        temp(os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/mapped_reads/{src}_to_{src1}.sorted.bam"))
    input:
        os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/mapped_reads/{src}.filtered.sam"),
    threads: 1
    priority: 3
    conda:
        os.path.join(CONDAENV, "samtools.yaml")
    shell:
        "samtools view -u {input} | "
        "samtools sort "
        "-@ {threads} "             
        "-o {output[0]} "

def aggregate_bam_input(wildcards):
    if "CASB" in strategies or "CACB" in strategies:
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
        assembly_dict["co_assembly"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
        assembly_request = "co_assembly"

    if "SASB" in strategies or "SACB" in strategies:
        assembly_dict["single_assembly"] = list(samples.keys())
        assembly_request = "single_assembly"
    inputs = expand(os.path.join(intermediate_results_dir,
                    "assembly",
                    assembly_request,
                    assembler,
                    wildcards.src,
                    "mapped_reads",
                    "{src1}_to_" + wildcards.src + ".sorted.bam"
                    ), src1=assembly_dict.get(assembly_request))

rule summarize_contig_depth:
    '''
        Compute reads coverage depth to perform binning.
    '''
    output:
        os.path.join(intermediate_results_dir, "binning/{binning_strategy}/{src1}/depth.txt"),
    input:
        aggregate_bam_input,
    params:
        bams = aggregate_bam_input,
    conda:
        os.path.join(CONDAENV, "metabat.yaml")
    threads: 5
    priority: 4
    shell:
        "jgi_summarize_bam_contig_depths --outputDepth {output} {params.bams}"

Solution

  • You can't normally* get this level of control over execution order in Snakemake, because it would involve setting priorities to jobs based on the wildcards of the job. In Snakemake you currently have only two options:

    1. Set a numeric priority for a rule (which applies equally to all jobs resulting from that rule)
    2. Force one or more specific output files (ie. specific jobs) to have max priority using the -P command-line option when running Snakemake

    Snakemake doesn't traverse through the DAG as a normal algorithm would. It rather looks for all available jobs where inputs are satisfied then picks the "best" one, or else chooses one at random. The algorithm already attempts to minimise the number of temporary files by prioritising jobs that will allow temporary files to be removed (see https://github.com/snakemake/snakemake/pull/409) but in your case this is not smart enough to optimise your workflow.

    If you are short of space but not short of time, you could just run Snakemake multiple times to create each output in turn and guarantee minimal storage overhead, but here it will be tricky to avoid deleting an output that is wanted for a subsequent run, and so having to re-compute it.

    So is there a way? Well...

    * When I say normally, there is a hack. It's a really dirty hack, but it can be made to work. Consider the following workflow:

    rule main:
        output: "all_combined.out"
        input: "a_plus_b.tmp", "b_plus_c.tmp", "c_plus_d.tmp", "d_plus_e.tmp", "e_plus_f.tmp"
        shell:
            "head -v {input} > {output}"
    
    rule merge1:
        output: temporary("{w1}_plus_{w2}.tmp")
        input: "{w1}.in", "{w2}.in"
        shell:
            """echo "There are $(ls *.in | wc -l) temporary input files."
               head -v {input} > {output}
            """
    
    rule gen_input:
        output: temporary("{x}.in")
        shell:
            r"sed s'/\(.\)/\1\1\1\1/' <<<{wildcards.x} >{output}"
    

    You can run this without any input files, since it generates its own:

    snakemake -F -j 1
    

    Most likely, you will see that at some points in the workflow there are 3 or more temporary input files, but if the DAG was evaluated in an optimal order there would only ever be 2.

    The hack is to use an onstart: handler to introspect the DAG and set priorities on the individual jobs. In this case I just get the jobs to run in alphabetical order of the x wildcard.

    onstart:
        # Fudge the priorities. This is a total hack ;-)
        for n, j in enumerate(sorted([ j for j in workflow.dag.needrun_jobs()
                                       if 'x' in j.wildcards_dict],
                              key=lambda j: j.wildcards_dict['x'],
                              reverse=True)):
            j.dag._priority[j] = -n
    

    I tested this in Snakemake 8 and it works with that. Not sure about Snakemake 6.