apache-sparkspark-streamingspark-checkpoint

Spark streaming SQS with checkpoint enable


I have went through multiple sites like https://spark.apache.org/docs/latest/streaming-programming-guide.html

https://data-flair.training/blogs/spark-streaming-checkpoint/

https://docs.databricks.com/spark/latest/rdd-streaming/developing-streaming-applications.html

Some links talk about the how we can code but it's so abstract that I needed a lot of time to figure out how this actually works


Solution

  • After a long fight I am able to setup the streaming code with checkpoint, adding here to help others

    import java.util.concurrent.Executors
    
    import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
    import com.amazonaws.regions.Regions
    import com.amazonaws.services.sqs.model.Message
    import com.fasterxml.jackson.databind.ObjectMapper
    import org.apache.log4j.LogManager
    import org.apache.spark.{SparkConf, SparkContext}
    import org.apache.spark.broadcast.Broadcast
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.streaming.{Duration, Seconds, StreamingContext}
    
    object StreamingApp extends scala.Serializable {
      @transient private final val mapper = new ObjectMapper
      @transient private final val LOG = LogManager.getLogger(getClass)
      @transient private final val executor = Executors.newFixedThreadPool(Runtime.getRuntime.availableProcessors)
      var s3 = "s3"
      private var shutdownMarker: String = _
      private var stopFlag: Boolean = false
    
      def main(args: Array[String]): Unit = {
        val queueName = args(0)
        val region = args(1)
        val fetchMaxMessage = args(2).toInt
        val visibilityTimeOutSeconds = args(3).toInt
        val waitTimeoutInMillis = args(4).toLong
        val isLocal = args(5).toBoolean
        val bucket = args(6)
        if (args.length >= 10)
          shutdownMarker = args(9)
        val sparkConf = initialiseSparkConf(isLocal)
        sparkConf.set(Constants.QUEUE_NAME, queueName)
        sparkConf.set(Constants.REGION, region)
        sparkConf.set(Constants.FETCH_MAX_MESSAGE, fetchMaxMessage.toString)
        sparkConf.set(Constants.VISIBILITY_TIMEOUT_SECONDS, visibilityTimeOutSeconds.toString)
        sparkConf.set(Constants.WAIT_TIMEOUT_IN_MILLIS, waitTimeoutInMillis.toString)
    
        shutdownMarker = s"$s3://$bucket/streaming/shutdownmarker"
        val checkpointDirectory = s"$s3://$bucket/streaming/checkpoint/"
        var context: StreamingContext = null
    
        try {
          context = StreamingContext.getOrCreate(checkpointDirectory, () => createContext(sparkConf, waitTimeoutInMillis, checkpointDirectory, args))
          context.start
          val checkIntervalMillis = 10000
          var isStopped = false
    
          while (!isStopped) {
            println("calling awaitTerminationOrTimeout")
            isStopped = context.awaitTerminationOrTimeout(checkIntervalMillis)
            if (isStopped)
              println("confirmed! The streaming context is stopped. Exiting application...")
            checkShutdownMarker(context.sparkContext)
            if (!isStopped && stopFlag) {
              println("stopping ssc right now")
              context.stop(stopSparkContext = true, stopGracefully = true)
              println("ssc is stopped!!!!!!!")
            }
          }
        }
        finally {
          LOG.info("Exiting the Application")
          if (context != null && org.apache.spark.streaming.StreamingContextState.STOPPED != context.getState) {
            context.stop(stopSparkContext = true, stopGracefully = true)
          }
          if (!executor.isShutdown)
            executor.shutdown()
        }
      }
    
      def checkShutdownMarker(sparkContext: SparkContext): Unit = {
        if (!stopFlag) {
          stopFlag = isFileExists(shutdownMarker, sparkContext)
        }
        println(s"Stop marker $shutdownMarker file found: $stopFlag at time ${System.currentTimeMillis()}")
      }
    
    def isFileExists(path: String, sparkContext: SparkContext): Boolean = {
        isValidPath(isDir = false, path, getFileSystem(path,sparkContext))
      }
    
      def getFileSystem(path: String, sparkContext: SparkContext): FileSystem = {
        FileSystem.get(URI.create(path), sparkContext.hadoopConfiguration)
      }
    
    def isValidPath(isDir: Boolean, path: String, fileSystem: FileSystem): Boolean = {
        LOG.info("Validating path {}", path)
        if (path.startsWith(Constants.S3) || path.startsWith(Constants.HDFS) || path.startsWith(Constants.FILE)) {
          val fsPath = new Path(path)
          if (isDir) {
            fileSystem isDirectory fsPath
          } else {
            fileSystem isFile fsPath
          }
        } else {
          Files.exists(Paths.get(path))
        }
      }
    
      def createContext(sparkConf: SparkConf, waitTimeoutInMillis: Long, checkpointDirectory: String, args: Array[String]): StreamingContext = {
    
        val context = new StreamingContext(sparkConf, Duration(waitTimeoutInMillis + 1000))
        processMessage(context, args)
        context.checkpoint(checkpointDirectory) // set checkpoint directory
        context
      }
    
      def processMessage(context: StreamingContext, args: Array[String]): Unit = {
    
        val bucket = args(6)
        val wgPath = args(7)
        var stagingPath = args(8)
        val waitTimeoutInMillis = args(4).toLong
        if (context != null) {
    
          if (!stagingPath.endsWith("/")) {
            stagingPath = s"$stagingPath/"
          }
          val outputPrefix = s"$s3://$bucket/$stagingPath"
    
          LOG.info(s"Number of cores for driver: ${Runtime.getRuntime.availableProcessors}")
    
          val sparkContext: SparkContext = context.sparkContext
    
          val broadcasts = BroadCaster.getInstance(sparkContext, s"$s3://$bucket/$wgPath")
    
          val input = context.receiverStream(broadcasts(Constants.SQS_RECEIVER).value.asInstanceOf[SQSReceiver])
          //input.checkpoint(interval = Seconds(60))
          LOG.info(s"Scheduling mode ${sparkContext.getSchedulingMode.toString}")
          input.foreachRDD(r => {
            val sparkSession = SparkSession.builder.config(r.sparkContext.getConf).getOrCreate()
    
            val messages = r.collect().map(message => mapper.readValue(message, classOf[Message]))
    
            val broadcasts = BroadCaster.getInstance(r.sparkContext, s"$s3://$bucket/$wgPath")
            //Application logic
          })
        }
      }
    
    
      def initialiseSparkConf(local: Boolean): SparkConf = {
        val sparkConf = new SparkConf()
          .setAppName("Spark Streaming")
          .set("spark.scheduler.mode", "FAIR")
          .set("spark.sql.parquet.filterpushdown", "true")
          .set("spark.executor.hearbeatInterval", "20")
          .set("spark.streaming.driver.writeAheadLog.closeFileAfterWrite", "true")
          .set("spark.streaming.receiver.writeAheadLog.closeFileAfterWrite", "true")
          .set("spark.streaming.receiver.writeAheadLog.enable", "true")
          .set("spark.streaming.stopGracefullyOnShutdown", "true")
          .set("spark.streaming.backpressure.enabled","true")
          .set("spark.streaming.backpressure.pid.minRate","10") //SQS support batch of 10
    
        if (local) {
          s3 = "s3a"
          sparkConf.setMaster("local[*]")
        } else {
          sparkConf.set("hive.metastore.client.factory.class",
            "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory")
        }
      }
    }
    
    object BroadCaster {
    
      @volatile private var instance: Map[String, Broadcast[Any]] = _
    
      def getInstance(sparkContext: SparkContext, wgPath: String): Map[String, Broadcast[Any]] = {
        if (instance == null) {
          synchronized {
            if (instance == null) {
              instance = Utils.createBroadcastObjects(wgPath, sparkContext)
              instance += (Constants.SQS_RECEIVER -> sparkContext.broadcast(getSQSReceiver(sparkContext.getConf)))
            }
          }
        }
        instance
      }
    
      private def getSQSReceiver(conf: SparkConf): SQSReceiver = {
        val javaSQSReceiver = new SQSReceiver(conf.get(Constants.QUEUE_NAME)).withRegion(Regions.fromName(conf.get(Constants.REGION))).withCredential(new DefaultAWSCredentialsProviderChain())
          .withFetchMaxMessage(conf.getInt(Constants.FETCH_MAX_MESSAGE, 10)).withVisibilityTimeOutSeconds(conf.getInt(Constants.VISIBILITY_TIMEOUT_SECONDS, 1800)).withWaitTimeoutinMillis(conf.getLong(Constants.WAIT_TIMEOUT_IN_MILLIS, 1000))
        javaSQSReceiver
      }
    }
    
    
    import java.util.List;
    
    import org.apache.log4j.Logger;
    import org.apache.spark.storage.StorageLevel;
    import org.apache.spark.streaming.receiver.Receiver;
    
    import com.amazonaws.auth.AWSCredentialsProvider;
    import com.amazonaws.regions.Regions;
    import com.amazonaws.services.sqs.AmazonSQS;
    import com.amazonaws.services.sqs.AmazonSQSClientBuilder;
    import com.amazonaws.services.sqs.model.DeleteMessageBatchRequest;
    import com.amazonaws.services.sqs.model.DeleteMessageRequest;
    import com.amazonaws.services.sqs.model.Message;
    import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
    import com.fasterxml.jackson.core.JsonProcessingException;
    import com.fasterxml.jackson.databind.ObjectMapper;
    
    public class SQSReceiver extends Receiver<String> {
    
        private String queueName;
        private transient AWSCredentialsProvider credential;
        private Regions region = Regions.US_EAST_1;
        private Long waitTimeoutinMillis = 0L;
        private ObjectMapper mapper= new ObjectMapper();
        private transient Logger logger = Logger.getLogger(SQSReceiver.class);
        private boolean deleteOnReceipt = false;
        private int fetchMaxMessage = 100;
        private int visibilityTimeOutSeconds = 60;
    
        private String sqsQueueUrl;
        private transient AmazonSQS amazonSQS;
    
        public SQSReceiver(String queueName) {
            this(queueName, false);
        }
    
        public SQSReceiver(String queueName, boolean deleteOnReceipt) {
            super(StorageLevel.MEMORY_AND_DISK_SER());
            this.queueName = queueName;
            this.deleteOnReceipt = deleteOnReceipt;
            setupSQS(queueName);
        }
    
        private void setupSQS(String queueName) {
            AmazonSQSClientBuilder amazonSQSClientBuilder = AmazonSQSClientBuilder.standard();
    
            if (credential != null) {
                amazonSQSClientBuilder.withCredentials(credential);
            }
            amazonSQSClientBuilder.withRegion(region);
            amazonSQS = amazonSQSClientBuilder.build();
            sqsQueueUrl = amazonSQS.getQueueUrl(queueName).getQueueUrl();
        }
    
        public void onStart() {
            new Thread(this::receive).start();
        }
    
        public void onStop() {
            // There is nothing much to do as the thread calling receive()
            // is designed to stop by itself if isStopped() returns false
        }
    
        private void receive() {
            try {
                setupSQS(queueName);
                ReceiveMessageRequest receiveMessageRequest = new ReceiveMessageRequest(sqsQueueUrl).withMaxNumberOfMessages(fetchMaxMessage).withVisibilityTimeout(visibilityTimeOutSeconds)
                        .withWaitTimeSeconds(20); //https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/examples-sqs-long-polling.html
                receiveMessagesFromSQS(amazonSQS, sqsQueueUrl, receiveMessageRequest);
            } catch (Throwable e) {
                stop("Error encountered while initializing", e);
            }
        }
    
        private void receiveMessagesFromSQS(final AmazonSQS amazonSQS, final String sqsQueueUrl,
                                            ReceiveMessageRequest receiveMessageRequest) {
            try {
                while (!isStopped()) {
                    List<Message> messages = amazonSQS.receiveMessage(receiveMessageRequest).getMessages();
                    if (deleteOnReceipt) {
                        String receiptHandle = messages.get(0).getReceiptHandle();
                        messages.forEach(m -> store(m.getBody()));
                        amazonSQS.deleteMessage(new DeleteMessageRequest(sqsQueueUrl, receiptHandle));
                    } else {
                        messages.forEach(this::storeMessage);
                    }
                    if (waitTimeoutinMillis > 0L)
                        Thread.sleep(waitTimeoutinMillis);
                }
                restart("Trying to connect again");
            } catch (IllegalArgumentException | InterruptedException e) {
                restart("Could not connect", e);
            } catch (Throwable e) {
                restart("Error Receiving Data", e);
            }
        }
    
        private void storeMessage(Message m) {
            try {
                if (m != null)
                    store(mapper.writeValueAsString(m));
            } catch (JsonProcessingException e) {
                logger.error("Unable to write message to streaming context");
            }
        }
    
        public SQSReceiver withVisibilityTimeOutSeconds(int visibilityTimeOutSeconds) {
            this.visibilityTimeOutSeconds = visibilityTimeOutSeconds;
            return this;
        }
    
        public SQSReceiver withFetchMaxMessage(int fetchMaxMessage) {
            if (fetchMaxMessage > 10) {
                throw new IllegalArgumentException("FetchMaxMessage can't be greater than 10");
            }
            this.fetchMaxMessage = fetchMaxMessage;
            return this;
        }
    
        public SQSReceiver withWaitTimeoutinMillis(long waitTimeoutinMillis) {
            this.waitTimeoutinMillis = waitTimeoutinMillis;
            return this;
        }
    
        public SQSReceiver withRegion(Regions region) {
            this.region = region;
            return this;
        }
    
        public SQSReceiver withCredential(AWSCredentialsProvider credential) {
            this.credential = credential;
            return this;
        }
    
        public void deleteMessages(DeleteMessageBatchRequest request) {
            request.withQueueUrl(sqsQueueUrl);
            amazonSQS.deleteMessageBatch(request);
        }
    }