databrickssparklyr

Partition sparklyr data frame by a column so that all observations with the same value for that column are in a single partition


The Problem

I am working in databricks with sparklyr. As an example, I am working with on-time arrival data (from August 2013 to June 2014) which can be downloaded one month at a time here: https://www.transtats.bts.gov/DL_SelectFields.aspx?gnoyr_VQ=FGJ&QO_fu146_anzr=. I have compiled this data into a single table and saved the result to a delta lake table called flight_data. While this is my example data, I want the solution to this question to be generalizable to other datasets. Actual data will contain many more observations and as a result, the partitions will likely contain far more observations (presumably, making it worth the overhead).

I am looking for a way to partition flight_data into subsets based on the value of a column. In this example, I want flights that fly the same route (that is, they have the same origin airport and the same destination airport) to always be in the same partition. For now, I do not care if flights with different origins or destination airports are in the same partition, assuming that the aforementioned requirement holds (e.g. a single partition can contain all flights from JFK to LGA and all flights from DIA to DFW). For my understanding, it would be nice to know how to impose that all observations in a given partition have the same origin-destination pair, but that's a second order concern.

My Attempt

I will illustrate my attempt from when I load in flight_data (hopefully this is sufficient as I didn't make any material changes to the data after aggregating it besides adding a few columns that aren't used for this repartitioning).

First I read the data in as a spark dataframe, convert the flight date to the proper type, and add a column that gives me a unique origin-destination pair (a directional route). Finally, I (perhaps inefficiently) retrieve the number of unique origin-destination pairs. That number is 4,742.

flight_sdf <- spark_read_table(sc, "flight_data", memory = FALSE)

flight_sdf <- flight_sdf %>% mutate(FL_DATE = to_date(FL_DATE, "M/d/yyyy"), origin_dest = paste0(ORIGIN_AIRPORT_ID,"_", DEST_AIRPORT_ID)) 

num_origin_dest <- length(unique(flight_sdf %>% pull(origin_dest)))

Then, (as a example) I try and repartition the data and simply calculate the average delay times by route (origin-destination pair) using spark_apply(). Once again, this is just an example so I can make sure the partition is as I want it to be. The functions I ultimately want to apply to each partition will be more sophisticated.

average_by_route <- function(df) {
  library(dplyr)
  df %>%
    group_by(origin_dest) %>%
    summarize(
      AVG_ARR_DELAY = mean(ARR_DELAY, na.rm = TRUE),
      AVG_DEP_DELAY = mean(DEP_DELAY, na.rm = TRUE)
    )
}

result <- flight_sdf %>%
  sdf_repartition(num_origin_dest, partition_by = "origin_dest") %>%
  spark_apply(average_by_route, packages = c("dplyr"))


sdf_nrow(result)

I would expect that sdf_nrow(result) would return 4,742 (the number of unique routes). It instead returns 4,807. Looking at result it appears that flights for 65 routes appear in two rows of result despite grouping by origin_dest within average_by_route(). This leads me to believe that not all flights corresponding to a single route are contained within one partition, despite me specifying the partition using sdf_repartition(num_origin_dest, partition_by = "origin_dest").

What I've seen on other (5+ year old) responses to similar questions leads me to believe it could have something do with some configuration option in the background but I am not sure.

Thanks in advance for the explanation (or if you can point me to a solution to this question) and let me know if you need any additional information. I am relatively new to this and could be missing something obvious.


Solution

  • I believe I solved the problem for my given circumstance. I had to increase the maximum number of records per batch when configuring my spark session. The default is 10,000 so I changed it to 20,000 which is larger than the number of records for any single route in my data.

    conf <- spark_config()
    conf$spark.sql.execution.arrow.maxRecordsPerBatch <- 20000
    sc <- spark_connect(method = "databricks",  config = conf)
    

    In other circumstances, the maximum number of records may not be the limiting factor. If this doesn't work for you, I suggest checking out this list of configurations: https://spark.apache.org/docs/latest/configuration.html. I found it very helpful.

    From what I can tell, there are other limits on the size of partitions (e.g. its memory footprint in bytes) that could be the limiting factor in other scenarios.