I want to find all partitions of a n elements into k subsets, this is my algorithm based on recursive formula for finding all Stirling second numbers
fun main(args: Array<String>) {
val s = mutableSetOf(1, 2, 3, 4, 5)
val partitions = 3
val res = mutableSetOf<MutableSet<MutableSet<Int>>>()
partition(s, partitions, res)
//println(res)
println("Second kind stirling number ${res.size}")
}
fun partition(inputSet: MutableSet<Int>, numOfPartitions: Int, result: MutableSet<MutableSet<MutableSet<Int>>>) {
if (inputSet.size == numOfPartitions) {
val sets = inputSet.map { mutableSetOf(it) }.toMutableSet()
result.add(sets)
}
else if (numOfPartitions == 1) {
result.add(mutableSetOf(inputSet))
}
else {
val popped: Int = inputSet.first().also { inputSet.remove(it) }
val r1 = mutableSetOf<MutableSet<MutableSet<Int>>>()
partition(inputSet, numOfPartitions, r1) //add popped to each set in solution (all combinations)
for (solution in r1) {
for (set in solution) {
set.add(popped)
result.add(solution.map { it.toMutableSet() }.toMutableSet()) //deep copy
set.remove(popped)
}
}
val r2 = mutableSetOf<MutableSet<MutableSet<Int>>>()
partition(inputSet, numOfPartitions - 1, r2) //popped is single elem set
r2.map { it.add(mutableSetOf(popped)) }
r2.map { result.add(it) }
}
}
Code works well for k = 2, but for bigger n and k it loses some partitions and I can't find a mistake here.
Example: n = 5 and k = 3 outputs
Second kind stirling number 19
the correct output would be 25.
If you can read Python code, consider the next algorithm which I've quickly adapted from my implementation of set partition into equal size parts.
Recursive function fills K parts with N values.
The lastfilled
parameter helps to avoid duplicates - it provides an increasing sequence of leading (smallest) elements of every part.
The empty
parameter is intended to avoid empty parts.
def genp(parts:list, empty, n, k, m, lastfilled):
if m == n:
print(parts)
global c
c+=1
return
if n - m == empty:
start = k - empty
else:
start = 0
for i in range(start, min(k, lastfilled + 2)):
parts[i].append(m)
if len(parts[i]) == 1:
empty -= 1
genp(parts, empty, n, k, m+1, max(i, lastfilled))
parts[i].pop()
if len(parts[i]) == 0:
empty += 1
def setkparts(n, k):
parts = [[] for _ in range(k)]
cnts = [0]*k
genp(parts, k, n, k, 0, -1)
c = 0
setkparts(5,3)
#setkparts(7,5)
print(c)
[[0, 1, 2], [3], [4]]
[[0, 1, 3], [2], [4]]
[[0, 1], [2, 3], [4]]
[[0, 1, 4], [2], [3]]
[[0, 1], [2, 4], [3]]
[[0, 1], [2], [3, 4]]
[[0, 2, 3], [1], [4]]
[[0, 2], [1, 3], [4]]
[[0, 2, 4], [1], [3]]
[[0, 2], [1, 4], [3]]
[[0, 2], [1], [3, 4]]
[[0, 3], [1, 2], [4]]
[[0], [1, 2, 3], [4]]
[[0, 4], [1, 2], [3]]
[[0], [1, 2, 4], [3]]
[[0], [1, 2], [3, 4]]
[[0, 3, 4], [1], [2]]
[[0, 3], [1, 4], [2]]
[[0, 3], [1], [2, 4]]
[[0, 4], [1, 3], [2]]
[[0], [1, 3, 4], [2]]
[[0], [1, 3], [2, 4]]
[[0, 4], [1], [2, 3]]
[[0], [1, 4], [2, 3]]
[[0], [1], [2, 3, 4]]
25