HiveBrain v1.2.0
Get Started
← Back to all entries
patterncppMinor

Segmented Sieve of Eratosthenes

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
segmentedsieveeratosthenes

Problem

I implemented this for the Prime Generator problem on SPOJ, but I am only getting 0.01s run-time, and would like to be able to match the run-times of the top submissions, which are all 0.00s.

What are some suggestions for optimizations that I can do to improve the efficiency of this algorithm?

Input: The input begins with the number t of test cases in a single line (t<=10). In each of the next t lines there are two numbers m and n (1 <= m <= n <= 1000000000, n-m<=100000) separated by a space.

Output: For every test case print all prime numbers p such that m <= p <= n, one number per line, test cases separated by an empty line.

#include 
#include 
#include 

using namespace std;

typedef vector VB;
typedef vector VI;

VI sieve(int n) {

  VB is_prime(n, true);
  int sqrtn = sqrt(n);

  for (int i = 2; i = m && i != prime) {
        is_prime[i - m] = false;
      }
    }
  }

  for (int k = 0; k <= range; ++k) {
    if (is_prime[k]) {
      printf("%d\n", m + k);
    }
  }
}

int main() {

  int t;
  scanf("%d", &t);

  while (t--) {
    int m, n;
    scanf("%d%d", &m, &n);
    if (m == 1) m++;
    segmented_sieve(m, n);
  }
  return 0;
}

Solution

Algorithm is incorrect

The algorithm you are using is incorrect because you do not find the proper range of primes to use for your segmented sieve. You are currently finding and using primes in the range 0..(n-m+1). But you actually need to be finding primes in the range 0..sqrt(n). Here is a trivial test case that demonstrates the problem:

1
9 9


For this test case, you find primes in the range 0..1, which finds no primes, which leads to an incorrect result where you print 9 as a prime. The correct range of primes should be 0..3 for this test case.

Find primes once

Given the maximum \$n\$ of \$10^9\$, you should do a single sieve to find all the primes under \$\sqrt {10^9}\$. There should be 3401 primes under 31623. In fact, you could even hardcode this list of primes into your program to avoid performing the sieve at runtime. It should only take 14KB to hardcode that list (7KB if you store them as short).

Optimizing the segmented sieve

Let's take a look at your main loop:

for (const auto &prime : primes) {
    for (int i = (m / prime) * prime; i = m && i != prime) {
        is_prime[i - m] = false;
      }
    }
  }


There are several things I notice that are suboptimal:

-
The loop increments by prime. You could achieve 2x speeds if you only handled odd numbers and incremented by 2*prime instead. The only even prime is 2 and you can check for that as a special case outside the sieve.

-
You have a if (i >= m) check which could be removed if you just started the loop from the right starting value.

-
You have a if (i != prime) check which could be removed if you just started the loop from the right starting value.

-
You do a subtraction i-m on every iteration.

So what is the correct starting value? You were close in your (m / prime) * prime) expression. The problem is that your expression could result in a value smaller than m. You should just round up like this:

int start = ((m + prime - 1) / prime) * prime;


This gets rid of the i >= m check. But also, start should be at least prime * prime, because any multiple of prime less than that will already have been handled by a lower prime.

start = max(start, prime * prime);


This gets rid of the case where you might reach i == prime. Thirdly, start should be odd, because if you are doing #1 above, you should start at an odd index and increment by an even increment.

if ((start & 1) == 0)
    start += prime;


Lastly, by subtracting m from start and making your loop go until n-m, you can get rid of the subtraction inside the loop.

start -= m;


Rewrite

Here is a rewrite of your whole program using the above ideas:

#include 
#include 
#include 

using namespace std;

#define MAX        1000000000

typedef vector VB;
typedef vector VI;

// This holds the primes from 3..sqrt(MAX)
VI primes;

// Finds the primes from 3..n and stores them in the global "primes".
static void findPrimes(int n)
{
    int i;
    int sqrtn = sqrt(n);
    VB  is_prime(n+1, true);

    for (i = 3; i = m, or prime^2,
        // whichever is larger.  Also make sure start is odd.
        int start = max(((m + prime - 1) / prime) * prime, prime * prime);
        if ((start & 1) == 0)
            start += prime;
        // Adjust to range so we don't have to subtract m in the loop.
        start -= m;
        for (int i = start; i = 2)
        printf("2\n");
    for (int k = (m & 1) ? 0 : 1; k <= range; k += 2) {
        if (is_prime[k]) {
            printf("%d\n", m + k);
        }
    }
}

int main()
{
    int t;
    scanf("%d", &t);

    findPrimes(sqrt(MAX));
    while (t--) {
        int m, n;
        scanf("%d%d", &m, &n);
        if (m == 1) m++;
        segmented_sieve(m, n);
    }
    return 0;
}

Code Snippets

for (const auto &prime : primes) {
    for (int i = (m / prime) * prime; i <= n; i += prime) {
      if (i >= m && i != prime) {
        is_prime[i - m] = false;
      }
    }
  }
int start = ((m + prime - 1) / prime) * prime;
start = max(start, prime * prime);
if ((start & 1) == 0)
    start += prime;
start -= m;

Context

StackExchange Code Review Q#153447, answer score: 4

Revisions (0)

No revisions yet.