scalaimplicitscala-3match-types

How to ask Scala if evidence exists for all instantiations of type parameter?


Given the following type-level addition function on Peano numbers

sealed trait Nat
class O extends Nat
class S[N <: Nat] extends Nat

type plus[a <: Nat, b <: Nat] = a match
  case O => b
  case S[n] => S[n plus b]

say we want to prove theorem like

for all natural numbers n, n + 0 = n

which perhaps can be specified like so

type plus_n_0 = [n <: Nat] =>> (n plus O) =:= n

then when it comes to providing evidence for theorem we can easily ask Scala compiler for evidence in particular cases

summon[plus_n_O[S[S[O]]]]  // ok, 2 + 0 = 2

but how can we ask Scala if it can generate evidence for all instantiations of [n <: Nat], thus providing proof of plus_n_0?


Solution

  • Here is one possible approach, which is an attempt at a literal interpretation of this paragraph:

    When proving a statement E:N→U about all natural numbers, it suffices to prove it for 0 and for succ(n), assuming it holds for n, i.e., we construct ez:E(0) and es:∏(n:N)E(n)→E(succ(n)).

    from the HoTT book (section 5.1).

    Here is the plan of what was implemented in the code below:

    The code then looks like this:

    sealed trait Nat
    class O extends Nat
    class S[N <: Nat] extends Nat
    
    type plus[a <: Nat, b <: Nat] <: Nat = a match
      case O => b
      case S[n] => S[n plus b]
    
    trait Forall[N, P[n <: N]]:
      inline def apply[n <: N]: P[n]
    
    trait NatInductionPrinciple[P[n <: Nat]] extends Forall[Nat, P]:
      def base: P[O]
      def step: [i <: Nat] => (P[i] => P[S[i]])
      inline def apply[n <: Nat]: P[n] =
        (inline compiletime.erasedValue[n] match
          case _: O => base
          case _: S[pred] => step(apply[pred])
        ).asInstanceOf[P[n]]
    
    given liftCoUpperbounded[U, A <: U, B <: U, S[_ <: U]](using ev: A =:= B):
      (S[A] =:= S[B]) = ev.liftCo[[X] =>> Any].asInstanceOf[S[A] =:= S[B]]
    
    type NatPlusZeroEqualsNat[n <: Nat] = (n plus O) =:= n
    
    def trivialLemma[i <: Nat]: ((S[i] plus O) =:= S[i plus O]) =
      summon[(S[i] plus O) =:= S[i plus O]]
    
    object Proof extends NatInductionPrinciple[NatPlusZeroEqualsNat]:
      val base = summon[(O plus O) =:= O]
      val step: ([i <: Nat] => NatPlusZeroEqualsNat[i] => NatPlusZeroEqualsNat[S[i]]) = 
        [i <: Nat] => (p: NatPlusZeroEqualsNat[i]) =>
          given previousStep: ((i plus O) =:= i) = p
          given liftPreviousStep: (S[i plus O] =:= S[i]) =
            liftCoUpperbounded[Nat, i plus O, i, S]
          given definitionalEquality: ((S[i] plus O) =:= S[i plus O]) =
            trivialLemma[i]
          definitionalEquality.andThen(liftPreviousStep)
    
    def demoNat(): Unit = {
      println("Running demoNat...")
      type two = S[S[O]]
      val ev = Proof[two]
      val twoInstance: two = new S[S[O]]
      println(ev(twoInstance) == twoInstance)
    }
    

    It compiles, runs, and prints:

    true
    

    meaning that we have successfully invoked the recursively defined method on the executable evidence-term of type two plus O =:= two.


    Some further comments