rapache-spark-sqlapplysparklyrprocessing-efficiency

Sparklyr spark_apply function on equal groups to run efficiently


How do I run a custom function efficiently in chunks within sparklyr environment?

I have a haversine function to calculate distance between 2 sets of Lat long within 1 data frame. As you can imagine, 10 customers to 10 store locations would generate 100 rows. I have 10 million customers with 500 stores. That's 5 billion rows. Running a join of this magnitude, then subsequently calculating this distance function may put a strain if not crash my spark environment.

My thought is to first do the join, which will generate the 5 billion rows, then spark_apply the function in equal chunks separately, then append it back. To do this locally in R, i would do a For Loop, and repeated process this sequentially in chunks, then append it back. How should I do this in Spark?

Here's what i've got so far. Appreciate your help with the syntax to run the spark_apply function based on the equal_groups column. Assuming this is the best strategy to run this efficiently.

If it has anything to do with this function, i'm having trouble applying this to my syntax. https://www.rstudio.com/blog/sparklyr-1-2/

library(tidyverse)

#sample data
df <- tibble(
  place=c("Finland", "Canada", "Tanzania", "Bolivia", "France"),
  longitude=c(27.472918, -90.476303, 34.679950, -65.691146, 4.533465),
  latitude=c(63.293001, 54.239631, -2.855123, -13.795272, 48.603949))

from <- 
  df[1:3,] %>%  #pick first 3 rows
  rename(long1 = longitude,
         lat1 = latitude)
to <- df[4:5,] %>%  #pick last 2 rows
  rename(long2 = longitude,
         lat2 = latitude)

#increase data size
n <- 100
from_many <- 
  do.call("rbind", replicate(n, from, simplify = FALSE)) %>% 
  mutate(place = row_number())


library(sparklyr)
sc <- spark_connect(master = "local")

from_many_sf <- copy_to(sc, from_many,overwrite = TRUE)
to_sf <- copy_to(sc, to,overwrite = TRUE)

# --- haversine distance function ---
get_geodesic_distance = function(x){
  
  geolocation = function(long1, lat1, long2, lat2){
    
    deg2rad <- function(deg) return(deg*pi/180)
    
    # Convert degrees to radians
    long1 <- deg2rad(long1)
    lat1 <- deg2rad(lat1)
    long2 <- deg2rad(long2)
    lat2 <- deg2rad(lat2)
    
    R = 6378137 #6371 Mean radius of the earth in km # 6378137 meters
    
    diff.long = (long2-long1)
    diff.lat = (lat2-lat1)
    
    a =(sin(diff.lat/2) * sin(diff.lat/2) + cos(lat1) * cos(lat2) * sin(diff.long/2)* sin(diff.long/2))
    c= 2*atan2(sqrt(a),sqrt(1-a))
    d = R*c
    
    return(d) #Distance in km
  }
  
  dist_vec = geolocation(x$long1, x$lat1, x$long2, x$lat2)
  
  res = dplyr::mutate(x, distance = dist_vec)
  res
}

#expansive join
full_sf <- 
  from_many_sf %>% 
  full_join(to_sf, by = character())

full_sf %>% tally

# application of distance calculation
full_sf %>% 
  spark_apply(get_geodesic_distance)

#cut data into groups
  full_sf %>% 
    sdf_with_sequential_id(., id = "id", from = 1L) %>% 
    mutate(equal_groups = ntile(id, 4)) %>% 
    group_by(equal_groups) %>% 
    tally()

Results of the spark_apply

# Source: spark<?> [?? x 7]
   place_x long1  lat1 place_y  long2  lat2  distance
     <int> <dbl> <dbl> <chr>    <dbl> <dbl>     <dbl>
 1       1  27.5 63.3  Bolivia -65.7  -13.8 11545583.
 2       1  27.5 63.3  France    4.53  48.6  2148220.
 3       2 -90.5 54.2  Bolivia -65.7  -13.8  7929331.
 4       2 -90.5 54.2  France    4.53  48.6  6111619.
 5       3  34.7 -2.86 Bolivia -65.7  -13.8 11061340.
 6       3  34.7 -2.86 France    4.53  48.6  6427714.
 7       4  27.5 63.3  Bolivia -65.7  -13.8 11545583.
 8       4  27.5 63.3  France    4.53  48.6  2148220.
 9       5 -90.5 54.2  Bolivia -65.7  -13.8  7929331.
