For every row in input_table
there should be created X amount of rows in output_table
, where X=days in year (from StartDate
)
Info
field should contain Y amount of characters, where Y= X*2, if there are less, field should be padded with additional #
characters.
In output_table
AM
and PM
columns will be filled with Info
characters in the correct order, so that each AM
& PM
fields will have exactly 1 character.
Here is the code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType, StringType, DateType, StructField, StructType, TimestampType, ArrayType
# Connection details for input table
url="..."
user="..."
password="..."
input_table="..."
output_table="..."
# Define schema for input table
input_schema = StructType([
StructField("ID1", IntegerType(), True),
StructField("ID2", IntegerType(), True),
StructField("StartDate", TimestampType(), True),
StructField("Info", StringType(), True),
StructField("Extracted", TimestampType(), True)
])
# Define schema for output table
output_schema = StructType([
StructField("ID1", IntegerType(), True),
StructField("ID2", IntegerType(), True),
StructField("Date", DateType(), True),
StructField("AM", StringType(), True),
StructField("PM", StringType(), True),
StructField("CurrentYear", StringType(), True)
])
# Initialize SparkSession
spark = SparkSession.builder.getOrCreate()
# Register UDF for padding marks
pad_marks_udf = udf(lambda info, days: marks.ljust(days, '#')[:days], StringType())
# Register UDF for creating rows
create_rows_udf = udf(lambda start_date, marks, days: [(start_date + i, info[i], info[i + 1]) for i in range(0, days, 2)],
ArrayType(StructType([
StructField("Date", DateType(), True),
StructField("AM", StringType(), True),
StructField("PM", StringType(), True),
])))
# Define function to pad marks and create rows
def process_row(row):
id1 = row["ID1"]
id2 = row["ID2"]
start_date = row["StartDate"]
info= row["info"]
extracted = row["Extracted"]
# Calculate number of days * 2
days = (start_date.year % 4 == 0 and 366 or 365) * 2
# Pad info
padded_info = pad_info_udf(info, days)
# Create rows
rows = create_rows_udf(start_date, padded_info, days)
# Prepare output rows
output_rows = []
for r in rows:
date = r["Date"]
am = r["AM"]
pm = r["PM"]
current_year = f"{current_year .year}/{current_year .year + 1}"
output_rows.append((id1, id2, date, am, pm, current_year))
return output_rows
# Load input table as DataFrame
df_input = spark.read \
.format("jdbc") \
.option("url", url) \
.option("dbtable", input_table) \
.option("user", user) \
.option("password", password) \
.schema(input_schema) \
.load()
# Apply processing to input table
output_rows = df_input.rdd.flatMap(process_row)
# Create DataFrame from output rows
df_output = spark.createDataFrame(output_rows, output_schema)
# Write DataFrame to output table
df_output.write \
.format("jdbc") \
.option("url", url) \
.option("user", user) \
.option("password", password) \
.option("dbtable", output_table) \
.mode("append") \
.save()
Similar code works in Python
with no problems, but when translated to PySpark
throws an AssertionError
. It needs to do no modification in input_table
and append output_table
with modified rows from input_table
.
So the reason is that code is not supposed to use Spark UDF in functions for RDDs. The plain functions should be used instead. Spark UDF can only be used in Spark SQL.
The reason the code worked in the local machine is because in the local mode the executor is in the same JVM as the driver.