How do I run a custom function efficiently in chunks within sparklyr environment?
I have a haversine function to calculate distance between 2 sets of Lat long within 1 data frame. As you can imagine, 10 customers to 10 store locations would generate 100 rows. I have 10 million customers with 500 stores. That's 5 billion rows. Running a join of this magnitude, then subsequently calculating this distance function may put a strain if not crash my spark environment.
My thought is to first do the join, which will generate the 5 billion rows, then spark_apply the function in equal chunks separately, then append it back. To do this locally in R, i would do a For Loop, and repeated process this sequentially in chunks, then append it back. How should I do this in Spark?
Here's what i've got so far. Appreciate your help with the syntax to run the spark_apply function based on the equal_groups column. Assuming this is the best strategy to run this efficiently.
If it has anything to do with this function, i'm having trouble applying this to my syntax. https://www.rstudio.com/blog/sparklyr-1-2/
library(tidyverse)
#sample data
df <- tibble(
place=c("Finland", "Canada", "Tanzania", "Bolivia", "France"),
longitude=c(27.472918, -90.476303, 34.679950, -65.691146, 4.533465),
latitude=c(63.293001, 54.239631, -2.855123, -13.795272, 48.603949))
from <-
df[1:3,] %>% #pick first 3 rows
rename(long1 = longitude,
lat1 = latitude)
to <- df[4:5,] %>% #pick last 2 rows
rename(long2 = longitude,
lat2 = latitude)
#increase data size
n <- 100
from_many <-
do.call("rbind", replicate(n, from, simplify = FALSE)) %>%
mutate(place = row_number())
library(sparklyr)
sc <- spark_connect(master = "local")
from_many_sf <- copy_to(sc, from_many,overwrite = TRUE)
to_sf <- copy_to(sc, to,overwrite = TRUE)
# --- haversine distance function ---
get_geodesic_distance = function(x){
geolocation = function(long1, lat1, long2, lat2){
deg2rad <- function(deg) return(deg*pi/180)
# Convert degrees to radians
long1 <- deg2rad(long1)
lat1 <- deg2rad(lat1)
long2 <- deg2rad(long2)
lat2 <- deg2rad(lat2)
R = 6378137 #6371 Mean radius of the earth in km # 6378137 meters
diff.long = (long2-long1)
diff.lat = (lat2-lat1)
a =(sin(diff.lat/2) * sin(diff.lat/2) + cos(lat1) * cos(lat2) * sin(diff.long/2)* sin(diff.long/2))
c= 2*atan2(sqrt(a),sqrt(1-a))
d = R*c
return(d) #Distance in km
}
dist_vec = geolocation(x$long1, x$lat1, x$long2, x$lat2)
res = dplyr::mutate(x, distance = dist_vec)
res
}
#expansive join
full_sf <-
from_many_sf %>%
full_join(to_sf, by = character())
full_sf %>% tally
# application of distance calculation
full_sf %>%
spark_apply(get_geodesic_distance)
#cut data into groups
full_sf %>%
sdf_with_sequential_id(., id = "id", from = 1L) %>%
mutate(equal_groups = ntile(id, 4)) %>%
group_by(equal_groups) %>%
tally()
Results of the spark_apply
# Source: spark<?> [?? x 7]
place_x long1 lat1 place_y long2 lat2 distance
<int> <dbl> <dbl> <chr> <dbl> <dbl> <dbl>
1 1 27.5 63.3 Bolivia -65.7 -13.8 11545583.
2 1 27.5 63.3 France 4.53 48.6 2148220.
3 2 -90.5 54.2 Bolivia -65.7 -13.8 7929331.
4 2 -90.5 54.2 France 4.53 48.6 6111619.
5 3 34.7 -2.86 Bolivia -65.7 -13.8 11061340.
6 3 34.7 -2.86 France 4.53 48.6 6427714.
7 4 27.5 63.3 Bolivia -65.7 -13.8 11545583.
8 4 27.5 63.3 France 4.53 48.6 2148220.
9 5 -90.5 54.2 Bolivia -65.7 -13.8 7929331.
10 5 -90.5 54.2 France 4.53 48.6 6111619.
# … with more rows
# ℹ Use `print(n = ...)` to see more rows
At least as I see it above, I recommend you solve your problem above purely in Spark. That means:
foreach
package--Spark distributes the work across your workers.sparklyr::spark_apply()
You can run arbitrary R code in each worker node to run any computation ... If you are already familiar with R, you might be tempted to use this approach for all Spark operations; however, this is not the recommended use of spark_apply(). Previous chapters provided more efficient techniques and tools to solve well-known problems Mastering Spark in R: Chapter 11
Doing it all in Spark is possible because Spark understands dplyr
functions and has a ton of commands that have the exact same name in R (or nearly so) in the Hive UDF, e.g., sin
, cos
, etc.
Essentially, you can let Spark figure out the hard part of distributing the work (within the for loops used to make smaller chunks of work, if any) by just doing a copy and paste of your code out of the function you defined:
library(magrittr) ##### if not already imported
full_sf <- full_sf %>%
dplyr::mutate(group = dplyr::ntile(place_x, n = 4))
for(grp in 1:4) {
full_sf %>%
dplyr::filter(group == grp) %>%
dplyr::relocate(place_y) %>%
dplyr::mutate(
dplyr::across(long1:lat2, ~ .x * pi/180),
diff.long = long2 - long1,
diff.lat = lat2 - lat1,
a = (sin(diff.lat/2) * sin(diff.lat/2) +
cos(lat1) * cos(lat2) * sin(diff.long/2)* sin(diff.long/2)),
c = 2*atan2(sqrt(a),sqrt(1-a))) %>%
dplyr::mutate(d = 6378137 * c) %>%
print()
}
# Source: spark<?> [?? x 12]
place_y place_x long1 lat1 long2 lat2 group diff.long diff.lat a c d
<chr> <int> <dbl> <dbl> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <dbl> <dbl>
1 Bolivia 1 0.479 1.10 -1.15 -0.241 1 -1.63 -1.35 0.619 1.81 11545583.
2 France 1 0.479 1.10 0.0791 0.848 1 -0.400 -0.256 0.0281 0.337 2148220.
3 Bolivia 2 -1.58 0.947 -1.15 -0.241 1 0.433 -1.19 0.339 1.24 7929331.
4 France 2 -1.58 0.947 0.0791 0.848 1 1.66 -0.0984 0.213 0.958 6111619.
5 Bolivia 3 0.605 -0.0498 -1.15 -0.241 1 -1.75 -0.191 0.581 1.73 11061340.
6 France 3 0.605 -0.0498 0.0791 0.848 1 -0.526 0.898 0.233 1.01 6427714.
7 Bolivia 4 0.479 1.10 -1.15 -0.241 1 -1.63 -1.35 0.619 1.81 11545583.
8 France 4 0.479 1.10 0.0791 0.848 1 -0.400 -0.256 0.0281 0.337 2148220.
9 Bolivia 5 -1.58 0.947 -1.15 -0.241 1 0.433 -1.19 0.339 1.24 7929331.
10 France 5 -1.58 0.947 0.0791 0.848 1 1.66 -0.0984 0.213 0.958 6111619.
# … with more rows
# ℹ Use `print(n = ...)` to see more rows
# Source: spark<?> [?? x 12]
place_y place_x long1 lat1 long2 lat2 group diff.long diff.lat a c d
<chr> <int> <dbl> <dbl> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <dbl> <dbl>
1 France 8 -1.58 0.947 0.0791 0.848 2 1.66 -0.0984 0.213 0.958 6111619.
2 Bolivia 9 0.605 -0.0498 -1.15 -0.241 2 -1.75 -0.191 0.581 1.73 11061340.
3 France 9 0.605 -0.0498 0.0791 0.848 2 -0.526 0.898 0.233 1.01 6427714.
4 Bolivia 10 0.479 1.10 -1.15 -0.241 2 -1.63 -1.35 0.619 1.81 11545583.
5 France 10 0.479 1.10 0.0791 0.848 2 -0.400 -0.256 0.0281 0.337 2148220.
6 Bolivia 11 -1.58 0.947 -1.15 -0.241 2 0.433 -1.19 0.339 1.24 7929331.
7 France 11 -1.58 0.947 0.0791 0.848 2 1.66 -0.0984 0.213 0.958 6111619.
8 Bolivia 12 0.605 -0.0498 -1.15 -0.241 2 -1.75 -0.191 0.581 1.73 11061340.
9 France 12 0.605 -0.0498 0.0791 0.848 2 -0.526 0.898 0.233 1.01 6427714.
10 Bolivia 13 0.479 1.10 -1.15 -0.241 2 -1.63 -1.35 0.619 1.81 11545583.
# … with more rows
# ℹ Use `print(n = ...)` to see more rows
# Source: spark<?> [?? x 12]
place_y place_x long1 lat1 long2 lat2 group diff.long diff.lat a c d
<chr> <int> <dbl> <dbl> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <dbl> <dbl>
1 Bolivia 16 0.479 1.10 -1.15 -0.241 3 -1.63 -1.35 0.619 1.81 11545583.
2 France 16 0.479 1.10 0.0791 0.848 3 -0.400 -0.256 0.0281 0.337 2148220.
3 Bolivia 17 -1.58 0.947 -1.15 -0.241 3 0.433 -1.19 0.339 1.24 7929331.
4 France 17 -1.58 0.947 0.0791 0.848 3 1.66 -0.0984 0.213 0.958 6111619.
5 Bolivia 18 0.605 -0.0498 -1.15 -0.241 3 -1.75 -0.191 0.581 1.73 11061340.
6 France 18 0.605 -0.0498 0.0791 0.848 3 -0.526 0.898 0.233 1.01 6427714.
7 Bolivia 19 0.479 1.10 -1.15 -0.241 3 -1.63 -1.35 0.619 1.81 11545583.
8 France 19 0.479 1.10 0.0791 0.848 3 -0.400 -0.256 0.0281 0.337 2148220.
9 Bolivia 20 -1.58 0.947 -1.15 -0.241 3 0.433 -1.19 0.339 1.24 7929331.
10 France 20 -1.58 0.947 0.0791 0.848 3 1.66 -0.0984 0.213 0.958 6111619.
# … with more rows
# ℹ Use `print(n = ...)` to see more rows
# Source: spark<?> [?? x 12]
place_y place_x long1 lat1 long2 lat2 group diff.long diff.lat a c d
<chr> <int> <dbl> <dbl> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <dbl> <dbl>
1 France 23 -1.58 0.947 0.0791 0.848 4 1.66 -0.0984 0.213 0.958 6111619.
2 Bolivia 24 0.605 -0.0498 -1.15 -0.241 4 -1.75 -0.191 0.581 1.73 11061340.
3 France 24 0.605 -0.0498 0.0791 0.848 4 -0.526 0.898 0.233 1.01 6427714.
4 Bolivia 25 0.479 1.10 -1.15 -0.241 4 -1.63 -1.35 0.619 1.81 11545583.
5 France 25 0.479 1.10 0.0791 0.848 4 -0.400 -0.256 0.0281 0.337 2148220.
6 Bolivia 26 -1.58 0.947 -1.15 -0.241 4 0.433 -1.19 0.339 1.24 7929331.
7 France 26 -1.58 0.947 0.0791 0.848 4 1.66 -0.0984 0.213 0.958 6111619.
8 Bolivia 27 0.605 -0.0498 -1.15 -0.241 4 -1.75 -0.191 0.581 1.73 11061340.
9 France 27 0.605 -0.0498 0.0791 0.848 4 -0.526 0.898 0.233 1.01 6427714.
10 Bolivia 28 0.479 1.10 -1.15 -0.241 4 -1.63 -1.35 0.619 1.81 11545583.
# … with more rows
# ℹ Use `print(n = ...)` to see more rows