pythonmocking

Unable to replace python function definition with mock and unitest


I am trying to write integration test for the following python code:

import xx.settings.config as stg
from xx.infrastructure.utils import csvReader, dataframeWriter
from pyspark.sql import SparkSession
from typing import List
from awsglue.utils import getResolvedOptions
import sys


def main(argv: List[str]) -> None:
    args = getResolvedOptions(
        argv,
        ['JOB_NAME', 'S3_BRONZE_BUCKET_NAME', 'S3_PRE_SILVER_BUCKET_NAME', 'S3_BRONZE_PATH', 'S3_PRE_SILVER_PATH'],
    )

    s3_bronze_bucket_name = args['S3_BRONZE_BUCKET_NAME']
    s3_pre_silver_bucket_name = args['S3_PRE_SILVER_BUCKET_NAME']
    s3_bronze_path = args['S3_BRONZE_PATH']
    s3_pre_silver_path = args['S3_PRE_SILVER_PATH']

    spark = SparkSession.builder.getOrCreate() 
    spark.conf.set('spark.sql.sources.partitionOverwriteMode', 'dynamic')


    for table in list(stg.data_schema.keys()):
        raw_data = stg.data_schema[table].columns.to_dict()
        df = csvReader(spark, s3_bronze_bucket_name, s3_bronze_path, table, schema, '\t')
        dataframeWriter(df, s3_pre_silver_bucket_name, s3_pre_silver_path, table, stg.data_schema[table].partitionKey)

if __name__ == '__main__':
    main(sys.argv)

I basically loop on a list of tables then read their content (csv format) from S3 and write them in parquet format in S3 also.

These are definitions of csvReader and dataframeWriter:

def csvReader(spark: SparkSession, bucket: str, path: str, table: str, schema: StructType, sep: str) -> DataFrame:
    return (
        spark.read.format('csv')
        .option('header', 'true')
        .option('sep', sep)
        .schema(schema)
        .load(f's3a://{bucket}/{path}/{table}.csv')
    )



def dataframeWriter(df: DataFrame, bucket: str, path: str, table: str, partition_key: str) -> None:
    df.write.partitionBy(partition_key).mode('overwrite').parquet(f's3a://{bucket}/{path}/{table}/')

For my integration tests I would like to replace S3 interaction with local files interaction (read css from local and write parquet in local. This is what I done:

import os
from unittest import TestCase
from unittest.mock import patch, Mock
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType
import xx.application.perfmarket_pre_silver as perfmarket_pre_silver
from dvr_config_utils.config import initialize_settings


def local_csvReader(spark: SparkSession, table: str, schema: StructType, sep: str):
    """Mocked function that replaces real csvReader. this one reads from local rather than S3."""
    return (
        spark.read.format('csv')
        .option('header', 'true')
        .option('sep', sep)
        .schema(schema)
        .load(f'../input_mock/{table}.csv')
    )


def local_dataframeWriter(df, table: str, partition_key: str):
    """Mocked function that replaces real dataframeWriter. this one writes in local rather than S3."""
    output_dir = f'../output_mock/{table}/'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    df.write.partitionBy(partition_key).mode('overwrite').parquet(output_dir)


class TestPerfmarketSilver(TestCase):
    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.master('local').appName('TestPerfmarketSilver').getOrCreate()
        cls.spark.conf.set('spark.sql.sources.partitionOverwriteMode', 'dynamic')

    @classmethod
    def tearDownClass(cls):
        """Clean up the Spark session and test data."""
        cls.spark.stop()
        os.system('rm -rf ../output_mock')

    @patch('xx.application.cc.getResolvedOptions')
    @patch('src.xx.infrastructure.utils.csvReader', side_effect=local_csvReader)
    @patch('xx.infrastructure.utils.dataframeWriter', side_effect=local_dataframeWriter)
    def test_main(self, mock_csvreader, mock_datawriter, mocked_get_resolved_options: Mock):

        expected_results = {'chemins': {'nbRows': 8}}


        mocked_get_resolved_options.return_value = {
            'JOB_NAME': 'perfmarket_pre_silver_test',
            'S3_BRONZE_BUCKET_NAME': 'test_bronze',
            'S3_PRE_SILVER_BUCKET_NAME': 'test_pre_silver',
            'S3_BRONZE_PATH': '../input_mock',
            'S3_PRE_SILVER_PATH': '../output_mock'
        }
        perfmarket_pre_silver.main([])

        for table in stg.data_schema.keys():
            # Verify that the output Parquet file is created
            output_path = f'../output_mock/{table}/'
            self.assertTrue(os.path.exists(output_path))

            # Read the written Parquet file and check the data
            written_df = self.spark.read.parquet(output_path)
            self.assertEqual(written_df.count(), expected_results[table]['nbRows'])  # Check row count
            self.assertTrue(
                [
                    column_data['bronze_name']
                    for table in stg.data_schema.values()
                    for column_data in table['columns'].values()
                ]
                == written_df.columns
            )

What I wanted to do with these two lines:

@patch('src.xx.infrastructure.utils.csvReader', side_effect=local_csvReader)
@patch('xx.infrastructure.utils.dataframeWriter', side_effect=local_dataframeWriter)

Is to replace definitions of csvReader by local_csvReader and dataframeWriter by local_dataframeWriter.

Unfortunately, code is retuning

py4j.protocol.Py4JJavaError: An error occurred while calling o39.load.
: java.lang.RuntimeException: java.lang.ClassNotFoundException: Class org.apache.hadoop.fs.s3a.S3AFileSystem not found

This is my project structure:

project/
│
├── src/
│   └── xx/
│       ├── application/
│       │   └── perfmarket_pre_silver.py
│       ├── __init__.py
│       ├── infrastructure/
│       │   ├── __init__.py
│       │   └── utils.py
│       └── other_modules/
└── tests/
    └── integration_tests/
        └── application/
            └── test_perfmarket_pre_silver.py

Both csvReader and dataframeWriter are defined in utils.py.

Error is pointing to csvReader call in main code (first snippet).

So my replacing technique is clearly not working.

What am I doing wrong please ?


Solution

  • This is what I did finally:

    from moto import mock_aws
    
    def local_csvReader(spark: SparkSession, table: str, schema: StructType, sep: str):
        """Mocked function that replaces real csvReader. this one reads from local rather than S3."""
        return (
            spark.read.format('csv')
            .option('header', 'true')
            .option('sep', sep)
            .schema(schema)
            .load(f'../input_mock/{table}.csv')
        )
    
    
    def local_dataframeWriter(df, table: str, partition_key: str):
        """Mocked function that replaces real dataframeWriter. this one writes in local rather than S3."""
        output_dir = f'../output_mock/{table}/'
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        df.write.partitionBy(partition_key).mode('overwrite').parquet(output_dir)
    
    
    @mock_aws
    @patch('xx.infrastructure.utils.dataframeWriter')
    @patch('xx.infrastructure.utils.csvReader')
    def test_main(self, mock_csv_reader, mock_dataframe_writer):
        from xx.application.perfmarket_pre_silver import main
        mock_csv_reader.side_effect = local_csvReader
        mock_dataframe_writer.side_effect = local_dataframeWriter
        ...
        with patch('sys.argv', mock_args):
            main(mock_args)
        ...
        ...