I have the following sparkdataframe:
id weekly_sale
1 40000
2 120000
3 135000
4 211000
5 215000
6 331000
7 337000
I need to see in which of the following intervals items in weekly_sale column fall:
under 100000
between 100000 and 200000
between 200000 and 300000
more than 300000
so my desired output would be like:
id weekly_sale label
1 40000 under 100000
2 120000 between 100000 and 200000
3 135000 between 100000 and 200000
4 211000 between 200000 and 300000
5 215000 between 200000 and 300000
6 331000 more than 300000
7 337000 more than 300000
any pyspark, spark.sql and Hive context implementation will help me.
Assuming ranges and labels are defined as follows:
splits = [float("-inf"), 100000.0, 200000.0, 300000.0, float("inf")]
labels = [
"under 100000", "between 100000 and 200000",
"between 200000 and 300000", "more than 300000"]
df = sc.parallelize([
(1, 40000.0), (2, 120000.0), (3, 135000.0),
(4, 211000.0), (5, 215000.0), (6, 331000.0),
(7, 337000.0)
]).toDF(["id", "weekly_sale"])
one possible approach is to use Bucketizer
:
from pyspark.ml.feature import Bucketizer
from pyspark.sql.functions import array, col, lit
bucketizer = Bucketizer(
splits=splits, inputCol="weekly_sale", outputCol="split"
)
with_split = bucketizer.transform(df)
and attach labels later:
label_array = array(*(lit(label) for label in labels))
with_split.withColumn(
"label", label_array.getItem(col("split").cast("integer"))
).show(10, False)
## +---+-----------+-----+-------------------------+
## |id |weekly_sale|split|label |
## +---+-----------+-----+-------------------------+
## |1 |40000.0 |0.0 |under 100000 |
## |2 |120000.0 |1.0 |between 100000 and 200000|
## |3 |135000.0 |1.0 |between 100000 and 200000|
## |4 |211000.0 |2.0 |between 200000 and 300000|
## |5 |215000.0 |2.0 |between 200000 and 300000|
## |6 |331000.0 |3.0 |more than 300000 |
## |7 |337000.0 |3.0 |more than 300000 |
## +---+-----------+-----+-------------------------+
There are of course different ways you can achieve the same goal. For example you can create a lookup table:
from toolz import sliding_window
from pyspark.sql.functions import broadcast
mapping = [
(lower, upper, label) for ((lower, upper), label)
in zip(sliding_window(2, splits), labels)
]
lookup_df =sc.parallelize(mapping).toDF(["lower", "upper", "label"])
df.join(
broadcast(lookup_df),
(col("weekly_sale") >= col("lower")) & (col("weekly_sale") < col("upper"))
).drop("lower").drop("upper")
or generate lookup expression:
from functools import reduce
from pyspark.sql.functions import when
def in_range(c):
def in_range_(acc, x):
lower, upper, label = x
return when(
(c >= lit(lower)) & (c < lit(upper)), lit(label)
).otherwise(acc)
return in_range_
label = reduce(in_range(col("weekly_sale")), mapping, lit(None))
df.withColumn("label", label)
The least efficient approach is an UDF.