I've been trying to find a way to count the number of times sets of Strings occur in a transaction database (implementing the Apriori algorithm in a distributed fashion). The code I have currently is as follows:
val cand_br = sc.broadcast(cand)
transactions.flatMap(trans => freq(trans, cand_br.value))
.reduceByKey(_ + _)
}
def freq(trans: Set[String], cand: Array[Set[String]]) : Array[(Set[String],Int)] = {
var res = ArrayBuffer[(Set[String],Int)]()
for (c <- cand) {
if (c.subsetOf(trans)) {
res += ((c,1))
}
}
return res.toArray
}
transactions starts out as an RDD[Set[String]]
, and I'm trying to convert it to an RDD[(K, V)
, with K every element in cand
and V the number of occurrences of each element in cand
in the transaction list.
When watching performance on the UI, the flatMap stage quickly takes about 3min to finish, whereas the rest takes < 1ms.
transactions.count() ~= 88000
and cand.length ~= 24000
for an idea of the data I'm dealing with. I've tried different ways of persisting the data, but I'm pretty positive that it's an algorithmic problem I am faced with.
Is there a more optimal solution to solve this subproblem?
PS: I'm fairly new to Scala / Spark framework, so there might be some strange constructions in this code
Probably, the right question to ask in this case would be: "what is the time complexity of this algorithm". I think it is very much unrelated to Spark's flatMap operation.
Given 2 collections of Sets of size m
and n
, this algorithm is counting how many elements of one collection are a subset of elements of the other collection, so it looks like complexity m x n
. Looking one level deeper, we also see that 'subsetOf' is linear of the number of elements of the subset. x subSet y
== x forAll y
, so actually the complexity is m x n x s
where s
is the cardinality of the subsets being checked.
In other words, this flatMap
operation has a lot of work to do.
Now, going back to Spark, we can also observe that this algo is embarrassingly parallel and we can take advantage of Spark's capabilities to our advantage.
To compare some approaches, I loaded the 'retail' dataset [1] and ran the algo on val cand = transactions.filter(_.size<4).collect
. Data size is a close neighbor of the question:
Some comparative runs on local mode:
transactions
partitions up to # of cores (8): 33 secsI also tried an alternative implementation, using cartesian
instead of flatmap
:
transactions
.cartesian(candRDD)
.map{case (tx, cd) => (cd, if (cd.subsetOf(tx)) 1 else 0)}
.reduceByKey(_ + _)
.collect
But that resulted in much longer runs as seen in the top 2 lines of the Spark UI (cartesian and cartesian with a higher number of partitions): 2.5 min
Given I only have 8 logical cores available, going above that does not help.
Is there any added 'Spark flatMap time complexity'? Probably some, as it involves serializing closures and unpacking collections, but negligible in comparison with the function being executed.
Let's see if we can do a better job: I implemented the same algo using plain scala:
val resLocal = reduceByKey(transLocal.flatMap(trans => freq(trans, cand)))
Where the reduceByKey
operation is a naive implementation taken from [2]
Execution time: 3.67 seconds.
Sparks gives you parallelism out of the box. This impl is totally sequential and therefore takes longer to complete.
Last sanity check: A trivial flatmap operation:
transactions
.flatMap(trans => Seq((trans, 1)))
.reduceByKey( _ + _)
.collect
Execution time: 0.88 secs
Spark is buying you parallelism and clustering and this algo can take advantage of it. Use more cores and partition the input data accordingly.
There's nothing wrong with flatmap
. The time complexity prize goes to the function inside it.