haskellgenericsenumerablederivingrepresentable

How to go from a value of a finite discrete type to a (Finite n) and back, using the type's derived Generic instance, in Haskell?


I have a library that currently demands of users that they provide a helper function with type:

tEnum :: (KnownNat n) => MyType -> Finite n

so that the library implementation can use a very efficient sized vector representation of a function with type:

foo :: MyType -> a

(MyType is discrete and finite.)

Assuming that deriving a Generic instance for MyType is possible, is there a way to generate tEnum automatically, thus lifting that burden from my library's users?

I would also like to go the other way; that is, automatically derive:

tGen :: (KnownNat n) => Finite n -> MyType

Solution

  • I have something working for at least the tEnum side of things. Since you did not specify your representation of Finite I used my own Finite and Nat.

    I have included a full code snippet with an example at the bottom of the post, but will only discuss the generic programming parts, leaving out the reasonably standard construction of Peano arithmetic and various useful theorems about it.

    A typeclass is used to keep track of things that can be converted into/out of these finite enums. The important bit here is the default type signatures and the default definitions: these mean that if someone derives EnumFin for a class deriving Generic, they don't have to actually write any code, as these defaults will be used. The defaults use methods from another class, which is implemented for the various kinds of things that GHC.Generics can produce. Notice that both the normal and the default signatures use (n ~ ...) => ... n instead of writing the size of the Finite directly in the type signature; this is because GHC will otherwise detect that the default signatures don't have to match the regular signatures (in the case of a class implementation that defines Size but not fromFin or toFin):

    class EnumFin a where
      type Size a :: Nat
      type Size a = GSize (Rep a)
    
      toFin :: (n ~ Size a) => a -> Finite n
      default toFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                    => a -> Finite n
      toFin = gToFin . from
    
      fromFin :: (n ~ Size a) => Finite n -> a
      default fromFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                      => Finite n -> a
      fromFin = to . gFromFin
    

    There are actually also a couple of other utility methods in the class. These are used by the actual generic implementation to get the minimum/maximum Finite n produced by an implementation (0 and n) without having to use more typeclasses & propagate KnownNat-style constraints:

      zero :: (n ~ Size a) => Finite n
      default zero :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                   => Finite n
      zero = gzero @(Rep a)
      gt :: (n ~ Size a) => Finite n
      default gt :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                   => Finite n
      gt = ggt @(Rep a)
    

    The class declaration for the generic class is fairly simple; note however that its parameter is kind * -> *, not *:

    class GEnumFin f where
      type GSize f :: Nat
      gToFin :: f a -> Finite (GSize f)
      gFromFin :: Finite (GSize f) -> f a
      gzero :: Finite (GSize f)
      ggt :: Finite (GSize f)
    

    This generics class now must be implemented for each of the relevant generic constructors. For example, U1 is a very simple one, referring to a constructor without fields, which is just encoded as the Finite number 0:

    instance GEnumFin U1 where
      type GSize U1 = 'Z
      gToFin U1 = ZF ZS
      gFromFin (ZF ZS) = U1
      gzero = ZF ZS
      ggt = ZF ZS
    

    :*: is used to combine individual fields, so both parts need to be encoded (it encodes lhs*(m+1)+rhs where m is the max value of the rhs):

    instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :*: b) where
      type GSize (a :*: b) = Plus (Times (GSize a) ('S (GSize b))) (GSize b)
      gToFin (a :*: b) = addFin (mulFin (gToFin a) (SF (ggt @b))) (gToFin b)
      gFromFin x = (gFromFin a :*: gFromFin b)
        where (a, b) = quotRemFin (toSN (ggt @a)) (toSN (ggt @b)) x
      gzero = addFin (mulFin (gzero @a) (SF (ggt @b))) (gzero @b)
      ggt = addFin (mulFin (ggt @a) (SF (ggt @b))) (ggt @b)
    

    :+: on the other hand is used when representing sums, and so must be able to encode either of its constituents (it encodes the left hand side as 0..n and the right as n+1...n+1+m):

    instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :+: b) where
      type GSize (a :+: b) = 'S (Plus (GSize a) (GSize b))
      gToFin (L1 a) = case proofPlusComm (toSN (gzero @a)) (toSN (gzero @b)) of
                        Refl -> addFin (injFin (gzero @b)) (gToFin a)
      gToFin (R1 b) = addFin (SF (ggt @a)) (gToFin b)
      gFromFin x = case proofPlusComm (toSN (ggt @a)) (toSN (ggt @b)) of
                     Refl -> splitFin (toSN (ggt @b)) (toSN (ggt @a)) x
                                      (R1 . gFromFin @b) (L1 . gFromFin @a)
      gzero = addFin (injFin (gzero @a)) (gzero @b)
      ggt = addFin (SF (ggt @a)) (ggt @b)
    

    There is also an important instance for a single constructor field, which requires that the contained type also implement EnumFin:

    instance (EnumFin a) => GEnumFin (K1 i a) where
      type GSize (K1 i a) = Size a
      gToFin (K1 a) = toFin a
      gFromFin = K1 . fromFin
      gzero = zero @a
      ggt = gt @a
    

    Finally, it is necessary to implement the M1 constructor, which is used to attach metadata to the generic tree, and which we don't care about at all here:

    instance forall i c a. (GEnumFin a) => GEnumFin (M1 i c a) where
      type GSize (M1 i c a) = GSize a
      gToFin (M1 a) = gToFin a
      gFromFin = M1 . gFromFin
      gzero = gzero @a
      ggt = ggt @a
    

    For completeness, here is a complete file that defines all of the Nat/Finite infrastructure used above and exhibits using the Generic implementation:

    {-# LANGUAGE TypeInType #-}
    {-# LANGUAGE GADTs #-}
    {-# LANGUAGE TypeFamilies #-}
    {-# LANGUAGE TypeOperators #-}
    {-# LANGUAGE UndecidableInstances #-}
    {-# LANGUAGE TypeApplications #-}
    {-# LANGUAGE ScopedTypeVariables #-}
    {-# LANGUAGE StandaloneDeriving #-}
    {-# LANGUAGE DefaultSignatures #-}
    {-# LANGUAGE FlexibleContexts #-}
    {-# LANGUAGE AllowAmbiguousTypes #-}
    {-# LANGUAGE DeriveGeneric #-}
    import GHC.Generics
    import Data.Type.Equality
    
    -- Fairly standard Peano naturals & various useful theorems about them:
    data Nat = Z | S Nat
    data SNat (n :: Nat) where
      ZS :: SNat 'Z
      SS :: SNat n -> SNat ('S n)
    deriving instance Show (SNat n)
    
    type family Plus (n :: Nat) (m :: Nat) where
      Plus 'Z m = m
      Plus ('S n) m = 'S (Plus n m)
    
    plus :: SNat n -> SNat m -> SNat (Plus n m)
    plus ZS m = m
    plus (SS n) m = SS (plus n m)
    
    proofPlusNZ :: SNat n -> Plus n 'Z :~: n
    proofPlusNZ ZS = Refl
    proofPlusNZ (SS n) = case proofPlusNZ n of Refl -> Refl
    
    proofPlusNS :: SNat n -> SNat m -> Plus n ('S m) :~: 'S (Plus n m)
    proofPlusNS ZS _ = Refl
    proofPlusNS (SS n) m = case proofPlusNS n m of Refl -> Refl
    
    proofPlusAssoc :: SNat n -> SNat m -> SNat o
                   -> Plus n (Plus m o) :~: Plus (Plus n m) o
    proofPlusAssoc ZS _ _ = Refl
    proofPlusAssoc (SS n) ZS _ = case proofPlusNZ n of Refl -> Refl
    proofPlusAssoc (SS n) (SS m) ZS =
      case proofPlusNZ m of
        Refl -> case proofPlusNZ (plus n (SS m)) of
          Refl -> Refl
    proofPlusAssoc (SS n) (SS m) (SS o) =
      case proofPlusAssoc n (SS m) (SS o) of Refl -> Refl
    
    proofPlusComm :: SNat n -> SNat m -> Plus n m :~: Plus m n
    proofPlusComm ZS ZS = Refl
    proofPlusComm ZS (SS m) = case proofPlusNZ m of Refl -> Refl
    proofPlusComm (SS n) ZS = case proofPlusNZ n of Refl -> Refl
    proofPlusComm (SS n) (SS m) =
      case proofPlusComm (SS n) m of
        Refl -> case proofPlusComm n (SS m) of
          Refl -> case proofPlusComm n m of
            Refl -> Refl
    
    type family Times (n :: Nat) (m :: Nat) where
      Times 'Z m = 'Z
      Times ('S n) m = Plus m (Times n m)
    
    times :: SNat n -> SNat m -> SNat (Times n m)
    times ZS _ = ZS
    times (SS n) m = plus m (times n m)
    
    proofMultNZ :: SNat n -> Times n 'Z :~: 'Z
    proofMultNZ ZS = Refl
    proofMultNZ (SS n) = case proofMultNZ n of Refl -> Refl
    
    proofMultNS :: SNat n -> SNat m -> Times n ('S m) :~: Plus n (Times n m)
    proofMultNS ZS ZS = Refl
    proofMultNS ZS (SS m) =
      case proofMultNZ (SS m) of
        Refl -> case proofMultNZ m of
          Refl -> Refl
    proofMultNS (SS n) ZS =
      case proofMultNS n ZS of Refl -> Refl
    proofMultNS (SS n) (SS m) =
      case proofMultNS (SS n) m of
        Refl -> case proofMultNS n (SS m) of
          Refl -> case proofMultNS n m of
            Refl -> case lemma1 n m (times n (SS m)) of
              Refl -> Refl
      where lemma1 :: SNat n -> SNat m -> SNat o -> Plus n ('S (Plus m o))
                                                    :~:
                                                    'S (Plus m (Plus n o))
            lemma1 n' m' o' =
              case proofPlusComm n' (SS (plus m' o')) of
                Refl -> case proofPlusComm m' (plus n' o') of
                  Refl -> case proofPlusAssoc m' o' n' of
                    Refl -> case proofPlusComm n' o' of
                      Refl -> Refl
    
    proofMultSN :: SNat n -> SNat m -> Times ('S n) m :~: Plus (Times n m) m
    proofMultSN ZS m = case proofPlusNZ m of Refl -> Refl
    proofMultSN (SS n) m =
      case proofPlusNZ (times n m) of
        Refl -> case proofPlusComm m (plus m (plus (times n m) ZS)) of
          Refl -> Refl
    
    proofMultComm :: SNat n -> SNat m -> Times n m :~: Times m n
    proofMultComm ZS ZS = Refl
    proofMultComm ZS (SS m) = case proofMultNZ (SS m) of
                                Refl -> case proofMultComm ZS m of
                                  Refl -> Refl
    proofMultComm (SS n) ZS = case proofMultComm n ZS of Refl -> Refl
    proofMultComm (SS n) (SS m) =
      case proofMultNS n m of
        Refl -> case proofMultNS m n of
          Refl -> case proofPlusAssoc m n (times n m) of
            Refl -> case proofPlusAssoc n m (times m n) of
              Refl -> case proofPlusComm n m of
                Refl -> case proofMultComm n m of
                  Refl -> Refl
    
    -- `Finite n` represents a number in 0..n (inclusive).
    --
    -- Notice that the "zero" branch includes an `SNat`; this is useful to be
    -- able to conveniently write `toSN` below (generally, to be able to
    -- reflect the `n` component to the value level) without needing to use a
    -- singleton typeclass & pass constraitns around everywhere.
    --
    -- It should be possible to switch this out for other implementations of
    -- `Finite` with different choices, but may require rewriting many of
    -- the following functions.
    data Finite (n :: Nat) where
      ZF :: SNat n -> Finite n
      SF :: Finite n -> Finite ('S n)
    deriving instance Show (Finite n)
    
    toSN :: Finite n -> SNat n
    toSN (ZF sn) = sn
    toSN (SF f) = SS (toSN f)
    
    addFin :: forall n m. Finite n -> Finite m -> Finite (Plus n m)
    addFin (ZF n) (ZF m) = ZF (plus n m)
    addFin (ZF n) (SF b) =
      case proofPlusNS n (toSN b) of
        Refl -> SF (addFin (ZF n) b)
    addFin (SF a) b = SF (addFin a b)
    
    mulFin :: forall n m. Finite n -> Finite m -> Finite (Times n m)
    mulFin (ZF n) (ZF m) = ZF (times n m)
    mulFin (ZF n) (SF b) = case proofMultNS n (toSN b) of
                             Refl -> addFin (ZF n) (mulFin (ZF n) b)
    mulFin (SF a) b = addFin b (mulFin a b)
    
    quotRemFin :: SNat n -> SNat m -> Finite (Plus (Times n ('S m)) m)
            -> (Finite n, Finite m)
    quotRemFin nn mm xx = go mm xx nn mm (ZF ZS) (ZF ZS)
      where go :: forall n m s p q r.
                (  Plus q s ~ n, Plus r p ~ m)
                => SNat m
                -> Finite (Plus (Times s ('S m)) p)
                -> SNat s
                -> SNat p
                -> Finite q
                -> Finite r
                -> (Finite n, Finite m)
            go _ (ZF _) s p q r = (addFin q (ZF s), addFin r (ZF p))
            go m (SF x) s (SS p) q r =
              case proofPlusComm (SS p) (times s m) of
                Refl -> case proofPlusNS (times s (SS m)) p of
                  Refl -> case proofPlusNS (toSN r) p of
                    Refl -> go m x s p q (SF r)
            go m (SF x) (SS s) ZS q _ =
              case proofPlusNS (toSN q) s of
                Refl -> case proofMultSN s (SS m) of
                  Refl -> case proofPlusNS (times s (SS m)) m of
                    Refl -> case proofPlusComm (times s (SS m)) (SS m) of
                      Refl -> case proofPlusNZ (times (SS s) (SS m)) of
                        Refl -> go m x s m (SF q) (ZF ZS)
    
    splitFin :: forall n m a. SNat n -> SNat m -> Finite ('S (Plus n m))
             -> (Finite n -> a) -> (Finite m -> a) -> a
    splitFin nn mm xx f g = go nn mm xx mm (ZF ZS)
      where go :: forall r s. (Plus r s ~ m)
               => SNat n -> SNat m -> Finite ('S (Plus n s))
               -> SNat s -> Finite r -> a
            go _ _ (ZF _) s r = g (addFin r (ZF s))
            go n m (SF x) (SS s) r =
              case proofPlusNS (toSN r) s of
                Refl -> case proofPlusNS n s of
                  Refl -> go n m x s (SF r)
            go n _ (SF x) ZS _ = case proofPlusNZ n of Refl -> f x
    
    injFin :: Finite n -> Finite ('S n)
    injFin (ZF n) = ZF (SS n)
    injFin (SF a) = SF (injFin a)
    
    toNum :: (Num a) => Finite n -> a
    toNum (ZF _) = 0
    toNum (SF n) = 1 + toNum n
    
    -- The actual classes & Generic stuff:
    class EnumFin a where
      type Size a :: Nat
      type Size a = GSize (Rep a)
    
      toFin :: (n ~ Size a) => a -> Finite n
      default toFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                    => a -> Finite n
      toFin = gToFin . from
    
      fromFin :: (n ~ Size a) => Finite n -> a
      default fromFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                      => Finite n -> a
      fromFin = to . gFromFin
    
      zero :: (n ~ Size a) => Finite n
      default zero :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                   => Finite n
      zero = gzero @(Rep a)
      gt :: (n ~ Size a) => Finite n
      default gt :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                   => Finite n
      gt = ggt @(Rep a)
    class GEnumFin f where
      type GSize f :: Nat
      gToFin :: f a -> Finite (GSize f)
      gFromFin :: Finite (GSize f) -> f a
      gzero :: Finite (GSize f)
      ggt :: Finite (GSize f)
    
    instance GEnumFin U1 where
      type GSize U1 = 'Z
      gToFin U1 = ZF ZS
      gFromFin (ZF ZS) = U1
      gzero = ZF ZS
      ggt = ZF ZS
    
    instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :*: b) where
      type GSize (a :*: b) = Plus (Times (GSize a) ('S (GSize b))) (GSize b)
      gToFin (a :*: b) = addFin (mulFin (gToFin a) (SF (ggt @b))) (gToFin b)
      gFromFin x = (gFromFin a :*: gFromFin b)
        where (a, b) = quotRemFin (toSN (ggt @a)) (toSN (ggt @b)) x
      gzero = addFin (mulFin (gzero @a) (SF (ggt @b))) (gzero @b)
      ggt = addFin (mulFin (ggt @a) (SF (ggt @b))) (ggt @b)
    
    instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :+: b) where
      type GSize (a :+: b) = 'S (Plus (GSize a) (GSize b))
      gToFin (L1 a) = case proofPlusComm (toSN (gzero @a)) (toSN (gzero @b)) of
                        Refl -> addFin (injFin (gzero @b)) (gToFin a)
      gToFin (R1 b) = addFin (SF (ggt @a)) (gToFin b)
      gFromFin x = case proofPlusComm (toSN (ggt @a)) (toSN (ggt @b)) of
                     Refl -> splitFin (toSN (ggt @b)) (toSN (ggt @a)) x
                                      (R1 . gFromFin @b) (L1 . gFromFin @a)
      gzero = addFin (injFin (gzero @a)) (gzero @b)
      ggt = addFin (SF (ggt @a)) (ggt @b)
    
    instance forall i c a. (GEnumFin a) => GEnumFin (M1 i c a) where
      type GSize (M1 i c a) = GSize a
      gToFin (M1 a) = gToFin a
      gFromFin = M1 . gFromFin
      gzero = gzero @a
      ggt = ggt @a
    
    instance (EnumFin a) => GEnumFin (K1 i a) where
      type GSize (K1 i a) = Size a
      gToFin (K1 a) = toFin a
      gFromFin = K1 . fromFin
      gzero = zero @a
      ggt = gt @a
    
    -- Demo:
    data Foo = A | B deriving (Show, Generic)
    data Bar = C | D deriving (Show, Generic)
    data Baz = E Foo | F Bar | G Foo Bar deriving (Show, Generic)
    
    instance EnumFin Foo
    instance EnumFin Bar
    instance EnumFin Baz
    
    main :: IO ()
    main = do
      putStrLn $ show $ toNum @Integer $ gt @Baz
      putStrLn $ show $ toNum @Integer $ toFin $ E A
      putStrLn $ show $ toNum @Integer $ toFin $ E B
      putStrLn $ show $ toNum @Integer $ toFin $ F C
      putStrLn $ show $ toNum @Integer $ toFin $ F D
      putStrLn $ show $ toNum @Integer $ toFin $ G A C
      putStrLn $ show $ toNum @Integer $ toFin $ G A D
      putStrLn $ show $ toNum @Integer $ toFin $ G B C
      putStrLn $ show $ toNum @Integer $ toFin $ G B D
      putStrLn $ show $ fromFin @Baz $ toFin $ E A
      putStrLn $ show $ fromFin @Baz $ toFin $ E B
      putStrLn $ show $ fromFin @Baz $ toFin $ F C
      putStrLn $ show $ fromFin @Baz $ toFin $ F D
      putStrLn $ show $ fromFin @Baz $ toFin $ G A C
      putStrLn $ show $ fromFin @Baz $ toFin $ G A D
      putStrLn $ show $ fromFin @Baz $ toFin $ G B C
      putStrLn $ show $ fromFin @Baz $ toFin $ G B D