rmatrixtidyverseapplyformgroups

Apply command for complex functions and calculations on a dataset in R


I'm a reasonably experienced R user who has often struggled to use the apply family. I have very slow-moving iterative code whose performance I'm hoping to improve through the use of this family, but am having difficulty. I will simplify the use case greatly here so please assume no obvious workarounds.

I have a dataset that consists of four observations assigned to 5 possible groups (the actual use case is 50,000 observations with 1110 possible groups) and two output variables. I would like to group each observation by assignment and then do something with the outputs (here, to simplify, I will say the mean sum of squares for each. The actual output is much more complicated). My iterative approach gives me what I want, and looks like this:

library(tidyverse)
set.seed(8675309)

#create toy data
dataset <- data.frame(obs_1 = round(runif(100, 1, 5)),
                      obs_2 = round(runif(100, 1, 5)),
                      obs_3 = round(runif(100, 1, 5)),
                      obs_4 = round(runif(100, 1, 5)),
                      val_1 = rnorm(100, 0, 5),
                      val_2 = rnorm(100, 0, 15))

#define a function to create the output for each group
cals <- function(df){
  var <- df %>%
    group_by(group) %>%
    summarise(x1 = sum(val_1),
              x2 = sum(val_2)) %>%
    mutate(x1 = x1^2,
           x2 = x2^2) %>%
    mutate(ans = x1  + x2) %>%
    pull(ans)
  return(var)
}

#initialize output matrix
answer <- matrix(rep(NA, 20), 5)

#loops -- ugh
for(i in 1:4){
#pull each group list and the two output variables
  df_used <- dataset %>%
    select(i, val1, val2)

#give the group list a common name so the function can identify it
  names(df_used)[1] <- 'group'

#calculate output using the function
  cal <- cals(df_used)

#write this into the output matrix
  answer[, i] <- cal
}

answer
# Result:
          [,1]        [,2]       [,3]       [,4]
[1,]  1159.463  197.090174   302.4915   320.8285
[2,] 15820.498 1975.668791   294.3433  7070.0387
[3,]  2423.859  537.334344 13256.3443  1331.7600
[4,]  4646.915 1900.430230  1836.5904 17242.5160
[5,]  9403.906    4.785014  1449.9531  1588.6278

I think, though, there must be a faster, less unsightly way(?)


Solution

  • mapply is probably what you're after. Here is a data.table version:

    library(data.table)
    
    dt <- as.data.table(dataset)
    mapply(\(x) setorder(dt[,.(sum(val_1)^2 + sum(val_2)^2), x], x)[[2]], dt[,1:4])
    #>           obs_1     obs_2     obs_3     obs_4
    #> [1,]   524.9378  1220.855 1780.1158  786.5803
    #> [2,]  2890.6006 10847.766 6224.3217 7760.9268
    #> [3,] 18436.0742  2667.610 3879.1027  466.2114
    #> [4,]  6888.7064  1774.418 2644.9105 1149.2653
    #> [5,]  3169.8326  3691.997  676.0297 2821.5822
    

    Parallel

    With 50K observation columns and only two value columns you probably want to compute in parallel, if possible. Below is an example with 50K observation columns, 1110 possible groups, and 2K values in val_1 and val_2. It runs in a reasonable amount of time.

    obs <- as.data.frame(
      matrix(sample(1110, 1e7, 1), 2e3, 5e4, 0, list(NULL, paste0("obs_", 1:5e4)))
    )
    vals <- data.table(val_1 = rnorm(2e3, 0, 5), val_2 = rnorm(2e3, 0, 15))
    
    system.time({
      # insert your function here
      f <- function(val1, val2) sum(val1^2) + sum(val2^2)
      library(parallel)
      cl <- makeCluster(detectCores() - 1)
      clusterExport(cl, c("f", "vals"))
      clusterEvalQ(cl, library(data.table))
      answer2 <- simplify2array(
        parLapply(
          cl, obs,
          \(x) {
            y <- numeric(1110)
            y[unique(x)] <- setorder(vals[,.(f(val_1, val_2)), x], x)[[2]]
            y
          }
        )
      )
      
      stopCluster(cl)
    })
    #>    user  system elapsed 
    #>    1.06    1.96   15.11
    
    dim(answer2)
    #> [1]  1110 50000
    
    answer2[1:10, 1:5]
    #>           obs_1     obs_2     obs_3      obs_4      obs_5
    #>  [1,]  301.8518 378.50549 1604.9906    0.00000   62.03574
    #>  [2,] 1216.5158 280.03548    0.0000   79.42371  221.81035
    #>  [3,]    0.0000   0.00000  201.5036    0.00000  272.46706
    #>  [4,]    0.0000 102.12533  239.0345  769.74224 2008.39479
    #>  [5,]  956.5008  47.84919  251.6572 1967.67512 1510.94146
    #>  [6,]  257.8219  73.64866  213.9344  211.03523  647.27991
    #>  [7,]  811.1412 274.54819  428.2221  731.54683  839.51485
    #>  [8,]  958.2328 158.62962  358.5906  502.11146    0.00000
    #>  [9,]  556.0048 741.85957 1135.0711  924.31785  332.33795
    #> [10,] 1126.8460   0.00000  421.9577  209.50286  184.39162