dafnyloop-invariant

Dafny Loop Invariant


I was writing code for a loop-based square root estimation method in Dafny, this is what I have:

method sqrt(val :int) returns (root:int)
requires val >= 0
ensures root * root >= val && (root - 1) * (root - 1) < val
{
        root := 0;
        var est := val;
        while (est > 0)
        invariant root * root >= val - est
        invariant (root-1) * (root-1) < val
        decreases est
        {
                root := root + 1;
                est := est - (2 * root - 1);
        }
}

Dafny is unable to verify this program due to the loop invariant. I can kind of see why, I am assuming its because the root can be 0 and therefore (0-1) * (0-1) < n could be false if n was 0, but I can't see a solution as I am new to this Dafny stuff. With the code described above I am getting an error when verifying which is:

src\dafnypractice.dfy(9,34): Error: this loop invariant could not be proved on entry Related message: loop invariant violation | 9 | invariant (root-1) * (root-1) < val | ^

src\dafnypractice.dfy(9,34): Error: this invariant could not be proved to be maintained by the loop Related message: loop invariant violation | 9 | invariant (root-1) * (root-1) < val | ^

Any help is appreciated.


Solution

  • Yes you identified a case which isn't true. You could specify it if val > 1, however you will need to develop those cases. You can make it work with a small change by making it less than or equal to val. Verifying loop invariants is all about induction. In Dafny every variable you're using has to be defined by the invariants and inductively the variable's value and relationships are maintained.

        function toOdd(n: nat): nat
            requires n > 0
        {
            2*n-1
        }
        
        function SumOfNOddNumbers(n: nat): nat {
            if n == 0 then 0 else toOdd(n)+SumOfNOddNumbers(n-1)
        }
    
        lemma SumOddIsSquared(n: nat)
            ensures SumOfNOddNumbers(n) == n*n
        {}
    
        method sqrt(val :nat) returns (root:nat)
            ensures val == 0 ==> root == 0
            ensures val != 0 ==> (root - 1) * (root - 1) <= val
            ensures val != 0 ==> root * root >= val
        {
                root := 0;
                var est: int := val;
                while (est > 0)
                    invariant val == 0 ==> root == 0
                    // invariant est == val ==> root == 0
                    invariant est == val-SumOfNOddNumbers(root)
                    invariant root * root >= val - est
                    invariant val != 0 ==> (root-1) * (root-1) <= val
                    decreases est
                {
                        root := root + 1;
                        ghost var oldEst := est;
                        est := est - (2 * root - 1);
                        assert val != 0 ==> (root-1) * (root-1) <= val by {
                            // assert oldEst == val - SumOfNOddNumbers(root-1);
                            assert oldEst > 0;
                            SumOddIsSquared(root-1);
                        }
                }
        }
    }
    

    More complete version:

    module SOSqrt {
        function toOdd(n: nat): nat
            requires n > 0
        {
            2*n-1
        }
    
        function SumOfNOddNumbers(n: nat): nat {
            if n == 0 then 0 else toOdd(n)+SumOfNOddNumbers(n-1)
        }
    
        lemma SumOddIsSquared(n: nat)
            ensures SumOfNOddNumbers(n) == n*n
        {}
    
        lemma SquareOfGreatNLarger(root: nat, n: nat)
            requires n > root
            ensures SumOfNOddNumbers(n) > SumOfNOddNumbers(root)
        {}
    
        lemma lessSquared(a: nat, b: nat)
            requires a <= b
            ensures a*a <= b*b
        {
            if a == b {
                assert a*a <= b*b;
            }else{
                var diff := b-a; 
                assert diff > 0;
                calc {
                    b*b;
                    (a+diff)*(a+diff);
                    a*a +2*a*diff + diff * diff;
                }
                assert a*a <= b*b;
            }
        }
    
        method sqrt(val :nat) returns (root:nat)
            ensures val == 0 ==> root == 0
            ensures val == 1 ==> root == 1
            ensures val != 0 ==> root * root >= val
            ensures val != 0 ==> (root - 1) * (root - 1) < val
        {
            root := 0;
            var est: int := val;
            while (est > 0)
                invariant val == 0 ==> root == 0
                invariant est == val - SumOfNOddNumbers(root)
                invariant root * root >= val - est
                invariant val > 1 ==> (root-1) * (root-1) < val
                invariant val == 1 ==> (root-1) * (root-1) <= val
                invariant est <= 0 ==> forall n :nat :: n > root ==> est > val - SumOfNOddNumbers(n)
                invariant est <= 0 ==> forall n: nat :: n < root ==> val - SumOfNOddNumbers(n) > 0
                decreases est
            {
                root := root + 1;
                ghost var oldEst := est;
                est := est - (2 * root - 1);
    
                assert val > 1 ==> (root-1) * (root-1) < val by {
                    assert oldEst > 0;
                    SumOddIsSquared(root-1);
                }
                assert est <= 0 ==> forall n :nat :: n > root ==> est > val-SumOfNOddNumbers(n) by {
                    forall n | n > root 
                        ensures est > val-SumOfNOddNumbers(n)
                    {
                        assert n >= root +1;
                        SquareOfGreatNLarger(root, n);
                    }
                }
                assert est <= 0 ==> forall n :nat :: n < root ==> val-SumOfNOddNumbers(n) > 0 by {
                    if est <= 0 {
                        assert val >= 1;
                        forall n : nat | n < root 
                            ensures val-SumOfNOddNumbers(n) > 0
                        {
                            SumOddIsSquared(n);
                            lessSquared(n, root-1);
                            assert n*n <= (root-1) * (root-1);
                        }
                    }
                }
            }
        }
    }