apache-sparkpysparkapache-spark-sqlquery-optimizationdatabase-performance

Need help understanding why Spark query takes longer to execute when GROUP BY is introduced


I have 3 tables in an Oracle database which I am trying to join and run some aggregates on :

orders: (3000 + rows)
order_line_items: (5000 + rows)
items: (14 million rows)

When I run the following code in PySpark:

joined_df = (orders_df.alias("o")
             .join(orders_line_item_df.alias("oli"), F.col("o.order_id") == F.col("oli.order_id"), how="inner")
             .join(item_df.alias("iw"), F.col("oli.item_id") == F.col("iw.item_id"), how="inner")
             .filter(F.col("o.do_status").isin(["110"]))
)

display(joined_df.limit(100))

It completes within 40 seconds and generates the following query plan:

== Physical Plan ==
AdaptiveSparkPlan (25)
+- == Final Plan ==
   ResultQueryStage (14), Statistics(sizeInBytes=981.6 KiB, rowCount=100, ColumnStat: N/A, isRuntime=true)
   +- CollectLimit (13)
      +- BroadcastHashJoin Inner BuildLeft (12)
         :- AQEShuffleRead (10)
         :  +- ShuffleQueryStage (9), Statistics(sizeInBytes=10.9 MiB, rowCount=2.01E+3, ColumnStat: N/A, isRuntime=true)
         :     +- Exchange (8)
         :        +- BroadcastHashJoin Inner BuildLeft (7)
         :           :- AQEShuffleRead (5)
         :           :  +- ShuffleQueryStage (4), Statistics(sizeInBytes=4.9 MiB, rowCount=1.35E+3, ColumnStat: N/A, isRuntime=true)
         :           :     +- Exchange (3)
         :           :        +- Filter (2)
         :           :           +- Scan JDBCRelation(orders) [numPartitions=1]  (1)
         :           +- Scan JDBCRelation(order_line_item) [numPartitions=1]  (6)
         +- Scan JDBCRelation(item_wms) [numPartitions=1]  (11)

But when I add a GROUP BY with a MIN aggregate the query does not finish:

SELECT
        o.order_id,
        MIN(
            CASE
                WHEN iw.code IN(
                    'A', 'B' 
                )
                     THEN
                    '1' 
                    WHEN bill_id = '2'
                    THEN '2' 
                WHEN bill_id  IN ('14','63')
                    THEN '3' 
                WHEN bill_id IN('76','09')
                    THEN '4' 
                ELSE
                    '5' 
            END
        ) foo
    FROM
        orders            o
        JOIN order_line_item   oli ON oli.order_id = o.order_id
        JOIN item          iw ON oli.item_id = iw.item_id
    WHERE
        o.status IN (
            '10'
        )
    GROUP BY
        o.order_id

The query plan generated is:

== Physical Plan ==
AdaptiveSparkPlan (41)
+- == Current Plan ==
   CollectLimit (22)
   +- SortAggregate (21)
      +- Sort (20)
         +- ShuffleQueryStage (19), Statistics(sizeInBytes=2.02E+22 B, ColumnStat: N/A)
            +- Exchange (18)
               +- SortAggregate (17)
                  +- * Sort (16)
                     +- * Project (15)
                        +- * BroadcastHashJoin Inner BuildLeft (14)
                           :- AQEShuffleRead (12)
                           :  +- ShuffleQueryStage (11), Statistics(sizeInBytes=94.3 KiB, rowCount=2.01E+3, ColumnStat: N/A, isRuntime=true)
                           :     +- Exchange (10)
                           :        +- * Project (9)
                           :           +- * BroadcastHashJoin Inner BuildLeft (8)
                           :              :- AQEShuffleRead (6)
                           :              :  +- ShuffleQueryStage (5), Statistics(sizeInBytes=52.8 KiB, rowCount=1.35E+3, ColumnStat: N/A, isRuntime=true)
                           :              :     +- Exchange (4)
                           :              :        +- * Project (3)
                           :              :           +- * Filter (2)
                           :              :              +- * Scan JDBCRelation(orders) [numPartitions=1]  (1)
                           :              +- * Scan JDBCRelation(order_line_item) [numPartitions=1]  (7)
                           +- * Scan JDBCRelation(item_wms) [numPartitions=1]  (13)

What my understanding is:

Am I correct ?

Further questions:

I am asking this as I am a beginner to Spark query optimization and need to know how to interpret a query plan.

Cluster info: 4 cores 32 GB Single Node.


Solution

  • I will try and answer some of your questions :

    1.In the first query Spark processes only a small subset of the data because of limit(100).The query plan includes a CollectLimit operation, meaning Spark limits the data as soon as it has gathered 100 rows.

    2.In the second query the GROUP BY operation typically requires a sort or shuffle operation across all partitions of the data. Spark performs a ShuffleQueryStage to redistribute the data across partitions before performing the aggregation. This is why the query takes longer to complete.You will see stages like SortAggregate, Exchange, and ShuffleQueryStage. These stages represent the sorting, shuffling,etc.

    What does + means?

    The + symbols in the query plan do not represent different stages or jobs. They are used to represent nested operations. The deeper a stage is in the tree, the more + symbols it has.

    how should I interpret when new job or stage is created ?

    You can use Spark UI to visualize new stages and identify memory usage of each tasks etc.