pythondaskdask-distributed

Understanding task stream and speeding up Distributed Dask


I have implemented some data analysis in Dask using dask-distributed, but the performance is very far from the same analysis implemented in numpy/pandas and I am finding it difficult to understand the task stream and memory consumption.

The class that sets up the cluster looks like:

class some_class():
    
        def __init__(self,engine_kwargs: dict = None):
            self.distributed = engine_kwargs.get("distributed", False)
            self.dask_client = None
                    self.n_workers = engine_kwargs.get(
                "n_workers", int(os.getenv("SLURM_CPUS_PER_TASK", os.cpu_count()))
            )
    
        @contextmanager
        def dask_context(self):
            """Dask context manager to set up and close down client"""
            if self.distributed:
                if self.distributed_mode == "processes":
                    processes = True
                dask_cluster = LocalCluster(n_workers=self.n_workers, processes=processes)
                dask_client = Client(self.dask_cluster)
    
            try:
                yield
            finally:
                if dask_client is not None:
                    dask_client.close()
                    local_cluster.close()

And I have something like the following method, which does the analysis:

    def correct(self,
        segy_container: "SegyFileContainer",
        v_sb: int,
        interp_type: int = 1,
        brute_downsample: int = None,)
        """"
        :param segy_container: Container for the seg-y path and data
        :type segy_container: SegyFileContainer
        :param interp_type: Interpolation type either 1=linear or 3=cubic, defaults to 1
        :type interp_type: int, optional
        :param brute_downsample: If you wish to down sample the data to get a brute stack, defaults to None
        :type brute_downsample: int, optional
        :param v_sb: NMO velocity, defaults to 1500
        :type v_sb: int, optional
        :return: NMO corrected gather
        :rtype: pd.DataFrame
        """
        min_cmp = segy_container.trace_headers["CDP"].values.min()
        max_cmp = segy_container.trace_headers["CDP"].values.max()
        groups = segy_container.trace_headers["CDP"]
        cdp_series = segy_container.trace_headers["CDP"]

        cdp_dataarray = xr.DataArray(cdp_series, dims=["trace"])

        dg_cmp = segy_container.segy_file.data.groupby(cdp_dataarray)
        dt_s = segy_container.segy_file.attrs["sample_rate"]
        hg_cmp = segy_container.trace_headers.groupby(
            segy_container.trace_headers["CDP"]
        )
        segy_container.trace_headers["CDP"].iloc[hg_cmp.indices.get(100)]
        tasks = [
            delayed(self._process_group)(
                segy_container, cmp_index, dg_cmp, hg_cmp, v_sb, interp_type, dt_s
            )
            for cmp_index in range(min_cmp, max_cmp + 1)
        ]

        with self.dask_context() as dc:
            results = compute(*tasks, scheduler=dc)

    def _process_group(
        self,
        segy_container,
        cmp_index,
        dg_cmp,
        hg_cmp,
        v_sb: int,
        interp_type: int,
        dt_s: int,
    ):

        cmp = (
            segy_container.segy_file.data[dg_cmp.groups[cmp_index]]
            .transpose()
            .compute()
        )
        offsets = hg_cmp.get_group(cmp_index)["offset"]
        nmo = self._nmo_correction(
            cmp=cmp,
            dt=dt_s / 1000,
            offsets=offsets,
            velocity=v_sb,
            interp_type=interp_type,
        )
        return nmo

    def _nmo_correction(
        self, cmp, dt: float, offsets, velocity: float, interp_type: int
    ):
        nmo_trace = da.zeros_like(cmp)
        nsamples = cmp.data.shape[0]
        times = da.arange(0, nsamples * dt, dt)

        for ind, offset in enumerate(offsets):
            reflected_times = self._reflection_time(times, offset, velocity)
            amplitude = self._sample_trace(
                reflected_times=reflected_times,
                trace=cmp.data[:, ind],
                dt=dt,
                interp_type=interp_type,
            )

            if amplitude is not None:
                nmo_trace[:, ind] = amplitude

        return nmo_trace

    def _reflection_time(self, t0, x, vnmo):
        t = da.sqrt(t0**2 + x**2 / vnmo**2)
        return t.compute()

    def _sample_trace(self, reflected_times, trace, dt, interp_type):
        times = np.arange(trace.size) * dt
        times = xr.DataArray(times)  
        reflected_times = xr.DataArray(reflected_times, dims="reflected_times")

        out_of_bounds = (reflected_times < times[0]) | (reflected_times > times[-1])
        if interp_type == 1:
            amplitude = np.interp(reflected_times, times, trace)
        elif interp_type == 3:
            polyfit = CubicSpline(times, trace)
            amplitude = polyfit(reflected_times)
        else:
            raise ValueError(
                f"Error in interpolating sample trace. interp_type should be either 1 or 3: {interp_type}"
            )

        amplitude[out_of_bounds.compute()] = 0.0
        return amplitude

I have the same thing implemented using numpy and pandas, and the runtime is 3 secs. For Dask-distributed in the way shown it is taking around 15 mins. If I just use scheduler=processes and not the cluster it takes about 4 mins.

I understand there will be overhead in setting up and using the cluster, but am trying to understand how to improve the run time.

Looking at the diagnostics in the Dask dash give some quite confusing graphs:

Dask Overview Dask Tasks Dask Tasks 2

I understand why there maybe more streams than the 10 workers I have created in this case, but am finding it hard to understand what exactly is going on here. I also dont understand why the memory usage is so high - as the file I am looking at is 715 Mb.

Any advice or insight on how to

  1. Understand the task stream
  2. Speed up the Dask-distributed code
  3. Understand why the memory usage is so high

Would be very much appreciated!


Solution

  • When your data fits easily into memory, it should be much faster to use your numerical package (numpy, pandas, xarray, ...) directly without dask.

    The basic reason, is that to get a task to a worker, you must copy all the data it will need to work on, transfer it to the worker, have it do the job, and then reverse the process to get the result back. Many workers are, of course, doing this simultaneously. Here are a few things to keep in mind:

    You might consider using the threaded scheduler for this particular workflow, since that avoids copying the underlying data and much of the overhead, but you will still have temporary arrays associated with several parallel tasks in memory at once.

    Other possible solutions:

    Keep in mind, that none of this addresses what your actual code is doing, because I cannot follow it at all. You have compute() calls in multiple places, which suggests to me that you might have several dask roundtrips going on.