apache-sparkparallel-processingpysparksom

Efficient implementation of SOM (Self organizing map) on Pyspark


I am struggling with the implementation of a performant version of a SOM Batch algorithm on Spark / Pyspark for a huge dataset with > 100 features. I have the feeling that I can either use RDDs where I can/have to specifiy the parallization on my own or I use Dataframe which should be more performant but I see no way how to use something like a local accumulation variable for each worker when using dataframes.

Ideas:

Any thoughts on the different options? Is there an even better option?

Or are all ideas not that good and I should just preselect a maximum variety subset of my dataset and train a SOM locally on that. Thanks!


Solution

  • This is exactly what I have done last year, so I might be in a good position to give you an answer.

    First, here is my Spark implementation of the batch SOM algorithm (it is written in Scala, but most things will be similar in Pyspark).

    I needed this algorithm for a project, and every implementation I found had at least one of these two problems or limitations:

    So, there I went on to code it myself: the batch SOM algorithm in Spark ML style. The first thing I did was looking how k-means was implemented in Spark ML, because as you know, the batch SOM is very similar to the k-means algorithm. Actually, I could re-use a large portion of the Spark ML k-means code, but I had to modify the core algorithm and the hyperparameters.

    I can summarize quickly how the model is built:

    1. A SOMParams class, containing the SOM hyperparameters (size, training parameters, etc.)
    2. A SOM class, which inherits from spark's Estimator, and contains the training algorithm. In particular, it contains a fit() method that operates on an input DataFrame, where features are stored as a spark.ml.linalg.Vector in a single column. fit() will then select this column and unpack the DataFrame to obtain the unerlying RDD[Vector] of features, and call the run() method on it. This is where all the computations happen, and as you guessed, it uses RDDs, accumulators and broadcast variables. Finally, the fit() method returns a SOMModel object.
    3. SOMModel is a trained SOM model, and inherits from spark's Transformer/Model. It contains the map prototypes (center vectors), and contains a transform() method that can operate on DataFrames by taking an input feature column, and adding a new column with the predictions (projection on the map). This is done by a prediction UDF.
    4. There is also SOMTrainingSummary that collects stuff such as the objective function.

    Here are the take-aways:

    Hope it will solve your question. Concerning performance, as you asked for an efficient implementation, I did not make any benchmarks yet but I use it at work and it crunches 500k/1M-rows datasets in a couple of minutes on the production cluster.