haskelltypeclasstype-families

How to convey "less than" constraint using type classes?


I am trying to write a version of take that works on length-indexed vectors. This requires the number to take from to be less than or equal to the length of the vector.

This is the current version of my code:

data Nat where
    Zero :: Nat
    Succ :: Nat -> Nat

data SNat (n :: Nat) where
    SZero :: SNat Zero
    SSucc :: SNat n -> SNat (Succ n)

data Vec (n :: Nat) (a :: Type) where
    Nil  :: Vec Zero a
    Cons :: a -> Vec n a -> Vec (Succ n) a

class (m :: Nat) >= (n :: Nat)
instance m >= Zero
instance m >= n => (Succ m >= Succ n)

take :: (m >= n) => SNat n -> Vec m a -> Vec n a
take (SZero  ) _         = Nil
take (SSucc n) (x `Cons` xs) = x `Cons` (take n xs)

However, I am getting this error which I am not sure how to solve:

    * Could not deduce (n2 >= n1) arising from a use of `take'
      from the context: m >= n
        bound by the type signature for:
                   take :: forall (m :: Nat) (n :: Nat) a.
                           (m >= n) =>
                           SNat n -> Vec m a -> Vec n a
        at src\AnotherOne.hs:39:1-48
      or from: (n :: Nat) ~ ('Succ n1 :: Nat)
        bound by a pattern with constructor:
                   SSucc :: forall (n :: Nat). SNat n -> SNat ('Succ n),
                 in an equation for `take'
        at src\AnotherOne.hs:41:7-13
      or from: (m :: Nat) ~ ('Succ n2 :: Nat)
        bound by a pattern with constructor:
                   Cons :: forall a (n :: Nat). a -> Vec n a -> Vec ('Succ n) a,
                 in an equation for `take'
        at src\AnotherOne.hs:41:17-27
      Possible fix:
        add (n2 >= n1) to the context of the data constructor `Cons'
    * In the second argument of `Cons', namely `(take n xs)'
      In the expression: x `Cons` (take n xs)
      In an equation for `take':
          take (SSucc n) (x `Cons` xs) = x `Cons` (take n xs

I have tried a few different iterations of the type class, using OVERLAPS and even INCOHERENT but I have not been able to fix it. HLS also tells me that the pattern matching is incomplete, saying that I am not matching (SSucc SZero) Nil and (SSucc (SSucc _)) Nil.

However if I try to write:

test = take (SSucc SZero) Nil

it correctly errors with Couldn't match type ‘'Zero’ with ‘'Succ 'Zero’, suggesting that my problem is specifically in the function definition, since from a few tests the API for the function seems correct.

Lastly I have been suggested to just use a type family for this, doing:

type (>=~) :: Nat -> Nat -> Bool
type family m >=~ n where
    m >=~ Zero        = True
    Succ m >=~ Succ n = m >=~ n
    _ >=~ _           = False
type m >= n = m >=~ n ~ True

Which does work, but I was trying to solve this using Haskell instances. As a side question, is there any benefit of one over the other?


Solution

  • The problem is that the interface of your >= class doesn't in any way express what it means for a number to be at least as great as another.

    To do that, I would suggest refactoring the singleton type to clearly separate the two possible cases:

    data SZero (n :: Nat) where
      SZero :: SZero 'Zero
    
    data SPositive (n :: Nat) where
      SSucc :: SNat n -> SPositive ('Succ n)
    
    type SNat n = Either (SZero n) (SPositive n)
    

    Furthermore, we need to have a way to express rolling back the inductive steps on the type level. Here we need a type family, but it can be much simpler than your >=~ one:

    type family Pred (n :: Nat) :: Nat where
      Pred ('Succ n) = n
    

    Notice this is not total! It's ok: type families can safely point to nowhere. You can still use them in a context where the compiler can infer that the clause that is there applies.

    Now we can formulate the class. The crucial theorem that you noticed was missing was that in the Succ case, you can apply induction over the predecessors. More precisely, we only need to know that n is positive, in order to be able to step down the mn property to the predecessors of both numbers. I.e. the mathematical statement is

    mn ∧ positive(n) ⟹ pred(m) ≥ pred(n).

    We can now express exactly that, using the CPS trick to demote the implication arrow into the value-level:

    class m>=n where
      atLeastAsPositive :: SPositive n -> (Pred m >= Pred n => r) -> r
    

    For the Zero case, this theorem doesn't even apply, but that's no problem – we know there aren't any suitable singletons anyway, so we can safely use an empty case match:

    instance m >= 'Zero where
      atLeastAsPositive s = case s of {}
    

    The interesting case is the one of positive numbers. The way we have formulated the type, the compiler can easily connect the threads:

    instance m >= n => ('Succ m >= 'Succ n) where
      atLeastAsPositive (SSucc _) φ = φ
    

    And finally, we invoke that theorem in your take function:

    take :: ∀ m n a . (m >= n) => SNat n -> Vec m a -> Vec n a
    take (Left SZero) _         = Nil
    take (Right s@(SSucc n)) (x `Cons` xs)
      = atLeastAsPositive @m s (x `Cons` (take n xs))