patternpythonMinor
Performance of modular square root
Viewed 0 times
squarerootperformancemodular
Problem
Here's Project Euler problem 451:
Consider the number \$15\$.
There are eight positive numbers less than \$15\$ which are coprime to \$15\$:
\begin{align} 1, 2, 4, 7, 8, 11, 13, 14. \end{align}
The modular inverses of these numbers modulo \$15\$ are:
\begin{align} 1, 8, 4, 13, 2, 11, 7, 14 \end{align}
because
\begin{align*}
1 \times 1 \mod 15 &= 1 \\
2 \times 8 = 16 \mod 15 &= 1 \\
4 \times 4=16 \mod 15 &= 1 \\
7 \times 13=91 \mod 15 &= 1 \\
11 \times 11=121 \mod 15 &= 1 \\
14 \times 14=196 \mod 15 &= 1
\end{align*}
Let \$I(n)\$ be the largest positive number \$m\$ smaller than \$n − 1\$ such that the modular inverse of \$m\$ modulo \$n\$ equals \$m\$ itself.
So \$I(15) = 11\$.
Also \$I(100) = 51\$ and \$I(7) = 1\$.
Find \$\sum I(n)\$ for \$3 \leq n \leq 2 \times 10^7\$.
I realised that because they are symmetrical, we can search for the smallest ones (of course skipping the trivial case \$1^2 = 1 \mod n\$). I also realised that if \$n\$ is divisible by a prime, I don't need to test multiples of that prime. I currently devised a implementation for the first four primes, as I think they are the ones that have the most impact. Anyways, here's my code. However, I must warn you, that in trying to optimise for speed, this one is very memory inefficient (needs around 2.5 GB).
```
upto=2*10**7+1
a = [True] * upto
p = []
for n in range(2,upto):
if a[n]:
p.append(n)
for k in range(2,(upto+n-1)//n):
a[k*n] = False
p=set(p)
su=0
squ=list(map(lambda x: x*x, range(upto)))
print('primes and squares ready')
s2 = set(range(0,upto,2))# for testing divisibility
s3 = set(range(0,upto,3))
s5 = set(range(0,upto,5))
s7 = set(range(0,upto,7))
r=[]
r.append([])
r.append([])
r[0].append([])
r[0].append([])
r[1].append([])
r[1].append([])
r[0][0].append([])
r[0][0].append([])
r[0][1].append([])
r[0][1].append([])
r[1][0].append([])
r[1][0].append([])
r[1][1].append([])
r[1][1].append([])
r[0][0][0].append(range
Consider the number \$15\$.
There are eight positive numbers less than \$15\$ which are coprime to \$15\$:
\begin{align} 1, 2, 4, 7, 8, 11, 13, 14. \end{align}
The modular inverses of these numbers modulo \$15\$ are:
\begin{align} 1, 8, 4, 13, 2, 11, 7, 14 \end{align}
because
\begin{align*}
1 \times 1 \mod 15 &= 1 \\
2 \times 8 = 16 \mod 15 &= 1 \\
4 \times 4=16 \mod 15 &= 1 \\
7 \times 13=91 \mod 15 &= 1 \\
11 \times 11=121 \mod 15 &= 1 \\
14 \times 14=196 \mod 15 &= 1
\end{align*}
Let \$I(n)\$ be the largest positive number \$m\$ smaller than \$n − 1\$ such that the modular inverse of \$m\$ modulo \$n\$ equals \$m\$ itself.
So \$I(15) = 11\$.
Also \$I(100) = 51\$ and \$I(7) = 1\$.
Find \$\sum I(n)\$ for \$3 \leq n \leq 2 \times 10^7\$.
I realised that because they are symmetrical, we can search for the smallest ones (of course skipping the trivial case \$1^2 = 1 \mod n\$). I also realised that if \$n\$ is divisible by a prime, I don't need to test multiples of that prime. I currently devised a implementation for the first four primes, as I think they are the ones that have the most impact. Anyways, here's my code. However, I must warn you, that in trying to optimise for speed, this one is very memory inefficient (needs around 2.5 GB).
```
upto=2*10**7+1
a = [True] * upto
p = []
for n in range(2,upto):
if a[n]:
p.append(n)
for k in range(2,(upto+n-1)//n):
a[k*n] = False
p=set(p)
su=0
squ=list(map(lambda x: x*x, range(upto)))
print('primes and squares ready')
s2 = set(range(0,upto,2))# for testing divisibility
s3 = set(range(0,upto,3))
s5 = set(range(0,upto,5))
s7 = set(range(0,upto,7))
r=[]
r.append([])
r.append([])
r[0].append([])
r[0].append([])
r[1].append([])
r[1].append([])
r[0][0].append([])
r[0][0].append([])
r[0][1].append([])
r[0][1].append([])
r[1][0].append([])
r[1][0].append([])
r[1][1].append([])
r[1][1].append([])
r[0][0][0].append(range
Solution
I'm not really sure I understand your code; it's not really organized in a way that aids readability. Use functions, so that I don't need to spend a minute or more reading the first few lines of the program to understand that they should be
Now at a glance, I can see that
However, it looks like you're precomputing
I am less sure about your
OK, now let's talk algorithm.
First of all, when computing
I think you had the right insight: If you're computing the coprimes separately for each
is much slower than
Filtering down
def computePrimesUpTo(n):
sieve = [True] * n
p = []
for i in range(2, n): # In Python2, use xrange instead
if sieve[i]:
p.append(i)
for multiple in range(2 * i, n, i):
sieve[multiple] = False
return set(p)
primes = computePrimesUpTo(upTo)Now at a glance, I can see that
if n in primes is checking whether n is a prime number, and I can see that you're using the Sieve of Eratosthenes to compute them. (Note: p is not the worst offender here; in fact, it's the easiest to figure out. It's taken me several reads through your code to figure out what r is supposed to be.)However, it looks like you're precomputing
xx for x in [0, 210^7], as well as lists of all multiples of 2, 3, 5, and 7 in the same range. This belies a fundamental misunderstanding of what is expensive in modern processors. x will be in a register, or in the worst case in the L1 cache, and squaring it will require a single instruction (less than a nanosecond). squ[x] will frequently not be in any of your CPU's memory caches; as a rule of thumb a cache miss costs 100 ns (And that's if it hasn't been swapped out of main memory onto disk! See this list of numbers every programmer should know). Not only that, but by reading from them, you'll push something else out of your cache, slowing down everything else. squ, s2, s3, s5, and s7 are likely slowing your program down by a noticeable amount.I am less sure about your
rs; they may help or hurt. However, the computation of r[1][1][0], etc. is not as fast as it might be; the in operation of python lists is O(n). Try implementing this function:def list_intersection(a, b):
"""Computes the intersection of two sorted iterables.
Returns a list; requires O(len(a) + len(b)) time.
"""
passOK, now let's talk algorithm.
First of all, when computing
l(n) there's no need to consider numbers less than or equal to sqrt(n). To take advantage of that, you might reverse the r[w][x][y][z] lists and iterate over them backward, with a check for each n that r[d2][d3][d5][d7][-1] > sqrt(n). This will give you a small benefit.I think you had the right insight: If you're computing the coprimes separately for each
n, for c in [c for c in range(2,n) if is_coprime(c, n)]: if (c*c)%n == 1: breakis much slower than
for c in range(2,n): if (c*c)%n == 1: breakFiltering down
range(2,n) is useful, but there's diminishing returns: filtering out factors of p gets you at most 1/p speed-up while computing l(n) with an additional cost that's O(upTo). Still, any test you do in your for x in r[...] loop will only slow you down (since if (x*x)%n == 1 is essentially as fast as you can get), so the only speedup you can gain is by reducing the size of the r lists.Code Snippets
def computePrimesUpTo(n):
sieve = [True] * n
p = []
for i in range(2, n): # In Python2, use xrange instead
if sieve[i]:
p.append(i)
for multiple in range(2 * i, n, i):
sieve[multiple] = False
return set(p)
primes = computePrimesUpTo(upTo)def list_intersection(a, b):
"""Computes the intersection of two sorted iterables.
Returns a list; requires O(len(a) + len(b)) time.
"""
passfor c in [c for c in range(2,n) if is_coprime(c, n)]: if (c*c)%n == 1: breakfor c in range(2,n): if (c*c)%n == 1: breakContext
StackExchange Code Review Q#39015, answer score: 3
Revisions (0)
No revisions yet.