UPDATE 2024.03.16: Provided code that produces the correct output, but is still not tail-recursive.
How can I create a tail recursive merge
method in Scala on a self-referential tree structure (or is it even possible)?
I have been working on this problem for several days now. I've read articles about approaching it, even in other languages. I've even submitted it to the various AIs (Bard, Copilot, AskCodi, etc.), and they return non-functioning code that STILL cannot be compiled with the @tailrec
annotation.
I must be missing some simple mental leap on converting the merge
method in the Node
case class to be tail recursive. Any guidance would be appreciated.
Especially anything (links to books, videos, articles, etc.) that offers a meta-cognitive way to "think through" this self-referential constructed from the bottom-up style of solution.
And finally, I now know some problems cannot be solved with tail recursion. Is this the case with this one? And if so, why?
CORRECTED CODE:
This performs the desired effect, but isn't tail-recursive.
object Node {
val Terminal: Node = Node(true, Map.empty)
}
final case class Node(
isWord: Boolean
, nodeByLetter: Map[Char, Node]
) {
require(
isWord || nodeByLetter.nonEmpty
, s"either isWord [$isWord] or nodeByLetter.nonEmpty [${nodeByLetter.nonEmpty}] must be true")
def merge(that: Node): Node = {
//@tailrec
def recursive(cursor: (Node, Node) = (this, that)): Node = {
cursor match {
case (Node.Terminal, Node.Terminal) =>
Node.Terminal
case (left, Node.Terminal) =>
if (left.isWord)
left
else
left.copy(isWord = true)
case (Node.Terminal, right) =>
if (right.isWord)
right
else
right.copy(isWord = true)
case (left, right) =>
val lettersToMerge =
left.nodeByLetter
.keySet
.filter(
letter =>
right.nodeByLetter.keySet.contains(letter)
&& (left.nodeByLetter(letter) != right.nodeByLetter(letter)))
if (lettersToMerge.isEmpty)
Node(
left.isWord || right.isWord
, right.nodeByLetter ++ left.nodeByLetter)
else {
val nodeKeysAll = (left.nodeByLetter.keySet ++ right.nodeByLetter.keySet)
.toList
.sorted
val nodes = nodeKeysAll
.map(
letter =>
if (lettersToMerge.contains(letter)) {
//this call fails the @tailrec annotation
recursive(left.nodeByLetter(letter), right.nodeByLetter(letter))
} else
left.nodeByLetter.getOrElse(letter, right.nodeByLetter(letter))
)
val nodeByLetter = {
nodes
.zip(nodeKeysAll)
.map(_.swap)
.toMap
}
Node(
left.isWord || right.isWord
, nodeByLetter
)
}
}
}
recursive()
}
}
When the @tailrec
line in method merge
is uncommented, the line...
recursive(left.nodeByLetter(letter), right.nodeByLetter(letter))
... highlights with a red squiggly (in IntelliJ) and reports the error...
Recursive call not in tail position (in @tailrec annotated method)
.
Here's the sample data I am using to ensure that the resulting function works:
object Main {
def main(args: Array[String]): Unit = {
//cat
val t = Node.Terminal
val at = Node(false, Map('t' -> t))
val cat = Node(false, Map('a' -> at))
val catRoot = Node(false, Map('c' -> cat))
//camp - intentionally not in alpha order
val p = Node.Terminal
val mp = Node(true, Map('p' -> p))
val amp = Node(false, Map('m' -> mp))
val camp = Node(false, Map('a' -> amp))
val campRoot = Node(false, Map('c' -> camp))
val root = catRoot.merge(campRoot)
println("----------------")
println("root: " + root)
}
}
And the output should look like this:
----------------
root: Node(false,Map(c -> Node(false,Map(a -> Node(false,Map(t -> Node(true,Map()), m -> Node(true,Map(p -> Node(true,Map())))))))))
ORIGINAL POSTED CODE WAS INCORRECT.
It doesn't perform the desired effect, much less is not tail-recursive. I've left it per the StackOverflow rules regarding "updating a Question". The corrected code is above.
case class Node(isWord: Boolean, nodeByLetter: Map[Char, Node]) {
//@tailrec
final def merge(that: Node): Node = {
val mergedIsWord = this.isWord || that.isWord
val mergedNodes =
(this.nodeByLetter.keySet ++ that.nodeByLetter.keySet)
.map(letter =>
(
letter
, (this.nodeByLetter.get(letter), that.nodeByLetter.get(letter)) match {
case (Some(thisNode), Some(thatNode)) =>
thisNode.merge(thatNode)
case (Some(thisNode), None) =>
thisNode
case (None, Some(thatNode)) =>
thatNode
case _ =>
throw new IllegalStateException("should never get here")
}))
.toMap
Node(mergedIsWord, mergedNodes)
}
}
Summary:
The answer to the parenthetical question in the OP title...
...(or is it even possible)?
...is "Yes".
The answer to the full question in the OP title...
How can I create a tail recursive merge method in Scala on a self-referential tree structure?
...is to "Move to a heap-based strategy anytime there is a requirement for anything to follow the recursive call, even making an additional recursive call, before returning."
Details:
The DAWG (Directed Acyclic Word Graph) problem is solvable, but not (easily) using the @tailrec
annotation. Instead, it more simply requires using an FP recursion concept called a trampoline. The concept is also referred to as a CPS (Continuation Passing Style).
Because of this well-structured and presented article on what and how to use a trampoline (specifically excluding the hand-waving of the mind-bending sequence
method which I explored understanding on Reddit), I was able to derive a fully working answer, which you can see detailed below. Here's another great (newbie-oriented) article I found.
import scala.annotation.tailrec
import scala.util.control.TailCalls._
object Node {
val Terminal: Node = Node(true, Map.empty)
def encode(word: String): Node =
word
.reverse
.foldLeft(Node.Terminal) {
(node, letter) =>
Node(false, Map(letter -> node))
}
}
final case class Node(
isWord: Boolean
, nodeByLetter: Map[Char, Node]
) {
require(
isWord || nodeByLetter.nonEmpty
, s"either isWord [$isWord] or nodeByLetter.nonEmpty [${nodeByLetter.nonEmpty}] must be true")
def find(chars: String): Boolean = {
@tailrec
def recursive(charsRemaining: String = chars, node: Node = this): Boolean =
charsRemaining match {
case "" => node.isWord
case charsRemainder =>
node.nodeByLetter.get(charsRemainder.head) match {
case Some(nextNode) => recursive(charsRemainder.tail, nextNode)
case None => false
}
}
recursive()
}
def merge(that: Node): Node = {
def sequence[A](listTailRecA: List[TailRec[A]]): TailRec[List[A]] =
listTailRecA
.reverse
.foldLeft(done(Nil): TailRec[List[A]]) {
(tailRecListA, tailRecA) =>
tailRecA map ((_: A) :: (_: List[A])).curried flatMap tailRecListA.map
}
def recursive(cursor: (Node, Node) = (this, that)): TailRec[Node] = {
cursor match {
case (Node.Terminal, Node.Terminal) =>
done(Node.Terminal)
case (left, Node.Terminal) =>
if (left.isWord)
done(left)
else
done(left.copy(isWord = true))
case (Node.Terminal, right) =>
if (right.isWord)
done(right)
else
done(right.copy(isWord = true))
case (left, right) =>
val lettersToMerge =
left.nodeByLetter
.keySet
.filter(
letter =>
right.nodeByLetter.keySet.contains(letter)
&& (left.nodeByLetter(letter) != right.nodeByLetter(letter)))
if (lettersToMerge.isEmpty)
done(Node(
left.isWord || right.isWord
, right.nodeByLetter ++ left.nodeByLetter))
else {
val nodeKeysAll = (left.nodeByLetter.keySet ++ right.nodeByLetter.keySet)
.toList
.sorted
val listTailRecNode = nodeKeysAll
.map(
letter =>
if (lettersToMerge.contains(letter))
tailcall(recursive(left.nodeByLetter(letter), right.nodeByLetter(letter)))
else
done(left.nodeByLetter.getOrElse(letter, right.nodeByLetter(letter)))
)
val tailRecListNode = sequence(listTailRecNode)
tailRecListNode
.map(
ln => {
val nodeByLetter = ln.zip(nodeKeysAll).map(_.swap).toMap
Node(
left.isWord || right.isWord
, nodeByLetter
)
})
}
}
}
recursive().result
}
}