apache-sparkapache-spark-sql

Why spark count action has executed in three stages


I have loaded a csv file. Re-partitioned it to 4 and then took count of the DataFrame. And when I looked at the DAG I see this action is executed in 3 stages.

enter image description here

Why this simple action is executed into 3 stages. I suppose 1st stage is to load the file and 2nd is to find the count on each partition.

So What is happening in the 3rd stage?

Here is my code

val sample = spark.read.format("csv").option("header", "true").option("inferSchema", "true").option("delimiter", ";").load("sample_data.csv")

sample.repartition(4).count()

Solution

    1. The first stage = read a file. Because of repartition (since it's wide transformation that requires shuffling) it can't be joined into single stage with partial_count (2nd stage)

    2. The second stage = local count (calculating count per partition)

    3. The third stage = results aggregation on driver.

    Spark generage separate stage per action or wide transformation. To get more details about narrow/wide transformations and why wide transformation require separate stage take a look at "Wide Versus Narrow Dependencies, High Performance Spark, Holden Karau" or this article.

    Let's test this assumption locally. First you need create a dataset:

    dataset/test-data.json

    [
      { "key":  1, "value":  "a" },
      { "key":  2, "value":  "b" },
      { "key":  3, "value":  "c" },
      { "key":  4, "value":  "d" },
      { "key":  5, "value":  "e" },
      { "key":  6, "value":  "f" },
      { "key":  7, "value":  "g" },
      { "key":  8, "value":  "h" }
    ]
    

    Than run the following code:

        StructType schema = new StructType()
                .add("key", DataTypes.IntegerType)
                .add("value", DataTypes.StringType);
    
        SparkSession session = SparkSession.builder()
                .appName("sandbox")
                .master("local[*]")
                .getOrCreate();
    
        session
                .read()
                .schema(schema)
                .json("file:///C:/<you_path>/dataset")
                .repartition(4) // comment on the second run
                .registerTempTable("df");
    
        session.sqlContext().sql("SELECT COUNT(*) FROM df").explain();
    

    The output will be:

    == Physical Plan ==
    *(3) HashAggregate(keys=[], functions=[count(1)])
    +- Exchange SinglePartition
       +- *(2) HashAggregate(keys=[], functions=[partial_count(1)])
          +- Exchange RoundRobinPartitioning(4)
             +- *(1) FileScan json [] Batched: false, Format: JSON, Location: InMemoryFileIndex[file:/C:/Users/iaroslav/IdeaProjects/sparksandbox/src/main/resources/dataset], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<>
    

    But if you comment/remove .repartition(4) string, note that TableScan & partial_count are done within the single stage and the output will be as following:

    == Physical Plan ==
    *(2) HashAggregate(keys=[], functions=[count(1)])
    +- Exchange SinglePartition
       +- *(1) HashAggregate(keys=[], functions=[partial_count(1)])
          +- *(1) FileScan json [] Batched: false, Format: JSON, Location: InMemoryFileIndex[file:/C:/Users/iaroslav/IdeaProjects/sparksandbox/src/main/resources/dataset], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<>
    

    P.S. Note that extra stage might have a significant impact on performance, since it requires disk I/O (take a look here) and is some kind of synch barrier impacting parallelization, means in most cases Spark won't start stage 2 till stage 1 is completed. Still if repartition increase level of parallelism it probably worth it.