10       5 -90.5 54.2  France    4.53  48.6  6111619.
# … with more rows
# ℹ Use `print(n = ...)` to see more rows

Solution

  • At least as I see it above, I recommend you solve your problem above purely in Spark. That means:

    You can run arbitrary R code in each worker node to run any computation ... If you are already familiar with R, you might be tempted to use this approach for all Spark operations; however, this is not the recommended use of spark_apply(). Previous chapters provided more efficient techniques and tools to solve well-known problems Mastering Spark in R: Chapter 11

    Doing it all in Spark is possible because Spark understands dplyr functions and has a ton of commands that have the exact same name in R (or nearly so) in the Hive UDF, e.g., sin, cos, etc.

    Essentially, you can let Spark figure out the hard part of distributing the work (within the for loops used to make smaller chunks of work, if any) by just doing a copy and paste of your code out of the function you defined:

    library(magrittr) ##### if not already imported
    
    full_sf <- full_sf %>% 
      dplyr::mutate(group = dplyr::ntile(place_x, n = 4))
    
    for(grp in 1:4) {
      full_sf %>% 
        dplyr::filter(group == grp) %>%
        dplyr::relocate(place_y) %>% 
        dplyr::mutate(
          dplyr::across(long1:lat2, ~ .x * pi/180), 
          diff.long = long2 - long1, 
          diff.lat = lat2 - lat1, 
          a = (sin(diff.lat/2) * sin(diff.lat/2) + 
            cos(lat1) * cos(lat2) * sin(diff.long/2)* sin(diff.long/2)), 
          c = 2*atan2(sqrt(a),sqrt(1-a))) %>% 
        dplyr::mutate(d = 6378137 * c) %>% 
        print()
    }
    
    # Source: spark<?> [?? x 12]
       place_y place_x  long1    lat1   long2   lat2 group diff.long diff.lat      a     c         d
       <chr>     <int>  <dbl>   <dbl>   <dbl>  <dbl> <int>     <dbl>    <dbl>  <dbl> <dbl>     <dbl>
     1 Bolivia       1  0.479  1.10   -1.15   -0.241     1    -1.63   -1.35   0.619  1.81  11545583.
     2 France        1  0.479  1.10    0.0791  0.848     1    -0.400  -0.256  0.0281 0.337  2148220.
     3 Bolivia       2 -1.58   0.947  -1.15   -0.241     1     0.433  -1.19   0.339  1.24   7929331.
     4 France        2 -1.58   0.947   0.0791  0.848     1     1.66   -0.0984 0.213  0.958  6111619.
     5 Bolivia       3  0.605 -0.0498 -1.15   -0.241     1    -1.75   -0.191  0.581  1.73  11061340.
     6 France        3  0.605 -0.0498  0.0791  0.848     1    -0.526   0.898  0.233  1.01   6427714.
     7 Bolivia       4  0.479  1.10   -1.15   -0.241     1    -1.63   -1.35   0.619  1.81  11545583.
     8 France        4  0.479  1.10    0.0791  0.848     1    -0.400  -0.256  0.0281 0.337  2148220.
     9 Bolivia       5 -1.58   0.947  -1.15   -0.241     1     0.433  -1.19   0.339  1.24   7929331.
    10 France        5 -1.58   0.947   0.0791  0.848     1     1.66   -0.0984 0.213  0.958  6111619.
    # … with more rows
    # ℹ Use `print(n = ...)` to see more rows
    # Source: spark<?> [?? x 12]
       place_y place_x  long1    lat1   long2   lat2 group diff.long diff.lat      a     c         d
       <chr>     <int>  <dbl>   <dbl>   <dbl>  <dbl> <int>     <dbl>    <dbl>  <dbl> <dbl>     <dbl>
     1 France        8 -1.58   0.947   0.0791  0.848     2     1.66   -0.0984 0.213  0.958  6111619.
     2 Bolivia       9  0.605 -0.0498 -1.15   -0.241     2    -1.75   -0.191  0.581  1.73  11061340.
     3 France        9  0.605 -0.0498  0.0791  0.848     2    -0.526   0.898  0.233  1.01   6427714.
     4 Bolivia      10  0.479  1.10   -1.15   -0.241     2    -1.63   -1.35   0.619  1.81  11545583.
     5 France       10  0.479  1.10    0.0791  0.848     2    -0.400  -0.256  0.0281 0.337  2148220.
     6 Bolivia      11 -1.58   0.947  -1.15   -0.241     2     0.433  -1.19   0.339  1.24   7929331.
     7 France       11 -1.58   0.947   0.0791  0.848     2     1.66   -0.0984 0.213  0.958  6111619.
     8 Bolivia      12  0.605 -0.0498 -1.15   -0.241     2    -1.75   -0.191  0.581  1.73  11061340.
     9 France       12  0.605 -0.0498  0.0791  0.848     2    -0.526   0.898  0.233  1.01   6427714.
    10 Bolivia      13  0.479  1.10   -1.15   -0.241     2    -1.63   -1.35   0.619  1.81  11545583.
    # … with more rows
    # ℹ Use `print(n = ...)` to see more rows
    # Source: spark<?> [?? x 12]
       place_y place_x  long1    lat1   long2   lat2 group diff.long diff.lat      a     c         d
       <chr>     <int>  <dbl>   <dbl>   <dbl>  <dbl> <int>     <dbl>    <dbl>  <dbl> <dbl>     <dbl>
     1 Bolivia      16  0.479  1.10   -1.15   -0.241     3    -1.63   -1.35   0.619  1.81  11545583.
     2 France       16  0.479  1.10    0.0791  0.848     3    -0.400  -0.256  0.0281 0.337  2148220.
     3 Bolivia      17 -1.58   0.947  -1.15   -0.241     3     0.433  -1.19   0.339  1.24   7929331.
     4 France       17 -1.58   0.947   0.0791  0.848     3     1.66   -0.0984 0.213  0.958  6111619.
     5 Bolivia      18  0.605 -0.0498 -1.15   -0.241     3    -1.75   -0.191  0.581  1.73  11061340.
     6 France       18  0.605 -0.0498  0.0791  0.848     3    -0.526   0.898  0.233  1.01   6427714.
     7 Bolivia      19  0.479  1.10   -1.15   -0.241     3    -1.63   -1.35   0.619  1.81  11545583.
     8 France       19  0.479  1.10    0.0791  0.848     3    -0.400  -0.256  0.0281 0.337  2148220.
     9 Bolivia      20 -1.58   0.947  -1.15   -0.241     3     0.433  -1.19   0.339  1.24   7929331.
    10 France       20 -1.58   0.947   0.0791  0.848     3     1.66   -0.0984 0.213  0.958  6111619.
    # … with more rows
    # ℹ Use `print(n = ...)` to see more rows
    # Source: spark<?> [?? x 12]
       place_y place_x  long1    lat1   long2   lat2 group diff.long diff.lat      a     c         d
       <chr>     <int>  <dbl>   <dbl>   <dbl>  <dbl> <int>     <dbl>    <dbl>  <dbl> <dbl>     <dbl>
     1 France       23 -1.58   0.947   0.0791  0.848     4     1.66   -0.0984 0.213  0.958  6111619.
     2 Bolivia      24  0.605 -0.0498 -1.15   -0.241     4    -1.75   -0.191  0.581  1.73  11061340.
     3 France       24  0.605 -0.0498  0.0791  0.848     4    -0.526   0.898  0.233  1.01   6427714.
     4 Bolivia      25  0.479  1.10   -1.15   -0.241     4    -1.63   -1.35   0.619  1.81  11545583.
     5 France       25  0.479  1.10    0.0791  0.848     4    -0.400  -0.256  0.0281 0.337  2148220.
     6 Bolivia      26 -1.58   0.947  -1.15   -0.241     4     0.433  -1.19   0.339  1.24   7929331.
     7 France       26 -1.58   0.947   0.0791  0.848     4     1.66   -0.0984 0.213  0.958  6111619.
     8 Bolivia      27  0.605 -0.0498 -1.15   -0.241     4    -1.75   -0.191  0.581  1.73  11061340.
     9 France       27  0.605 -0.0498  0.0791  0.848     4    -0.526   0.898  0.233  1.01   6427714.
    10 Bolivia      28  0.479  1.10   -1.15   -0.241     4    -1.63   -1.35   0.619  1.81  11545583.
    # … with more rows
    # ℹ Use `print(n = ...)` to see more rows