For a given integers n
and m
, determine that coefficient of x^m
term in (x^2+x+1)^n
is even or odd?
For example, if n=3 and m=4, (x^2+x+1)^3 = x^6 + 3x^5 + [[6x^4]] + 7x^3 + 6x^2 + 3x + 1
, so coefficient of x^4
term is 6 (=even).
n
and m
is as large as 10^12 and I want to calculate in a few seconds, so you can't calculate in linear time.
Do you have any efficient algorithm?
Yes, linear time in the number of bits in the input.
The coefficients in question are trinomial coefficients T(n, m)
. For binomial coefficients, we would use Lucas's theorem; let's work out the trinomial analog for p = 2
.
Working mod 2
and following the proof of Nathan Fine,
(1 + x + x^2)^{2^i} = 1 + x^{2^i} + x^{2^{2 i}}
(1 + x + x^2)^n
= prod_i ((1 + x + x^2)^{2^i n_i})
where n = sum_i n_i 2^i and n_i in {0, 1} for all i
(i.e., n_i is the binary representation of n
= prod_i (1 + x^{2^i n_i} + x^{2^i 2 n_i})
= prod_i sum_{m_i = 0}^{2 n_i} x^{2^i}
= sum_{(m_i)} prod_i x^{2^i m_i}
taken over sequences (m_i) where 0 ≤ m_i ≤ 2 n_i.
In the binomial case, the next step is to observe that, for the coefficient of x^m
, there's at most one choice of (m_i)
whose x^{2^i m_i}
factors have the right product, i.e., the binary representation of m
.
In the trinomial case, we have to consider binary pseudo-representations (m_i)
of m
where pseudo-bits can be zero, one, or two. There is a contribution to the sum if and only if for all i
such that n_i = 0
, we have m_i = 0
.
We can write an automaton that scans n
and m
bit by bit. State a
is initial and accepting.
a (0:0:nm') -> a nm' [emit 0]
a (1:0:nm') -> a nm' [emit 0]
-> b nm' [emit 2]
a (1:1:nm') -> a nm' [emit 1]
b (0:1:nm') -> a nm' [emit 0]
b (1:0:nm') -> b nm' [emit 1]
b (1:1:nm') -> a nm' [emit 0]
-> b nm' [emit 2]
We can use dynamic programming to count the paths. In code form:
def trinomial_mod_two(n, m):
a, b = 1, 0
while m:
n1, n = n & 1, n >> 1
m1, m = m & 1, m >> 1
if n1:
if m1:
a, b = a ^ b, b
else:
a, b = a, a ^ b
elif m1:
a, b = b, 0
else:
a, b = a, 0
return a
Branchless version for giggles:
def trinomial_mod_two_branchless(n, m):
a, b = 1, 0
while m:
n1, n = n & 1, n >> 1
m1, m = m & 1, m >> 1
a, b = ((n1 | ~m1) & a) ^ (m1 & b), ((n1 & ~m1) & a) ^ (n1 & b)
return a