I am struggling with the implementation of a performant version of a SOM Batch algorithm on Spark / Pyspark for a huge dataset with > 100 features. I have the feeling that I can either use RDDs where I can/have to specifiy the parallization on my own or I use Dataframe which should be more performant but I see no way how to use something like a local accumulation variable for each worker when using dataframes.
Ideas:
Any thoughts on the different options? Is there an even better option?
Or are all ideas not that good and I should just preselect a maximum variety subset of my dataset and train a SOM locally on that. Thanks!
This is exactly what I have done last year, so I might be in a good position to give you an answer.
First, here is my Spark implementation of the batch SOM algorithm (it is written in Scala, but most things will be similar in Pyspark).
I needed this algorithm for a project, and every implementation I found had at least one of these two problems or limitations:
fit()
/transform()
API operating over DataFrames.So, there I went on to code it myself: the batch SOM algorithm in Spark ML style. The first thing I did was looking how k-means was implemented in Spark ML, because as you know, the batch SOM is very similar to the k-means algorithm. Actually, I could re-use a large portion of the Spark ML k-means code, but I had to modify the core algorithm and the hyperparameters.
I can summarize quickly how the model is built:
SOMParams
class, containing the SOM hyperparameters (size, training parameters, etc.)SOM
class, which inherits from spark's Estimator
, and contains the training algorithm. In particular, it contains a fit()
method that operates on an input DataFrame
, where features are stored as a spark.ml.linalg.Vector
in a single column. fit()
will then select this column and unpack the DataFrame
to obtain the unerlying RDD[Vector]
of features, and call the run()
method on it. This is where all the computations happen, and as you guessed, it uses RDD
s, accumulators and broadcast variables. Finally, the fit()
method returns a SOMModel
object.SOMModel
is a trained SOM model, and inherits from spark's Transformer
/Model
. It contains the map prototypes (center vectors), and contains a transform()
method that can operate on DataFrames
by taking an input feature column, and adding a new column with the predictions (projection on the map). This is done by a prediction UDF.SOMTrainingSummary
that collects stuff such as the objective function.Here are the take-aways:
RDD
and DataFrame
s (or rather Dataset
s, but the difference between those two is of no real importance here). They are just used in different contexts. In fact, a DataFrame can be seen as a RDD
specialized for manipulating structured data organized in columns (such as relational tables), allowing SQL-like operations and an optimization of the execution plan (Catalyst optimizer).Dataframe
s, always.RDD
API and distribute your computations yourself, using map/mapPartitions/foreach/reduce/reduceByKey/and so son. Look at how things are done in MLlib: it's only a nice wrapper around RDD manipulations!Hope it will solve your question. Concerning performance, as you asked for an efficient implementation, I did not make any benchmarks yet but I use it at work and it crunches 500k/1M-rows datasets in a couple of minutes on the production cluster.