pythondaskfuturedask-distributeddask-delayed

Dask performs recomputation in branched graphs


Suppose, I create the following graph:

import dask
import time


@dask.delayed
def step_1():
    print("Running Step 1")
    time.sleep(1)
    return True

@dask.delayed
def step_2(prev_step):
    print("Running Step 2")
    time.sleep(1)
    return True

@dask.delayed
def step_3a(prev_step):
    print("Running Step 3a")
    time.sleep(1)
    return True

@dask.delayed
def step_3b(prev_step):
    print("Running Step 3b")
    time.sleep(1)
    return True
stp_1 = step_1()
stp_2 = step_2(stp_1)
stp_3a = step_3a(stp_2)
stp_3b = step_3b(stp_2)
from dask import visualize

visualize([stp_3a, stp_3b])

branched computation graph

from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=1, threads_per_worker=3, dashboard_address="localhost:27998")
client = Client(cluster)
client

Now, I compute step_3a and it should take about 3 seconds.


start = time.perf_counter()

stp_3a_futures = client.compute(stp_3a) # So that the future stays in memory

stp_3a_results = client.gather(stp_3a_futures)

duration = time.perf_counter() - start

print(duration)
[Out]: 3.1600782200694084

This makes sense. But now, when I execute step_3b, I expect it to finish in one second since it has already computed step_1 and step_2. But, unfortunately, it doesn't keep those two steps in memory and the computation for step_3b also takes 3 seconds:


start = time.perf_counter()

stp_3b_futures = client.compute(stp_3b) # So that the future stays in memory

stp_3b_results = client.gather(stp_3b_futures)

duration = time.perf_counter() - start

print(duration)

[Out]: 3.0438701044768095

Now, my question is:

I know I can call client.persist() on stp_2 but that's not the answer I'm looking for. In my use-case, when I'll be computing step_3a, I won't have any reference to the delayed object for step_2.

thank you in advance for those of you who can answer. :)


Solution

  • The graphchain works well with the recent dask version:

    from time import sleep
    
    from dask import delayed
    from dask.config import set as dask_set
    from graphchain import optimize
    
    
    @delayed
    def step_1():
        print("Running Step 1")
        sleep(1)
        return True
    
    
    @delayed
    def step_2(prev_step):
        print("Running Step 2")
        sleep(1)
        return True
    
    
    @delayed
    def step_3a(prev_step):
        print("Running Step 3a")
        sleep(1)
        return True
    
    
    @delayed
    def step_3b(prev_step):
        print("Running Step 3b")
        sleep(1)
        return True
    
    
    stp_1 = step_1()
    stp_2 = step_2(stp_1)
    stp_3a = step_3a(stp_2)
    stp_3b = step_3b(stp_2)
    

    Now, the computations:

    %time stp_3a.compute()
    # Running Step 1
    # Running Step 2
    # Running Step 3a
    # CPU times: user 330 ms, sys: 14.3 ms, total: 344 ms
    # Wall time: 3.01 s
    
    %time stp_3b.compute()
    # Running Step 1
    # Running Step 2
    # Running Step 3b
    # CPU times: user 6.4 ms, sys: 3.03 ms, total: 9.43 ms
    # Wall time: 3.01 s
    
    with dask_set(delayed_optimize=optimize):
        
        %time stp_3a.compute()
        # Running Step 1
        # Running Step 2
        # Running Step 3a
        # CPU times: user 364 ms, sys: 20.8 ms, total: 385 ms
        # Wall time: 3.04 s
        
        %time stp_3b.compute()
        # Running Step 3b
        # CPU times: user 5 ms, sys: 2.97 ms, total: 7.97 ms
        # Wall time: 1.01 s