haskellpurely-functional

Finding a more idiomatic and shorter way of writing balanced binary tree insert in Haskell


I wrote a Binary tree(NOT a search tree) data type in Haskell and wrote the insert function such that the binary tree always remains balanced after each insert. Even though it works correctly, it does not look very pretty and I'm not able to see a way to make it more Haskell-like. Is there a way to make the insert prettier or is this it ? Code :

data Tree a = Leaf
  | Node Integer (Tree a) a (Tree a)
    deriving (Show, Eq)

getHeight :: Tree a -> Integer
getHeight Leaf = -1
getHeight (Node height _ _ _) = height

insert :: a -> Tree a -> Tree a
insert val Leaf = Node 0 Leaf val Leaf
insert val (Node height lTree curVal rTree)
  | getHeight lTree <= getHeight rTree = let newLTree = insert val lTree in Node (1 + max (getHeight newLTree) (getHeight rTree)) newLTree curVal rTree
  | otherwise = let newRTree = insert val rTree in Node (1 + max (getHeight lTree) (getHeight newRTree)) lTree curVal newRTree

foldTree :: [a] -> Tree a
foldTree = foldr insert Leaf

Solution

  • Well, the first thing is that in your second branch, you can known which argument max is going to return, because you already compared the two heights. So:

    insert val (Node height lTree curVal rTree)
      | getHeight lTree <= getHeight rTree = let newLTree = insert val lTree in Node (1 + max (getHeight newLTree) (getHeight rTree)) newLTree curVal rTree
      | otherwise = let newRTree = insert val rTree in Node (1 + getHeight lTree) lTree curVal newRTree
    

    In fact, if you split the <= case out into < and == cases, you can always know which one max will return. You can use compare to get a nice ADT distinguishing the three cases:

    insert val (Node height lTree curVal rTree) = case compare (getHeight lTree) (getHeight rTree) of
        LT -> Node (1 + getHeight rTree ) lTree' curVal rTree
        EQ -> Node (1 + getHeight lTree') lTree' curVal rTree
        GT -> Node (1 + getHeight lTree ) lTree  curVal rTree'
        where
        lTree' = insert val lTree
        rTree' = insert val rTree
    

    I've also split out the lets into a where block in this change. That's mostly an aesthetic change; in this case, driven by the desire to make it visually obvious how the cases correspond with each other. Laziness will ensure that only the appropriate one of the two new trees will get calculated. Actually, I'd go even a bit further: in the LT and GT cases, you know the height of the outer node won't change, so you can reuse the existing height.

    insert val (Node height lTree curVal rTree) = case compare (getHeight lTree) (getHeight rTree) of
        LT -> Node height                 lTree' curVal rTree
        EQ -> Node (1 + getHeight lTree') lTree' curVal rTree
        GT -> Node height                 lTree  curVal rTree'
        where
        lTree' = insert val lTree
        rTree' = insert val rTree
    

    If going the insert/foldTree route, I'd stop here, I think. That looks fairly idiomatic to me.

    But I think I'd also ponder whether it might be worthwhile to directly build the tree of interest, without iterated insertions. One way to do this might be to make a stack of increasingly-deep trees, coalescing trees of equal height at each step:

    push :: a -> [Tree a] -> [Tree a]
    push a (t:t':ts) | h == h' = Node (h+1) t a t' : ts
        where [h, h'] = map getHeight [t, t']
    push a ts = Node 0 Leaf a Leaf : ts
    

    We can combine two trees from this stack by taking the smaller tree as the left branch, and combining the two children of the larger tree as the right branch. So:

    collapse :: [Tree a] -> Tree a
    collapse [] = Leaf
    collapse [t] = t
    collapse (t:t':ts) = collapse $ case t' of
        Leaf -> ts
        Node _ l a r -> Node (1 + getHeight r') t a r' : ts
            where r' = collapse [l, r]
    

    (It is a bit subtle that this is correct, actually! It's not at all obvious that this should produce minimal-height trees nor that discarding t when t' is a Leaf is okay. But both are true.)

    Now we can combine these two algorithms to get foldTree:

    foldTree :: [a] -> Tree a
    foldTree = collapse . foldr push []
    

    This method definitely requires more code, and more thought and insight. However, the advantage of this approach is that you get linear run time, rather than the linearithmic time of the repeated-insert version.