HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMajor

Sieve of Sundaram for Project Euler 7: Python implementation slower than C++ and R

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
thanslowersundaramsieveprojecteulerforpythonandimplementation

Problem

A friend of mine recently started learning R, and she complained to me that R is very slow. She was working on Project Euler Problem 7, which asks for the value of the 10001st prime number. For comparison, I decided to write the exact same program in C++ (which I use most frequently), R (which I do not use regularly, but used for comparison), and Python (because I wanted to advise her to switch her learning from R to Python).

The program follows these steps:

  • Find the upper bound on the Nth prime.



  • Fill a prime number sieve (Sieve of Sundaram) large enough to capture the upper bound.



  • Count the primes in the prime number sieve until the Nth prime is found.



To my surprise, the Python solution was significantly slower than the R version. Here are the results from time for the three implementations:

  • C++: real 0.002s, user 0.001s, sys 0.001s



  • R: real 2.239s, user 2.065s, sys 0.033s



  • Python: real 13.588s, user 13.566s, sys 0.025s



I tried using numpy arrays instead of lists, but to little effect.

I would appreciate an explanation of why the Python implementation is so slow relative to the other two. I would also appreciate tips on how to use lists in R and Python more efficiently. Is there a better way to do this?

C++:

#include 
#include 
#include 
using namespace std;

const int N = 10001;

int main()
{
  int max = floor(N*(log(N) + log(log(N))));
  vector marked(max/2, false);
  for (size_t i = 1; i < marked.size(); i++) {
    for (size_t j = 1; j <= i; j++) {
      size_t m = i + j + 2*i*j;
      if (m < marked.size()) {
        marked[m] = true;
      } else {
        break;
      }
    }
  }
  int count = 1;
  for (size_t m = 1; m < marked.size(); m++) {
    if (not marked[m]) count++;
    if (count == N) {
      cout << 2*m + 1 << endl;
      break;
    }
  }
  return 0;
}


R:

