apache-sparkpysparkfpgrowth

PySpark :: FP-growth algorithm ( raise ValueError("Params must be either a param map or a list/tuple of param maps, ")


I am the beginner with PySpark. I am using FPgrowth computing association in PySpark. I followed the steps below.

Data Example

from pyspark.sql.session import SparkSession

spark = SparkSession.builder.getOrCreate()

# make some test data
columns = ['customer_id', 'product_id']
vals = [
     (370, 154),
     (41, 40),
     (109, 173),
     (18, 55),
     (105, 126),
     (370, 121),
     (41, 32323),
     (109, 22),
     (18, 55),
     (105, 133),
     (109, 22),
     (18, 55),
     (105, 133)
]

df = spark.createDataFrame(vals, columns)

df.show()
+-----------+----------+
|customer_id|product_id|
+-----------+----------+
|        370|       154|
|         41|        40|
|        109|       173|
|         18|        55|
|        105|       126|
|        370|       121|
|         41|     32323|
|        109|        22|
|         18|        55|
|        105|       133|
|        109|        22|
|         18|        55|
|        105|       133|
+-----------+----------+

### Prepare input data
from pyspark.sql.functions import collect_list, col

transactions = df.groupBy("customer_id")\
      .agg(collect_list("product_id").alias("product_ids"))\
      .rdd\
      .map(lambda x: (x.customer_id, x.product_ids))

transactions.collect()
[(370, [121, 154]),
 (41, [32323, 40]),
 (105, [133, 133, 126]),
 (18, [55, 55, 55]),
 (109, [22, 173, 22])]

## Convert .rdd to spark dataframe 
df2 = spark.createDataFrame(transactions)
df2.show()
+---+---------------+
| _1|             _2|
+---+---------------+
|370|     [121, 154]|
| 41|    [32323, 40]|
|105|[126, 133, 133]|
| 18|   [55, 55, 55]|
|109|  [22, 173, 22]|
+---+---------------+

df3 = df2.selectExpr("_1 as customer_id", "_2 as product_id")
df3.show()
df3.printSchema()
+-----------+---------------+
|customer_id|     product_id|
+-----------+---------------+
|        370|     [154, 121]|
|         41|    [32323, 40]|
|        105|[126, 133, 133]|
|         18|   [55, 55, 55]|
|        109|  [173, 22, 22]|
+-----------+---------------+

root
 |-- customer_id: long (nullable = true)
 |-- product_id: array (nullable = true)
 |    |-- element: long (containsNull = true)

 ## FPGrowth Model Building
 from pyspark.ml.fpm import FPGrowth
 fpGrowth = FPGrowth(itemsCol="product_id", minSupport=0.5, minConfidence=0.6)
 model = fpGrowth.fit(df3)

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-12-aa1f71745240> in <module>()
----> 1 model = fpGrowth.fit(df3)

/usr/lib/spark/python/pyspark/ml/base.py in fit(self, dataset, params)
     62                 return self.copy(params)._fit(dataset)
     63             else:
---> 64                 return self._fit(dataset)
     65         else:
     66             raise ValueError("Params must be either a param map or a list/tuple of param maps, "

/usr/lib/spark/python/pyspark/ml/wrapper.py in _fit(self, dataset)
    263 
    264     def _fit(self, dataset):
--> 265         java_model = self._fit_java(dataset)
    266         return self._create_model(java_model)
    267 

/usr/lib/spark/python/pyspark/ml/wrapper.py in _fit_java(self, dataset)
    260         """
    261         self._transfer_params_to_java()
--> 262         return self._java_obj.fit(dataset._jdf)
    263 
    264     def _fit(self, dataset):

/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1131         answer = self.gateway_client.send_command(command)
   1132         return_value = get_return_value(
-> 1133             answer, self.gateway_client, self.target_id, self.name)
   1134 
   1135         for temp_arg in temp_args:

/usr/lib/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    317                 raise Py4JJavaError(
    318                     "An error occurred while calling {0}{1}{2}.\n".
--> 319                     format(target_id, ".", name), value)
    320             else:
    321                 raise Py4JError(

I looked up but I did not figure out what went wrong. The only thing maybe I could point to I converted the RDD to a dataframe.

Can anybody point me to what I am doing wrong?


Solution

  • Well, I just realised FPGrowth from pyspark.ml.fpm takes a PySpark dataframe, not an rdd. So above mentioned method converted my dataset to an rdd.

    I was able to avoid the situation by using PySpark collect_set list with groupby to get a data frame and pass on.

    from pyspark.sql.session import SparkSession
    
    # instantiate Spark
    spark = SparkSession.builder.getOrCreate()
    
    # make some test data
    columns = ['customer_id', 'product_id']
    vals = [
         (370, 154),
         (370, 40),
         (370, 173),
         (41, 55),
         (41, 126),
         (41, 121),
         (41, 321),
         (105, 22),
         (105, 55),
         (105, 133),
         (109, 22),
         (109, 55),
         (109, 133)    
    ]
    
    
    # create DataFrame
    df = spark.createDataFrame(vals, columns)
    
    df.show()
    +-----------+----------+
    |customer_id|product_id|
    +-----------+----------+
    |        370|       154|
    |        370|        40|
    |        370|       173|
    |         41|        55|
    |         41|       126|
    |         41|       121|
    |         41|     32323|
    |        105|        22|
    |        105|        55|
    |        105|       133|
    |        109|        22|
    |        109|        55|
    |        109|       133|
    +-----------+----------+
    
    # Create dataframe for FPGrowth model input
    from pyspark.sql.functions import collect_list, col
    from pyspark.sql import functions as F 
    from pyspark.sql.functions import *
    transactions = df.groupBy("customer_id")\
          .agg(F.collect_set("product_id"))
           
    transactions.show()
    +-----------+-----------------------+
    |customer_id|collect_set(product_id)|
    +-----------+-----------------------+
    |        370|         [154, 173, 40]|
    |         41|    [321, 121, 126, 55]|
    |        105|          [133, 22, 55]|
    |        109|          [133, 22, 55]|
    +-----------+-----------------------+
    
    # FPGrowth model 
    from pyspark.ml.fpm import FPGrowth
    fpGrowth = FPGrowth(itemsCol="collect_set(product_id)", minSupport=0.5, minConfidence=0.6
     model_working = fpGrowth.fit(transactions)
    
    # Display frequent itemsets
    model_working.freqItemsets.show()
    +-------------+----+
    |        items|freq|
    +-------------+----+
    |         [55]|   3|
    |         [22]|   2|
    |     [22, 55]|   2|
    |        [133]|   2|
    |    [133, 22]|   2|
    |[133, 22, 55]|   2|
    |    [133, 55]|   2|
    +-------------+----+
    
    # Display generated association rules.
    model_working.associationRules.show()
    
    # transform examines the input items against all the association rules and summarise the
    # consequents as prediction
    model_working.transform(transactions).show()
    
    +----------+----------+------------------+
    |antecedent|consequent|        confidence|
    +----------+----------+------------------+
    |     [133]|      [22]|               1.0|
    |     [133]|      [55]|               1.0|
    | [133, 55]|      [22]|               1.0|
    | [133, 22]|      [55]|               1.0|
    |      [22]|      [55]|               1.0|
    |      [22]|     [133]|               1.0|
    |      [55]|      [22]|0.6666666666666666|
    |      [55]|     [133]|0.6666666666666666|
    |  [22, 55]|     [133]|               1.0|
    +----------+----------+------------------+
    
    +-----------+-----------------------+----------+
    |customer_id|collect_set(product_id)|prediction|
    +-----------+-----------------------+----------+
    |        370|         [154, 173, 40]|        []|
    |         41|    [321, 121, 126, 55]| [22, 133]|
    |        105|          [133, 22, 55]|        []|
    |        109|          [133, 22, 55]|        []|
    +-----------+-----------------------+----------+