haskellgenericsscrap-your-boilerplate

How to fold over a constructor with special cases?


So I have a tree that I want to collapse where the nodes are of type

data Node = Node1 Node | Node2 Node Node | ... deriving Data

except for a few special cases. I want to do something along the lines of

collapse SPECIALCASE1 = ...
collapse SPECIALCASE2 = ...
...
collapse node = foldl (++) $ gmapQ validate node

where all the special cases generate lists of results that the last case just recursively collapses; but this doesn't work as the function that is the first parameter of gmapQ has to be of type forall d. Data d => d -> u and not Node -> u, which as far as I know just limits you to only using functions operating on the Data type.

Is there any way of coercing the values in the problem to be of the correct type, or another more lenient map function perhaps?

Extra info:

The actual code for the function described above as collapse is named validate and is for traversing and finding unbound variables in an abstract syntax tree (for a very simple language) for which the special cases are handled like this

validate _ (Nr _) = []
validate env (Let var val expr) = validate env val ++ validate (var:env) expr
validate env (Var var) = if elem var env then [] else [var]

which is essentially the rules that literal numbers don't have variables in them, let expressions binds a variable and variables need to be checked if bound or not. Every other construct in this toy-language is just a combination of numbers and variables (e.g. summation, multiplication, etc.) and as such when I check for unbound variables I just need to traverse their sub-trees and combine the results; thus the gmapQ.

Extra info 2:

The actual data type used instead of the Node example above is of the form

data Ast = Nr Int
         | Sum Ast Ast
         | Mul Ast Ast
         | Min Ast
         | If Ast Ast Ast
         | Let String Ast Ast
         | Var String
           deriving (Show, Eq, Data)

Solution

  • The direct way to do what you want is to write your special case for validate as:

    validate env expr = concat $ gmapQ ([] `mkQ` (validate env)) expr
    

    This uses mkQ from Data.Generics.Aliases. The whole point of mkQ is to create queries of type forall d. Data d => d -> u that can operate differently on different Data instances. By the way, there's no magic here. You could have defined it manually in terms of cast as:

    validate env expr = concat $ gmapQ myQuery expr
      where myQuery :: Data d => d -> [String]
            myQuery d = case cast d of Just d -> validate env d
                                       _ -> []
    

    Still, I've generally found it clearer to use uniplate from the lens library. The idea is to create a default Plated instance:

    instance Plated Ast where
      plate = uniplate   -- uniplate from Data.Data.Lens 
    

    which magically defines children :: Ast -> [Ast] to return all direct descendants of a node. You can then write your default validate case as:

    validate env expr = concatMap (validate env) (children expr)
    

    The full code w/ a test that prints ["z"]:

    {-# LANGUAGE DeriveDataTypeable #-}
    
    module SpecialCase where
    
    import Control.Lens.Plated
    import Data.Data
    import Data.Data.Lens (uniplate)
    
    data Ast = Nr Int
             | Sum Ast Ast
             | Mul Ast Ast
             | Min Ast
             | If Ast Ast Ast
             | Let String Ast Ast
             | Var String
               deriving (Show, Eq, Data)
    
    instance Plated Ast where
      plate = uniplate
    
    validate env (Let var val expr) = validate env val ++ validate (var:env) expr
    validate env (Var var) = if elem var env then [] else [var]
    -- either use this uniplate version:
    validate env expr = concatMap (validate env) (children expr)
    -- or use the alternative, lens-free version:
    -- validate env expr = concat $ gmapQ ([] `mkQ` (validate env)) expr
    
    main = print $ validate [] (Let "x" (Nr 3) (Let "y" (Var "x") 
                 (Sum (Mul (Var "x") (Var "z")) (Var "y"))))