coq

Constructing proof terms inside a recursive function


Just starting with Coq, trying to write a list insertion function that will return either the new list, or if the provided index was illegal, a proof thereof using sumor. I'm fine with the recursion, but I'm struggling to write the proof terms.

Here's what I've got so far:

Fixpoint insert {A: Type} (lst: list A) (n: nat) (x: A): (list A)+{n >= length lst} :=
  match lst, n return (list A)+{n >= length lst} with
  | [ ], O => inleft [ x ]
  | [ ], S n' => inright (* proof term for `n >= length lst` *)
  | h :: t, S n' => match insert' _ t n' with
                    | inleft lst' => inleft (h :: lst')
                    | inright P (* : n' >= length lst' *) => inright _ (* proof term for n >= length lst *)
                    end
  | _, O => inleft x :: lst
  end.

Any pointers? I'm realizing I've got a real blind spot here, so in addition to concrete suggestions, I'd be interested in any tips on what I should be reading to understand this better.

Update: I gave myself three lemmas:

Lemma simple_test: forall (A: Type) (lst: list A) (n: nat), lst = [] /\ n > 0 -> n > length lst.
Lemma Sn_gt_0: forall n: nat, S n > 0.
Lemma Sn_gt_len_tail: forall (A: Type) (h: A) (t: list A) (n': nat),
    n' >= length t -> S n' >= length (h :: t).

Then tried:

Fixpoint insert' {A: Type} (lst: list A) (n: nat) (x: A): (list A)+{n >= length lst} :=
  match lst return (list A)+{n >= length lst} with
  | [ ] => match n with
           | O => inleft [ x ]
           | S n' => inright _
                       (* proof term for `n >= length lst` *)
                       (simple_test A lst n (conj (eq_refl lst) (Sn_gt_0 n')))
           end
  | h :: t => match n with
              | O => inleft x :: lst
              | S n' => fun _ => 
                          match insert' _ t n' with
                          | inleft lst' => inleft h :: lst'
                          | inright P (*: n' >= length t*) => inright _ (Sn_gt_len_tail P)
                          end
              end
  end.

This fails with:

The term conj eq_refl (Sn_gt_0 n') has type lst = lst /\ S n' > 0 while it is expected to have type lst = [] /\ n > 0.

I feel like I'm soooooooo close.


Solution

  • As a newcomer, you are giving yourself an exercise that is too big to grasp in one go, so your code contains several mistakes.

    First, you are using dependent pattern matching

    match lst return (list A) + {n >= lst} with  ... end
    

    This is fine. But why did you not use the same kind of dependent pattern matching when matching on n. You should have written

    match n return (list A) + {n >= 0} with ... end
    

    In the third line of your definition of insert'

    One important thing is I did not write n >= length lst, instead I wrote n >= 0. Here is the main lesson If you use dependent pattern matching, the expected type in each branch does not refer to the matched variable, but to the pattern. So all occurrences of n and lst in line 4 should be replaced by [] and 0 respectively (fortunately, there are no such occurrences), and all occurrences of n and lst in lines 5-7, should be replaced by [] and (S n') respectively.

    Now, if you look at lines 10-17, all occurrences of lst should now be replaced with (h :: t), and you should now be able to guess what to do with the various occurrences of n in that part of the text too.

    Once you fix this issue, other bugs surface. You say that in the case of failure, the index of insertion is larger than or equal than the length of the list, but this is not true. When the index is 0 and the list is nil, this is a case of success. So the arithmetic of comparison needs to be fixed in a variety of places.

    Then there are a few other problems: you did not provide all the arguments to all the functions, so that many inner terms end up having the wrong type.

    Here is the fixed code:

    Require Import List.
    Import ListNotations.
    
    Lemma simple_test: forall (A: Type) (lst: list A) (n: nat), lst = [] /\ n > 0 -> n > length lst.
    Proof.
    Admitted.
    
    Lemma Sn_gt_0: forall n: nat, S n > 0.
    Proof.
    Admitted.
    
    Lemma Sn_gt_len_tail: forall (A: Type) (h: A) (t: list A) (n': nat),
        n' > length t -> S n' > length (h :: t).
    Proof.
    Admitted.
    
    Fixpoint insert' {A: Type} (lst: list A) (n: nat) (x: A):
       (list A)+{n > length lst} :=
      match lst return (list A)+{n > length lst} with
      | [ ] => match n return (list A)+{n > 0} with
               | O => inleft [ x ]
               | S n' => inright _
                           (* proof term for `n >= length lst` *)
                           (simple_test A [] (S n') (conj (eq_refl []) (Sn_gt_0 n')))
               end
      | h :: t => match n with
                  | O => inleft (x :: h :: t)
                  | S n' => match insert' t n' x with
                              | inleft _ lst' => inleft (h :: lst')
                              | inright _ P (*: n' >= length t*) =>
                                inright _ (Sn_gt_len_tail A h t _ P)
                              end
                  end
      end.
    

    The lesson here is that confirmed users almost never write functions with this kind of type by hand, because it requires you to get too many things right at the same time. Instead, people have a tendency to use tools that make it easier to separate the algorithmic part and the proof part, like Equations, or they directly use the tactic machinery to build the program progressively as if it was a proof.

    Here is an example:

    Definition insert2 {A : Type} :
      forall (lst : list A) (n : nat) (x : A), (list A)+{n > length lst}.
    fix insert2_rec 1.
    destruct lst as [ | h t].
      intros [ | n'] x.
        apply inleft.
        exact [x].
      right.
      apply Sn_gt_0.
    intros [ | n'] x.
      apply inleft.
      exact (x :: h :: t).
    destruct (insert2_rec t n' x) as [v | P].
      apply inleft.
      exact (h :: v).
    apply inright.
    apply Sn_gt_len_tail.
    exact P.
    Defined.
    

    If you type Print insert2. you will see the code that what is actually generated by this definition by proof. It is very close to what you get if you typed Print insert'..

    You may not like having definitions by proofs in you scripts (it is harder to see what is the algorithm), but you can use a definition by proof to obtain the function, print it, and then rewrite your direct definition as Fixpoint definition. The term obtained with a definition by proof will guide you.

    ADDENDUM:

    Once we have fixed insert', we can use the lessons learned to fix insert, just making sure to place well-typed terms in the various places:

    Fixpoint insert {A: Type} (lst: list A) (n: nat) (x: A): (list A)+{n > length lst} :=
      match lst, n return (list A)+{n > length lst} with
      | [ ], O => inleft [ x ]
      | [ ], S n' => inright (Sn_gt_0 n')
      | h :: t, S n' => match insert t n' x with
                        | inleft _ lst' => inleft (h :: lst')
                        | inright P (* : n' >= length lst' *) => inright _ (* proof term for n >= length lst *) (Sn_gt_len_tail _ _ _ _ P)
                        end
      | _, O => inleft (x :: lst)
      end.
    

    This code was tested with Coq 8.19