Way to speed up this modulo operation?

1.5k Views Asked by At

I need to do this for a programming contest, but I thought it would be a better fit here.

I need to perform the following computation multiple times:

$$ x = ((x \cdot a) + b) \% M $$

Basically I need to do:

for(int i=0;i<n;i++){
    x = (x*a)%M;
    x = (x+b)%M;
}

where $M = 1000000007$ and $n \approx 10000$. Is there a way to speed up this computation and get the value of $x$ directly?

1

There are 1 best solutions below

1
On BEST ANSWER

If it is not necessary to honestly run the loop $n$ times, we can take advantage of the fact that the answer takes an explicit form: If $x_n$ denotes the output after the $n$-th loop, then

$$ x_n \equiv a^n x + (1 + a + \cdots + a^{n-1})b \mod{M} $$

Assuming that $a$ is an integer, the speed of computation can be boosted by considering an iteration. More precisely, if we define

$$ f(n, a) = \sum_{k=0}^{n-1} a^k = 1 + a + \cdots + a^{n-1}, $$

then for integers $a$ and $n \geq 1$, it satisfies

$$ f(n, a) \equiv \begin{cases} 1, & n = 1; \\ 1+a, & n = 2; \\ (1+a)f(n/2, a^2), & \text{$n$ is even}; \\ 1+(a+a^2)f((n-1)/2, a^2), & \text{$n$ is odd}; \end{cases} \mod{M}$$

So, instead of performing $n$ linear operations, it suffices to perform $\mathcal{O}(\log_2 n)$ iterations. For instance, the following is an implementation of this idea to Python code:

import time

M = 1000000007
x = 12387429
a = 2384238
b = 39287433
n = int(input("Enter the value of n = "))

def benchmark(f):
    t0 = time.time();
    f()
    print("Ellapsed time: {0:.4f} second(s)".format(time.time() - t0))
    print("")

def f1():
    res = x
    for i in range(0, n):
        res = (a * res + b) % M
    print("Result:", res)

print("First method:")
benchmark(f1)

def f2():
    def recur(n, a):
        if n == 1:
            return 1
        elif n == 2:
            return (1+a)
        elif (n % 2) == 0:
            return (1+a)*recur(n//2, (a*a)%M)%M
        elif (n % 2) == 1:
            return (1+a*(1+a)*recur(n//2, (a*a)%M))%M
        else:
            raise ValueError("Invalid input")
    s = recur(n, a)
    res = (s*b + (1+(a-1)*s)*x)%M
    print("Result:", res)

print("Second method:")
benchmark(f2)

In my computer, the result of this code is

>>> 
Enter the value of n = 10000000
First method:
Result: 239506769
Ellapsed time: 2.3320 second(s)

Second method:
Result: 239506769
Ellapsed time: 0.0090 second(s)

>>> 
Enter the value of n = 10000
First method:
Result: 982346606
Ellapsed time: 0.0150 second(s)

Second method:
Result: 982346606
Ellapsed time: 0.0185 second(s)

As we can see, this method is useful when $n$ is very large.