apache-sparkpysparkapache-kafkaprotocol-buffersstreaming

PySpark: How To Deserialise A Proto Payload From A Kafka Message With Variable Message Type


I am trying to read from a Kafka topic that contains messages with different Proto payloads. With the messageName set in the Kafka message key.

But when I try to:

df = spark.readStream.format(constants.KAFKA_INPUT_FORMAT) \
        .options(**options) \
        .load()
df = df.selectExpr("CAST(key AS STRING)").alias('key')
df = df.select(from_protobuf('value', df.key, desc_file_path).alias('value'))

I get the pyspark.errors.exceptions.base.PySparkTypeError: [NOT_ITERABLE] Column is not iterable error.

How can I dynamically set the messageName parameter of the from_protobuf function with the key value of the Kafka message attribute?


Solution

  • I was able to address the above requirement by creating a descriptor file, containing all of the proto definitions, via the following protoc command:

    protoc --include_imports --include_source_info --retain_options --descriptor_set_out=./event.desc --proto_path=./proto ./proto/event/content/monolog/v1/*.proto
    

    While my PySpark driver will use a udf (User Defined Function) to process the payload:

        desc_file_path = './event.desc'
        with open(desc_file_path, 'rb') as f:
            desc_file_bytes = f.read()
            
        @udf
        def deserialize_message_from_file_descriptor_set(message_headers, message_data):
            try:
                # Parse the FileDescriptorSet
                file_descriptor_set = descriptor_pb2.FileDescriptorSet()
                file_descriptor_set.ParseFromString(desc_file_bytes)
    
                # Create a DescriptorPool from the FileDescriptorSet
                pool = descriptor_pool.DescriptorPool()
                for file_descriptor_proto in file_descriptor_set.file:
                    pool.Add(file_descriptor_proto)
    
                
                headersList = list(message_headers)
                eventTypeStr = headersList[4].value.decode("utf-8")
                class_name = kafka_events_classes[eventTypeStr]
                # Find the message descriptor
                message_descriptor = pool.FindMessageTypeByName(class_name)
                
                factory = message_factory.MessageFactory()
    
                # Create a message instance from the descriptor
                message_class = factory.GetPrototype(message_descriptor)
    
                # Deserialize the message data
                message = message_class()
                message.ParseFromString(message_data)
    
                return json_format.MessageToJson(message)
            except Exception as e:
                return e
            
        df = df.selectExpr('CAST(key AS STRING)', 'offset', 'timestamp', 'headers', 'value')
        df = df.withColumn('payload', deserialize_message_from_file_descriptor_set("headers", "value").alias("", metadata={"sqlType": "JSON"}))
    

    In my case, pool.FindMessageTypeByName uses the right class name depending on the type of event consumed. This info is fetched by one of the event headers.