rparallel-processingtargets-r-packagefuture.callr

Improve parallel performance with batching in a static-dynamic branching pipeline


BLUF: I am struggling to understand out how to use batching in the R targets package to improve performance in a static and dynamic branching pipeline processed in parallel using tar_make_future(). I presume that I need to batch within each dynamic branch but I am unsure how to go about doing that.

Here's a reprex that uses dynamic branching nested inside static branching, similar to what my actual pipeline is doing. It first branches statically for each value in all_types, and then dynamically branches within each category. This code produces 1,000 branches and 1,010 targets total. In the actual workflow I obviously don't use replicate, and the dynamic branches vary in number depending on the type value.

# _targets.r

library(targets)
library(tarchetypes)
library(future)
library(future.callr)

plan(callr)

all_types = data.frame(type = LETTERS[1:10])

tar_map(values = all_types, names = "type",
  tar_target(
    make_data,
    replicate(100,
      data.frame(x = seq(1000) + rnorm(1000, 0, 5),
                 y = seq(1000) + rnorm(1000, 20, 20)),
      simplify = FALSE
    ),
    iteration = "list"
  ),
  tar_target(
    fit_model,
    lm(make_data),
    pattern = map(make_data),
    iteration = "list"
  )
)

And here's a timing comparison of tar_make() vs tar_make_future() with eight workers:

# tar_destroy()
t1 <- system.time(tar_make())
# tar_destroy()
t2 <- system.time(tar_make_future(workers = 8))

rbind(serial = t1, parallel = t2)

##          user.self sys.self elapsed user.child sys.child
## serial        2.12     0.11   25.59         NA        NA
## parallel      2.07     0.24  184.68         NA        NA

I don't think the user or system fields are useful here since the job gets dispatched to separate R processes, but the elapsed time for the parallel job takes about 7 times longer than the serial job.

I presume this slowdown is caused by the large number of targets. Will batching improve performance in this case, and if so how can I implement batching within the dynamic branch?


Solution

  • You are on the right track with batching. In your case, that is a matter of breaking up your list of 100 datasets into groups of, say, 10 or so. You could do this with a nested list of datasets, but that's a lot of work. Luckily, there is an easier way.

    Your question is actually really well-timed. I just wrote some new target factories in tarchetypes that could help. To access them, you will need the development version of tarchetypes from GitHub:

    remotes::install_github("ropensci/tarchetypes")
    

    Then, with tar_map2_count(), it will be much easier to batch your list of 100 datasets for each scenario.

    library(targets)
    tar_script({
      library(broom)
      library(targets)
      library(tarchetypes)
      library(tibble)
    
      make_data <- function(n) {
        datasets_per_batch <- replicate(
          100,
          tibble(
            x = seq(n) + rnorm(n, 0, 5),
            y = seq(n) + rnorm(n, 20, 20)
          ),
          simplify = FALSE
        )
        tibble(dataset = datasets_per_batch, rep = seq_along(datasets_per_batch))
      }
    
      tar_map2_count(
        name = model,
        command1 = make_data(n = rows),
        command2 = tidy(lm(y ~ x, data = dataset)), # Need dataset[[1]] in tarchetypes 0.4.0
        values = data_frame(
          scenario = LETTERS[seq_len(10)],
          rows = seq(10, 100, length.out = 10)
        ),
        columns2 = NULL,
        batches = 10
      )
    })
    tar_make(reporter = "silent")
    #> Warning message:
    #> `data_frame()` was deprecated in tibble 1.1.0.
    #> Please use `tibble()` instead.
    #> This warning is displayed once every 8 hours.
    #> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
    tar_read(model)
    #> # A tibble: 2,000 × 8
    #>    term        estimate std.error statistic   p.value scenario  rows tar_group
    #>    <chr>          <dbl>     <dbl>     <dbl>     <dbl> <chr>    <dbl>     <int>
    #>  1 (Intercept)   17.1      12.8       1.34  0.218     A           10        10
    #>  2 x              1.39      1.35      1.03  0.333     A           10        10
    #>  3 (Intercept)    6.42     14.0       0.459 0.658     A           10        10
    #>  4 x              1.75      1.28      1.37  0.209     A           10        10
    #>  5 (Intercept)   32.8       7.14      4.60  0.00176   A           10        10
    #>  6 x             -0.300     1.14     -0.263 0.799     A           10        10
    #>  7 (Intercept)   29.7       3.24      9.18  0.0000160 A           10        10
    #>  8 x              0.314     0.414     0.758 0.470     A           10        10
    #>  9 (Intercept)   20.0      13.6       1.47  0.179     A           10        10
    #> 10 x              1.23      1.77      0.698 0.505     A           10        10
    #> # … with 1,990 more rows
    

    Created on 2021-12-10 by the reprex package (v2.0.1)

    There is also tar_map_rep(), which may be easier if all your datasets are randomly generated, but I am not sure if I am overfitting your use case.

    library(targets)
    tar_script({
      library(broom)
      library(targets)
      library(tarchetypes)
      library(tibble)
    
      make_one_dataset <- function(n) {
        tibble(
          x = seq(n) + rnorm(n, 0, 5),
          y = seq(n) + rnorm(n, 20, 20)
        )
      }
    
      tar_map_rep(
        name = model,
        command = tidy(lm(y ~ x, data = make_one_dataset(n = rows))),
        values = data_frame(
          scenario = LETTERS[seq_len(10)],
          rows = seq(10, 100, length.out = 10)
        ),
        batches = 10,
        reps = 10
      )
    })
    tar_make(reporter = "silent")
    #> Warning message:
    #> `data_frame()` was deprecated in tibble 1.1.0.
    #> Please use `tibble()` instead.
    #> This warning is displayed once every 8 hours.
    #> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
    tar_read(model)
    #> # A tibble: 2,000 × 10
    #>    term    estimate std.error statistic p.value scenario  rows tar_batch tar_rep
    #>    <chr>      <dbl>     <dbl>     <dbl>   <dbl> <chr>    <dbl>     <int>   <int>
    #>  1 (Inter…   37.5        7.50     5.00  0.00105 A           10         1       1
    #>  2 x         -0.701      1.17    -0.601 0.564   A           10         1       1
    #>  3 (Inter…   21.5        9.64     2.23  0.0567  A           10         1       2
    #>  4 x         -0.213      1.55    -0.138 0.894   A           10         1       2
    #>  5 (Inter…   20.6        9.51     2.17  0.0620  A           10         1       3
    #>  6 x          1.40       1.79     0.783 0.456   A           10         1       3
    #>  7 (Inter…   11.6       11.2      1.04  0.329   A           10         1       4
    #>  8 x          2.34       1.39     1.68  0.131   A           10         1       4
    #>  9 (Inter…   26.8        9.16     2.93  0.0191  A           10         1       5
    #> 10 x          0.288      1.10     0.262 0.800   A           10         1       5
    #> # … with 1,990 more rows, and 1 more variable: tar_group <int>
    

    Created on 2021-12-10 by the reprex package (v2.0.1)

    Unfortunately, futures do come with overhead. Maybe it will be faster in your case if you try tar_make_clustermq()?