I use Pyspark in Azure Databricks to transform data before sending it to a sink. In this sink any array must at most have a length of 100. In my data I have an array
that is always length 300 an a field specifying how many values of these are relevant (n_relevant
).
n_relevant
values might be:
E.g.:
array: [1,2,3,4,5,...300]
n_relevant: 4
desired outcome: [1,2,3,4]
array: [1,2,3,4,5,...300]
n_relevant: 200
desired outcome: [1,3,5,...199]
array: [1,2,3,4,5,...300]
n_relevant: 300
desired outcome: [1,4,7,...298]
array: [1,2,3,4,5,...300]
n_relevant: 800
desired outcome: [1,4,7,...298]
This little program reflects the desired behavior:
from math import ceil
def subsample(array:list,n_relevant:int)->list:
if n_relevant<100:
return [x for i,x in enumerate(array) if i<n_relevant]
if 100<=n_relevant<300:
mod=ceil(n_relevant/100)
return [x for i,x in enumerate(array) if i%mod==0 and i<n_relevant]
else:
return [x for i,x in enumerate(array) if i%3==0]
n_relevant=<choose n>
t1=[i for i in range(300)]
subsample(t1,n_relevant)
What I have tried:
transforms to set undesired values to 0
and remove those with array_remove could subset with a specific modulo BUT cannot adopt to n_relevant
. Specifically you cannot hand a parameter to the lambda function and you cannot dynamically change the function.
You can filter by index as follows
from pyspark.sql.types import StructField, StructType, IntegerType, ArrayType
df = spark.createDataFrame(
[[list(range(300)), 4], [list(range(300)), 200], [list(range(300)), 300], [list(range(300)), 800]],
schema=StructType(
[
StructField("array", ArrayType(IntegerType())),
StructField("n_relevant", IntegerType()),
]
),
)
df = df.withColumn(
"result",
F.when(F.col("n_relevant") <= 100, F.slice("array", 1, F.col("n_relevant")))
.when(
F.col("n_relevant") <= 200,
F.filter(
F.slice("array", 1, F.col("n_relevant")), lambda _, index: index % 2 == 0
),
)
.otherwise(
F.filter(
F.slice("array", 1, F.col("n_relevant")), lambda elem, index: index % 3 == 0
)
),
)
display(df)