apache-sparkapache-spark-sql

Optimize usage of collect()


I have working code, but it takes 10 minutes for a task that my local computer can do in ~1 minute. So I think my code needs optimization and I think I am not using Spark, especially the SQL limit() and collect() methods, correctly.

I wand/need to move my problem to Spark (pyspark), because our old tools and computers cannot sensibly handle the sheer amount of files produced (and they apparently don't have the resources to handle some of the biggest files we generate).

I am looking at CSV files and for each file, i.e. experiment, i need to know which sensor was triggered first/last and when these events occured.

Reduced to the Spark relevant code I do

tgl = dataframe.filter("<this line is relevant>") \
        .select(\
            substring_index(col("Name"),"Sensor <Model> <Revision> ", -1)\
                .alias("Sensors"),\
            col("Timestamp").astype('bigint'))\
        .groupBy("Sensors").agg(min("Timestamp"),max("Timestamp"))

point_in_time = tgl.orderBy("min(Timestamp)", ascending=True).limit(2).collect()
[...]
point_in_time = tgl.orderBy("min(Timestamp)", ascending=False).limit(1).collect()
[...]
point_in_time = tgl.orderBy("max(Timetamp)", ascending=True).limit(1).collect()
[...]
point_in_time = tgl.orderBy("max(Timestamp)", ascending=False).limit(2).collect()
[...]

I did it this way because I read somewhere, that using .limit() is often the smarter choice b/c then not all data would be collected centrally, which can take quite bit of time, memory and network capacity.

I test my code with a file, which is 2.5GB large, and about 3E7 lines long. When I look at the timeline of the processing, I get this: Timeline of exemplary test run

The first interesting thing to note is that about every Spark task takes 1.1 minutes. The code shown above is responsible for the first 4 illustrated calls to collect().

As all four calls share the same dataframe that originates from filter().select().group().agg() I would have thought the later three calls would be much faster than the first. Apparently Spark does not recognise this and starts from the original dataframe every time. How can I optimize this, so that the later three calls to collect() benefit from the intermediary results of the first call to collect()?


Solution

  • Your observation about spark re-executing the DAG each time is correct, it arises from a very simple fact that spark is lazy and that spark has 2 types of operations:

    1. transformations: select, filter, groupBy, orderBy, withColumn, etc which describe how Dataframe/Dataset will be transformed and contribute to the DAG
    2. actions: write, collect, count, etc which cause execution of the DAG

    Dataframes do not hold data, they are sort of virtual view that describe transformations of input data. One way to not cause re-execution of the DAG with each collect is to cache tgl

    tgl = dataframe.filter("<this line is relevant>") \
        .select(\
            substring_index(col("Name"),"Sensor <Model> <Revision> ", -1)\
                .alias("Sensors"),\
            col("Timestamp").astype('bigint'))\
        .groupBy("Sensors").agg(min("Timestamp"),max("Timestamp"))
    
     tgl.persist()
    
     point_in_time = tgl.orderBy("min(Timestamp)", ascending=True).limit(2).collect()
     [...]
     point_in_time = tgl.orderBy("min(Timestamp)", ascending=False).limit(1).collect()
     [...]
     point_in_time = tgl.orderBy("max(Timetamp)", ascending=True).limit(1).collect()
     [...]
     point_in_time = tgl.orderBy("max(Timestamp)", ascending=False).limit(2).collect()
     [...]
    

    It will prevent re-execution of the DAG, but there will be a price to for caching tgl to RAM and might negate benefits of the limit operation. How much is the impact, only experiment will show.

    Alternatively if you define what kind of questions you want answered by you program, I could try and help you with writing specific query or program to answered in one go.