```
n length(marked)) break
marked[m] <- T
}
}
count <- 1
for (i in 1:length(marked)) {
if (!marked[i]) count <- count + 1
if (count == n) {
print(

Solution

TL;DR: Just use PyPy; it gets you to about 10x the time of C++. If you really want to use CPython, a lot of clever optimizations (not algorithm changes) gets you as fast as PyPy and then using Numpy gets you close to C++ (2x the time).

The first thing of note is that your Python code is broken:

if m > len(marked):
        break


Remember that Python is 0-indexed. What about when m == len(marked)?

So that's the first thing to fix. Going from the top, I'd do the
translation so:

from math import log

n = 10001                              n  len(marked): break          if (m > length(marked)) break
        marked[m-1] = True                 marked[m] <- T
                                         }
                                       }
count = 1                              count <- 1
for i in range(1, len(marked)+1):      for (i in 1:length(marked)) {
    if not marked[i-1]: count += 1       if (!marked[i]) count <- count + 1
    if count == n:                       if (count == n) {
        print(2*i + 1)                     print(2*i + 1)
        break                              break
                                         }
                                       }


This is as direct a mapping as possible; instead of changing the comparison
I just shifted the index when indexing. This isn't idiomatic, but it's direct.
It's largely the same as your code, but it's correct. This matters when we go to
larger N, where your code fails. It's also Python 3 compatible simply by using
print with brackets.

Let's time them:

$ time Rscript r.r
[1] 104743
Rscript r.r  1.59s user 0.78s system 99% cpu 2.375 total

$ time python2 p.py
104743
python2 p.py  12.88s user 0.00s system 100% cpu 12.873 total

$ time python3 p.py
104743
python3 p.py  0.16s user 0.00s system 98% cpu 0.163 total

$ # A faster, very compatible Python interpreter
$ time pypy p.py   
104743
pypy p.py  0.04s user 0.01s system 98% cpu 0.051 total

$ time pypy3 p.py
104743
pypy3 p.py  0.05s user 0.01s system 99% cpu 0.054 total


So there we have it. Python is over an order of magnitude faster on 75% of
interpreters, and under an order of magnitude slower in the worst case...

But why is it so slow with python2? line_profiler is a good utility:

import line_profiler
profiler = line_profiler.LineProfiler()

def main():
    # ... the code ...

profiler.enable()
profiler.add_function(main)
main()
profiler.print_stats()


Giving:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     4                                           def main():
     5         1          271    271.0      0.0      from math import log
     6                                           
     7         1            1      1.0      0.0      n = 10001
     8         1           22     22.0      0.0      maximum = n * (log(n) + log(log(n)))
     9         1          373    373.0      0.0      marked = [False] * int(maximum // 2)
    10     57160        37122      0.6      0.3      for i in range(1, len(marked)+1):
    11    188523      8444200     44.8     62.9          for j in range(1, i+1):
    12    188355       118260      0.6      0.9              m = i + j + 2*i*j
    13    188355      4672922     24.8     34.8              if m > len(marked): break
    14    131364        75962      0.6      0.6              marked[m-1] = True
    15                                           
    16                                           
    17         1            1      1.0      0.0      count = 1
    18     52371        25618      0.5      0.2      for i in range(1, len(marked)+1):
    19     52371        28393      0.5      0.2          if not marked[i-1]: count += 1
    20     52371        26003      0.5      0.2          if count == n:
    21         1           60     60.0      0.0              print(2*i + 1)
    22         1          329    329.0      0.0              break


So our most likely offending line is:

11    188523      8444200     44.8     62.9          for j in range(1, i+1):


On Python 2 there is both range and xrange. Using

for i in range(...):


will generate a list of numbers and then loop over them, whereas

for i in xrange(...):


will just increment i as wanted each time. This is often much faster. So let's add:

# Set range to xrange on Python 2
try:
    range = xrange
except NameError:
    pass


to the top of the script and retime:

$ time python2 p.py
104743
python2 p.py  0.09s user 0.00s system 96% cpu 0.090 total

$ time python3 p.py
104743
python3 p.py  0.16s user 0.00s system 99% cpu 0.161 total

$ time pypy p.py   
104743
pypy p.py  0.04s user 0.00s system 93% cpu 0.050 total

$ time pypy3 p.py
104743
pypy3 p.py  0.04s user 0.01s system 99% cpu 0.054 total


Yup. CPython is much improved.

So how would we make the code good? First you can transform the i+j+2ij
calculation to directly

Code Snippets

if m > len(marked):
        break
from math import log

n = 10001                              n <- 10001
maximum = n * (log(n) + log(log(n)))   max <- n*(log(n) + log(log(n)))
marked = [False] * int(maximum // 2)   marked <- vector(mode="logical", length=max/2)
for i in range(1, len(marked)+1):      for (i in 1:length(marked)) {
    for j in range(1, i+1):              for (j in 1:i) {
        m = i + j + 2*i*j                  m <- i + j + 2*i*j
        if m > len(marked): break          if (m > length(marked)) break
        marked[m-1] = True                 marked[m] <- T
                                         }
                                       }
count = 1                              count <- 1
for i in range(1, len(marked)+1):      for (i in 1:length(marked)) {
    if not marked[i-1]: count += 1       if (!marked[i]) count <- count + 1
    if count == n:                       if (count == n) {
        print(2*i + 1)                     print(2*i + 1)
        break                              break
                                         }
                                       }
$ time Rscript r.r
[1] 104743
Rscript r.r  1.59s user 0.78s system 99% cpu 2.375 total

$ time python2 p.py
104743
python2 p.py  12.88s user 0.00s system 100% cpu 12.873 total

$ time python3 p.py
104743
python3 p.py  0.16s user 0.00s system 98% cpu 0.163 total

$ # A faster, very compatible Python interpreter
$ time pypy p.py   
104743
pypy p.py  0.04s user 0.01s system 98% cpu 0.051 total

$ time pypy3 p.py
104743
pypy3 p.py  0.05s user 0.01s system 99% cpu 0.054 total
import line_profiler
profiler = line_profiler.LineProfiler()

def main():
    # ... the code ...

profiler.enable()
profiler.add_function(main)
main()
profiler.print_stats()
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     4                                           def main():
     5         1          271    271.0      0.0      from math import log
     6                                           
     7         1            1      1.0      0.0      n = 10001
     8         1           22     22.0      0.0      maximum = n * (log(n) + log(log(n)))
     9         1          373    373.0      0.0      marked = [False] * int(maximum // 2)
    10     57160        37122      0.6      0.3      for i in range(1, len(marked)+1):
    11    188523      8444200     44.8     62.9          for j in range(1, i+1):
    12    188355       118260      0.6      0.9              m = i + j + 2*i*j
    13    188355      4672922     24.8     34.8              if m > len(marked): break
    14    131364        75962      0.6      0.6              marked[m-1] = True
    15                                           
    16                                           
    17         1            1      1.0      0.0      count = 1
    18     52371        25618      0.5      0.2      for i in range(1, len(marked)+1):
    19     52371        28393      0.5      0.2          if not marked[i-1]: count += 1
    20     52371        26003      0.5      0.2          if count == n:
    21         1           60     60.0      0.0              print(2*i + 1)
    22         1          329    329.0      0.0              break

Context

StackExchange Code Review Q#71137, answer score: 30

Revisions (0)

No revisions yet.