pythonapache-sparkpysparkcount

Why does .count() method return the wrong number of items?


I'm using pySpark and using count() on a dataFrame I seem to get the incorrect results;

I made a csv, and I want to filter the rows with an incorrect type. Everything works (I use .show() to check), however when i call count(), the results I get are incorrect. I found out columnPruning is part of the problem, however disabling it still returns the incorrect result. In order to find the correct results, I need to call dataFrame.cache().count(). My questions are:

  1. Why do I get the wrong result? What happens under the hood?
  2. Is it an intended behaviour, or a bug?
  3. How should I handle it? Using .cache() works, but is expensive and I don't exactly understand why it works.

Here's the code:

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
#spark.conf.set('spark.sql.csv.parser.columnPruning.enabled', False)

myDf = spark.read.csv(
    'C:/example/myCsv.csv',
    header=True,
    inferSchema=False,
    schema='c1 INTEGER, c2 STRING, bad STRING',
    columnNameOfCorruptRecord='bad',
)

myDf.show()
print(myDf.count())

myCleanDf = myDf.filter(myDf.bad.isNull()).drop('bad')
myBadsDf = myDf.filter(myDf.bad.isNotNull()).select('bad')

myCleanDf.show()
myBadsDf.show()

#Wrong results
print(myCleanDf.count())
print(myBadsDf.count())

#Correct results
print(myCleanDf.cache().count())
print(myBadsDf.cache().count())

(The csv contains sample data. Putting a row with an incorrect type will trigger the behaviour I describe) (Both with and without pruning the results I get are incorrect)


Solution

  • This issue was first raised here - SPARK-21610 which was subsequently 'fixed' by disallowing filtering the dataframe when only internal corrupt record column was used in the filter (via this PR 19199).

    Consequently, the error message "Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column" was added in Spark 2.3.

    Since Spark 2.3, the queries from raw JSON/CSV files are disallowed
    when the referenced columns only include the internal corrupt record
    column (named _corrupt_record by default). For example,
    spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count() and
    spark.read.schema(schema).json(file).select("_corrupt_record").show(). Instead, you can cache or save the parsed results and then send the
    same query. For example, val df =
    spark.read.schema(schema).json(file).cache() and then
    df.filter($"_corrupt_record".isNotNull).count().

    Reference: Migration Guide

    The same bug was raised here SPARK-22580 too.

    Apparently, this behaviour was changed by this PR 35844 in 2022 with the release of Spark 3.3. While, querying the 'corrupt record column' no longer raises the exception, I believe this is still a bug, which has remained in the codebase.

    You will get correct results if you add a cache() to your dataframe read.

    from pyspark.sql import SparkSession
    
    spark = SparkSession.builder.getOrCreate()
    # spark.conf.set('spark.sql.csv.parser.columnPruning.enabled', False)
    
    myDf = spark.read.csv(
        "/workspaces/cubie/cubie/tests/test_data/temp.csv",
        header=True,
        inferSchema=False,
        schema="c1 INTEGER, c2 STRING, bad STRING",
        columnNameOfCorruptRecord="bad",
    ).cache()
    
    myCleanDf = myDf.filter(myDf.bad.isNull()).drop("bad")
    myBadsDf = myDf.filter(myDf.bad.isNotNull()).select("bad")
    
    myCleanDf.show()
    myBadsDf.show()
    
    # Correct results with cache()
    print(myCleanDf.count())
    print(myBadsDf.count())
    

    Remember, cache() is also lazily evaluated. show() or count() is what triggers the cache to come into play.