I have went through multiple sites like https://spark.apache.org/docs/latest/streaming-programming-guide.html
https://data-flair.training/blogs/spark-streaming-checkpoint/
https://docs.databricks.com/spark/latest/rdd-streaming/developing-streaming-applications.html
Some links talk about the how we can code but it's so abstract that I needed a lot of time to figure out how this actually works
After a long fight I am able to setup the streaming code with checkpoint, adding here to help others
import java.util.concurrent.Executors
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
import com.amazonaws.regions.Regions
import com.amazonaws.services.sqs.model.Message
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.log4j.LogManager
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.SparkSession
import org.apache.spark.streaming.{Duration, Seconds, StreamingContext}
object StreamingApp extends scala.Serializable {
@transient private final val mapper = new ObjectMapper
@transient private final val LOG = LogManager.getLogger(getClass)
@transient private final val executor = Executors.newFixedThreadPool(Runtime.getRuntime.availableProcessors)
var s3 = "s3"
private var shutdownMarker: String = _
private var stopFlag: Boolean = false
def main(args: Array[String]): Unit = {
val queueName = args(0)
val region = args(1)
val fetchMaxMessage = args(2).toInt
val visibilityTimeOutSeconds = args(3).toInt
val waitTimeoutInMillis = args(4).toLong
val isLocal = args(5).toBoolean
val bucket = args(6)
if (args.length >= 10)
shutdownMarker = args(9)
val sparkConf = initialiseSparkConf(isLocal)
sparkConf.set(Constants.QUEUE_NAME, queueName)
sparkConf.set(Constants.REGION, region)
sparkConf.set(Constants.FETCH_MAX_MESSAGE, fetchMaxMessage.toString)
sparkConf.set(Constants.VISIBILITY_TIMEOUT_SECONDS, visibilityTimeOutSeconds.toString)
sparkConf.set(Constants.WAIT_TIMEOUT_IN_MILLIS, waitTimeoutInMillis.toString)
shutdownMarker = s"$s3://$bucket/streaming/shutdownmarker"
val checkpointDirectory = s"$s3://$bucket/streaming/checkpoint/"
var context: StreamingContext = null
try {
context = StreamingContext.getOrCreate(checkpointDirectory, () => createContext(sparkConf, waitTimeoutInMillis, checkpointDirectory, args))
context.start
val checkIntervalMillis = 10000
var isStopped = false
while (!isStopped) {
println("calling awaitTerminationOrTimeout")
isStopped = context.awaitTerminationOrTimeout(checkIntervalMillis)
if (isStopped)
println("confirmed! The streaming context is stopped. Exiting application...")
checkShutdownMarker(context.sparkContext)
if (!isStopped && stopFlag) {
println("stopping ssc right now")
context.stop(stopSparkContext = true, stopGracefully = true)
println("ssc is stopped!!!!!!!")
}
}
}
finally {
LOG.info("Exiting the Application")
if (context != null && org.apache.spark.streaming.StreamingContextState.STOPPED != context.getState) {
context.stop(stopSparkContext = true, stopGracefully = true)
}
if (!executor.isShutdown)
executor.shutdown()
}
}
def checkShutdownMarker(sparkContext: SparkContext): Unit = {
if (!stopFlag) {
stopFlag = isFileExists(shutdownMarker, sparkContext)
}
println(s"Stop marker $shutdownMarker file found: $stopFlag at time ${System.currentTimeMillis()}")
}
def isFileExists(path: String, sparkContext: SparkContext): Boolean = {
isValidPath(isDir = false, path, getFileSystem(path,sparkContext))
}
def getFileSystem(path: String, sparkContext: SparkContext): FileSystem = {
FileSystem.get(URI.create(path), sparkContext.hadoopConfiguration)
}
def isValidPath(isDir: Boolean, path: String, fileSystem: FileSystem): Boolean = {
LOG.info("Validating path {}", path)
if (path.startsWith(Constants.S3) || path.startsWith(Constants.HDFS) || path.startsWith(Constants.FILE)) {
val fsPath = new Path(path)
if (isDir) {
fileSystem isDirectory fsPath
} else {
fileSystem isFile fsPath
}
} else {
Files.exists(Paths.get(path))
}
}
def createContext(sparkConf: SparkConf, waitTimeoutInMillis: Long, checkpointDirectory: String, args: Array[String]): StreamingContext = {
val context = new StreamingContext(sparkConf, Duration(waitTimeoutInMillis + 1000))
processMessage(context, args)
context.checkpoint(checkpointDirectory) // set checkpoint directory
context
}
def processMessage(context: StreamingContext, args: Array[String]): Unit = {
val bucket = args(6)
val wgPath = args(7)
var stagingPath = args(8)
val waitTimeoutInMillis = args(4).toLong
if (context != null) {
if (!stagingPath.endsWith("/")) {
stagingPath = s"$stagingPath/"
}
val outputPrefix = s"$s3://$bucket/$stagingPath"
LOG.info(s"Number of cores for driver: ${Runtime.getRuntime.availableProcessors}")
val sparkContext: SparkContext = context.sparkContext
val broadcasts = BroadCaster.getInstance(sparkContext, s"$s3://$bucket/$wgPath")
val input = context.receiverStream(broadcasts(Constants.SQS_RECEIVER).value.asInstanceOf[SQSReceiver])
//input.checkpoint(interval = Seconds(60))
LOG.info(s"Scheduling mode ${sparkContext.getSchedulingMode.toString}")
input.foreachRDD(r => {
val sparkSession = SparkSession.builder.config(r.sparkContext.getConf).getOrCreate()
val messages = r.collect().map(message => mapper.readValue(message, classOf[Message]))
val broadcasts = BroadCaster.getInstance(r.sparkContext, s"$s3://$bucket/$wgPath")
//Application logic
})
}
}
def initialiseSparkConf(local: Boolean): SparkConf = {
val sparkConf = new SparkConf()
.setAppName("Spark Streaming")
.set("spark.scheduler.mode", "FAIR")
.set("spark.sql.parquet.filterpushdown", "true")
.set("spark.executor.hearbeatInterval", "20")
.set("spark.streaming.driver.writeAheadLog.closeFileAfterWrite", "true")
.set("spark.streaming.receiver.writeAheadLog.closeFileAfterWrite", "true")
.set("spark.streaming.receiver.writeAheadLog.enable", "true")
.set("spark.streaming.stopGracefullyOnShutdown", "true")
.set("spark.streaming.backpressure.enabled","true")
.set("spark.streaming.backpressure.pid.minRate","10") //SQS support batch of 10
if (local) {
s3 = "s3a"
sparkConf.setMaster("local[*]")
} else {
sparkConf.set("hive.metastore.client.factory.class",
"com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory")
}
}
}
object BroadCaster {
@volatile private var instance: Map[String, Broadcast[Any]] = _
def getInstance(sparkContext: SparkContext, wgPath: String): Map[String, Broadcast[Any]] = {
if (instance == null) {
synchronized {
if (instance == null) {
instance = Utils.createBroadcastObjects(wgPath, sparkContext)
instance += (Constants.SQS_RECEIVER -> sparkContext.broadcast(getSQSReceiver(sparkContext.getConf)))
}
}
}
instance
}
private def getSQSReceiver(conf: SparkConf): SQSReceiver = {
val javaSQSReceiver = new SQSReceiver(conf.get(Constants.QUEUE_NAME)).withRegion(Regions.fromName(conf.get(Constants.REGION))).withCredential(new DefaultAWSCredentialsProviderChain())
.withFetchMaxMessage(conf.getInt(Constants.FETCH_MAX_MESSAGE, 10)).withVisibilityTimeOutSeconds(conf.getInt(Constants.VISIBILITY_TIMEOUT_SECONDS, 1800)).withWaitTimeoutinMillis(conf.getLong(Constants.WAIT_TIMEOUT_IN_MILLIS, 1000))
javaSQSReceiver
}
}
import java.util.List;
import org.apache.log4j.Logger;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.receiver.Receiver;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSClientBuilder;
import com.amazonaws.services.sqs.model.DeleteMessageBatchRequest;
import com.amazonaws.services.sqs.model.DeleteMessageRequest;
import com.amazonaws.services.sqs.model.Message;
import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
public class SQSReceiver extends Receiver<String> {
private String queueName;
private transient AWSCredentialsProvider credential;
private Regions region = Regions.US_EAST_1;
private Long waitTimeoutinMillis = 0L;
private ObjectMapper mapper= new ObjectMapper();
private transient Logger logger = Logger.getLogger(SQSReceiver.class);
private boolean deleteOnReceipt = false;
private int fetchMaxMessage = 100;
private int visibilityTimeOutSeconds = 60;
private String sqsQueueUrl;
private transient AmazonSQS amazonSQS;
public SQSReceiver(String queueName) {
this(queueName, false);
}
public SQSReceiver(String queueName, boolean deleteOnReceipt) {
super(StorageLevel.MEMORY_AND_DISK_SER());
this.queueName = queueName;
this.deleteOnReceipt = deleteOnReceipt;
setupSQS(queueName);
}
private void setupSQS(String queueName) {
AmazonSQSClientBuilder amazonSQSClientBuilder = AmazonSQSClientBuilder.standard();
if (credential != null) {
amazonSQSClientBuilder.withCredentials(credential);
}
amazonSQSClientBuilder.withRegion(region);
amazonSQS = amazonSQSClientBuilder.build();
sqsQueueUrl = amazonSQS.getQueueUrl(queueName).getQueueUrl();
}
public void onStart() {
new Thread(this::receive).start();
}
public void onStop() {
// There is nothing much to do as the thread calling receive()
// is designed to stop by itself if isStopped() returns false
}
private void receive() {
try {
setupSQS(queueName);
ReceiveMessageRequest receiveMessageRequest = new ReceiveMessageRequest(sqsQueueUrl).withMaxNumberOfMessages(fetchMaxMessage).withVisibilityTimeout(visibilityTimeOutSeconds)
.withWaitTimeSeconds(20); //https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/examples-sqs-long-polling.html
receiveMessagesFromSQS(amazonSQS, sqsQueueUrl, receiveMessageRequest);
} catch (Throwable e) {
stop("Error encountered while initializing", e);
}
}
private void receiveMessagesFromSQS(final AmazonSQS amazonSQS, final String sqsQueueUrl,
ReceiveMessageRequest receiveMessageRequest) {
try {
while (!isStopped()) {
List<Message> messages = amazonSQS.receiveMessage(receiveMessageRequest).getMessages();
if (deleteOnReceipt) {
String receiptHandle = messages.get(0).getReceiptHandle();
messages.forEach(m -> store(m.getBody()));
amazonSQS.deleteMessage(new DeleteMessageRequest(sqsQueueUrl, receiptHandle));
} else {
messages.forEach(this::storeMessage);
}
if (waitTimeoutinMillis > 0L)
Thread.sleep(waitTimeoutinMillis);
}
restart("Trying to connect again");
} catch (IllegalArgumentException | InterruptedException e) {
restart("Could not connect", e);
} catch (Throwable e) {
restart("Error Receiving Data", e);
}
}
private void storeMessage(Message m) {
try {
if (m != null)
store(mapper.writeValueAsString(m));
} catch (JsonProcessingException e) {
logger.error("Unable to write message to streaming context");
}
}
public SQSReceiver withVisibilityTimeOutSeconds(int visibilityTimeOutSeconds) {
this.visibilityTimeOutSeconds = visibilityTimeOutSeconds;
return this;
}
public SQSReceiver withFetchMaxMessage(int fetchMaxMessage) {
if (fetchMaxMessage > 10) {
throw new IllegalArgumentException("FetchMaxMessage can't be greater than 10");
}
this.fetchMaxMessage = fetchMaxMessage;
return this;
}
public SQSReceiver withWaitTimeoutinMillis(long waitTimeoutinMillis) {
this.waitTimeoutinMillis = waitTimeoutinMillis;
return this;
}
public SQSReceiver withRegion(Regions region) {
this.region = region;
return this;
}
public SQSReceiver withCredential(AWSCredentialsProvider credential) {
this.credential = credential;
return this;
}
public void deleteMessages(DeleteMessageBatchRequest request) {
request.withQueueUrl(sqsQueueUrl);
amazonSQS.deleteMessageBatch(request);
}
}