scalahaskellfunctional-programmingfor-comprehensiondo-notation

Haskell do-notation has no equivalent for-comprehension in Scala?


Consider the following Haskell code that builds a balanced binary tree:

data Tree a = Node a (Tree a) (Tree a) | Empty

build :: Int -> [(Tree Char, Int)] 
build n = do
  let k = (n - 1) `div` 2
  (l, i) <- build k
  (r, j) <- build (n - k - 1)
  let h = i + j + 1
  (Node 'x' l r, h) : [(Node 'x' r l, h) | abs (i - j) == 1]

Attempting to convert it to Scala 3 yields the following:

enum Tree[+A]:
  case Empty
  case Node(value: A, left: Tree[A], right: Tree[A])

object Tree:

  def build(n: Int): List[(Tree[Char], Int)] =
    val k = (n - 1) / 2
    for
      (l, i) <- build(k)
      (r, j) <- build(n - k - 1)
      h = i + j + 1
      xs = if math.abs(i - j) == 1
           then List((Node('x', r, l), h))
           else Nil
    yield (Node('x', l, r), h)

The Scala code, however, doesn't compile. This is because yield wants to return a tuple, not a list of tuples.

In Haskell, the do-notation can use return (or pure) to "lift" a value to the enclosing monad, or if no return is used, the value of the last expression must already be the type of the enclosing monad. In Scala, yield is analogous to return, but the option to not use yield translates into a foreach that doesn't return anything.

The Scala code can be written as follows to return a list instead, but it's not pretty as when for-comprehension is used.

val k = (n - 1) / 2
build(k).flatMap((l, i) =>
  build(n - k - 1).flatMap { (r, j) =>
    val h = i + j + 1
    val xs =
      if math.abs(i - j) == 1
      then List((Node('x', r, l), h))
      else Nil
    (Node('x', l, r), h) :: xs
  }
)

I also attempted to nest the for-comprehensions, but that doesn't seem to be valid syntax.

My question is this: Is it possible to use for-comprehension to obtain the equivalent of the Haskell code?


Solution

  • My Scala is a bit rusty but I think you'd have to write

    object Tree:
      def build(n: Int): List[(Tree[Char], Int)] =
        val k = (n - 1) / 2
        for
          (l, i) <- build(k)
          (r, j) <- build(n - k - 1)
          h = i + j + 1
          x <- (Node('x', l, r), h) :: (
                 if math.abs(i - j) == 1
                 then List((Node('x', r, l), h))
                 else Nil)
        yield x