I am trying to read from a Kafka topic that contains messages with different Proto payloads. With the messageName
set in the Kafka message key
.
But when I try to:
df = spark.readStream.format(constants.KAFKA_INPUT_FORMAT) \
.options(**options) \
.load()
df = df.selectExpr("CAST(key AS STRING)").alias('key')
df = df.select(from_protobuf('value', df.key, desc_file_path).alias('value'))
I get the pyspark.errors.exceptions.base.PySparkTypeError: [NOT_ITERABLE] Column is not iterable
error.
How can I dynamically set the messageName
parameter of the from_protobuf
function with the key
value of the Kafka message attribute?
I was able to address the above requirement by creating a descriptor file, containing all of the proto definitions, via the following protoc
command:
protoc --include_imports --include_source_info --retain_options --descriptor_set_out=./event.desc --proto_path=./proto ./proto/event/content/monolog/v1/*.proto
While my PySpark driver will use a udf
(User Defined Function) to process the payload:
desc_file_path = './event.desc'
with open(desc_file_path, 'rb') as f:
desc_file_bytes = f.read()
@udf
def deserialize_message_from_file_descriptor_set(message_headers, message_data):
try:
# Parse the FileDescriptorSet
file_descriptor_set = descriptor_pb2.FileDescriptorSet()
file_descriptor_set.ParseFromString(desc_file_bytes)
# Create a DescriptorPool from the FileDescriptorSet
pool = descriptor_pool.DescriptorPool()
for file_descriptor_proto in file_descriptor_set.file:
pool.Add(file_descriptor_proto)
headersList = list(message_headers)
eventTypeStr = headersList[4].value.decode("utf-8")
class_name = kafka_events_classes[eventTypeStr]
# Find the message descriptor
message_descriptor = pool.FindMessageTypeByName(class_name)
factory = message_factory.MessageFactory()
# Create a message instance from the descriptor
message_class = factory.GetPrototype(message_descriptor)
# Deserialize the message data
message = message_class()
message.ParseFromString(message_data)
return json_format.MessageToJson(message)
except Exception as e:
return e
df = df.selectExpr('CAST(key AS STRING)', 'offset', 'timestamp', 'headers', 'value')
df = df.withColumn('payload', deserialize_message_from_file_descriptor_set("headers", "value").alias("", metadata={"sqlType": "JSON"}))
In my case, pool.FindMessageTypeByName
uses the right class name depending on the type of event consumed. This info is fetched by one of the event headers
.