The problem is given as: Output the answer of (A^1+A^2+A^3+...+A^K) modulo 1,000,000,007, where 1≤ A, K ≤ 10^9, and A and K must be an integer.
I am trying to write a program to compute the above question. I have tried using the formula for geometric sequence, then applying the modulo on the answer. Since the results must be an integer as well, finding modulo inverse is not required.
Below is the code I have now, its in pascal
Var
a,k,i:longint;
power,sum: int64;
Begin
Readln(a,k);
power := 1;
For i := 1 to k do
power := ((power mod 1000000007) * a) mod 1000000007;
sum := a * (power-1) div (a-1);
Writeln(sum mod 1000000007);
End.
This task came from my school, they do not give away their test data to the students. Hence I do not know why or where my program is wrong. I only know that my program outputs the wrong answer for their test data.
If you want to do this without calculating a modular inverse, you can calculate it recursively using:
1 + A + A2 + A3 + … + Ak
= 1 + (A + A2)(1 + A2 + (A2)2 + … + (A2)k/2−1)
That’s for even k. For odd k:
1 + A + A2 + A3 + ... + Ak
= (1 + A)(1 + A2 + (A2)2 + ... + (A2)(k−1)/2)
Since k is divided by 2 in each recursive call, the resulting algorithm has O(log k) complexity. In Java:
static int modSumAtoAk(int A, int k, int mod)
{
return (modSum1ToAk(A, k, mod) + mod-1) % mod;
}
static int modSum1ToAk(int A, int k, int mod)
{
long sum;
if (k < 5) {
//k is small -- just iterate
sum = 0;
long x = 1;
for (int i=0; i<=k; ++i) {
sum = (sum+x) % mod;
x = (x*A) % mod;
}
return (int)sum;
}
//k is big
int A2 = (int)( ((long)A)*A % mod );
if ((k%2)==0) {
// k even
sum = modSum1ToAk(A2, (k/2)-1, mod);
sum = (sum + sum*A) % mod;
sum = ((sum * A) + 1) % mod;
} else {
// k odd
sum = modSum1ToAk(A2, (k-1)/2, mod);
sum = (sum + sum*A) % mod;
}
return (int)sum;
}
Note that I’ve been very careful to make sure that each product is done in 64 bits, and to reduce by the modulus after each one.
With a little math, the above can be converted to an iterative version that doesn’t require any storage:
static int modSumAtoAk(int A, int k, int mod)
{
// first, we calculate the sum of all 1... A^k
// we'll refer to that as SUM1 in comments below
long fac=1;
long add=0;
//INVARIANT: SUM1 = add + fac*(sum 1...A^k)
//this will remain true as we change k
while (k > 0) {
//above INVARIANT is true here, too
long newmul, newadd;
if ((k%2)==0) {
//k is even. sum 1...A^k = 1+A*(sum 1...A^(k-1))
newmul = A;
newadd = 1;
k-=1;
} else {
//k is odd.
newmul = A+1L;
newadd = 0;
A = (int)(((long)A) * A % mod);
k = (k-1)/2;
}
//SUM1 = add + fac * (newadd + newmul*(sum 1...Ak))
// = add+fac*newadd + fac*newmul*(sum 1...Ak)
add = (add+fac*newadd) % mod;
fac = (fac*newmul) % mod;
//INVARIANT is restored
}
// k == 0
long sum1 = fac + add;
return (int)((sum1 + mod -1) % mod);
}