HiveBrain v1.2.0
Get Started
← Back to all entries
patterncModerate

Eratosthenes Sieve optimized in C

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
optimizedsieveeratosthenes

Problem

I wrote yet another optimized single-threaded Eratosthenes Sieve implementation in C:

erato.c

// #include 
#include 

#define N 1000000000
#define num_t unsigned long

int main() {
    register char *b = malloc(N * sizeof(char));
    for(num_t i = 0; i ^ N; ++i)
        b[i] = !(i & 1);

    // printf("2\n");
    for(num_t i = 3; i ^ 1 ^ N; i += 2) {
        if(!b[i - 2]) {
            // printf("%llu\n", i);
            const num_t increment = i << 1;
            num_t j = i;
            while(j < N) {
                b[j - 2] = 1;
                j += increment;
            }
        }
    }
    free(b);
    return 0;
}


The output is omitted because it takes most of the time to process.

Optional flags: -O3

Performance

The sieve covers first 1.000.000.000 numbers in


11.14s user 0.32s system 99% cpu 11.465 total

Could you please help me to improve my programming/mathematical skills by reviewing this piece of code?

Solution

Bug

At the end of your program, the b array is all full of 1s, which means that you didn't find any primes. The problem is here:

num_t j = i;


Because you start your j loop at i, you will mark i (which is prime) as non-prime. You should start your j loop at i*i instead (see below).

Don't get tricky

Why write this:

for(num_t i = 0; i ^ N; ++i)


when you can write this:

for(num_t i = 0; i < N; ++i)


I had to stare at your code for a long time just to figure out whether it was correct, for a simple loop!

Also, you did it here:

for(num_t i = 3; i ^ 1 ^ N; i += 2) {


which should be:

for(num_t i = 3; i < N; i += 2) {


You can afford to waste 2 bytes

Here, you use index-2 to save 2 bytes:

if(!b[i - 2]) {
        // printf("%llu\n", i);
        const num_t increment = i << 1;
        num_t j = i;
        while(j < N) {
            b[j - 2] = 1;
            j += increment;
        }
    }


There's no need to save 2 bytes. Just use b[index] instead of b[index-2]. You already allocated the full array anyways.

Speed improvements

There are three things I can see that could speed up your program:

  • In your j loop, you can start j at ii instead of just i. All non-primes less than ii will already be marked.



  • You can stop the i loop at sqrt(N) instead of at N, because once you reach sqrt(N), the second loop won't mark any more entries.



-
Since you are finding primes up to 10^9, you are using 1GB of memory for the b array. You should get better performance if you use less memory, because the processor cache will be utilized more effectively if your array is smaller.

a) You can only track odd numbers, which will reduce your memory usage to 1/2 of the original (512MB).

b) You can use 1 bit per number instead of 1 byte per number. This will reduce your memory to 1/8 of the original. Combined with (a), it will be 1/16 the original, or 64MB. This is a much better than 1GB.

Example using 1 bit per odd number

#include 
#include 
#include 
#include 
#include 

#define N        1000000000

int main(void)
{
    int       arraySize = ((N + 63)/64 + 1) * sizeof(uint32_t);
    uint32_t *primes    = malloc(arraySize);

    // Each bit in primes is an odd number.
    // Bit 0 = 1, bit 1 = 3, bit 2 = 5, etc.
    memset(primes, 0xff, arraySize);

    // 1 is not a prime.
    primes[0] &= ~0x1;

    int sqrt_N = sqrt(N);
    for(int i = 3; i > 6;
        int iBit   = (1 > 1) & 31));
        if ((primes[iIndex] & iBit) != 0) {
            int increment = i+i;
            for (int j = i * i; j > 6;
                int jBit   = (1 > 1) & 31));
                primes[jIndex] &= ~jBit;
            }
        }
    }

    // Count the number of primes in order to verify that the above worked.
    // Start count at 1 to include 2 as the first prime, since we are only
    // going to count odd primes.
    int count = 1;
    for (int i = 3; i > 6;
        int iBit   = (1 > 1) & 31));
        if (primes[iIndex] & iBit)
            count++;
    }
    printf("%d\n", count);

    free(primes);
    return 0;
}


Edit: Even less memory usage

Pete Kirkham suggested using even less memory by using 2 out of every 6 numbers. In other words, using 1 bit per every 3 numbers instead of 1 bit per every 2 numbers. At first I was skeptical because this required using a division in the inner loop. However, after coding it up, it turned out to be faster. The code is quite a bit trickier however, because the inner loops need to avoid any multiples of 3, because all multiples of 3 are no longer in the primes array:

#include 
#include 
#include 
#include 
#include 

#define N        1000000000

int main(void)
{
    int       arraySize = (N/24 + 1);
    uint32_t *primes    = malloc(arraySize);

    // The bits in primes follow this pattern:
    //
    // Bit 0 = 5, bit 1 = 7, bit 2 = 11, bit 3 = 13, bit 4 = 17, etc.
    //
    // For even bits, bit n represents 5 + 6*n
    // For odd  bits, bit n represents 1 + 6*n
    memset(primes , 0xff, arraySize);

    int sqrt_N = sqrt(N);
    for(int i = 5; i > 5;
        int iBit   = 1 > 5;
                int jBit   = 1 = N)
                    break;

                jBitNumber = j / 3 - 1;
                jIndex = jBitNumber >> 5;
                jBit   = 1 > 5;
                int jBit   = 1 = N)
                    break;

                jBitNumber = j / 3 - 1;
                jIndex = jBitNumber >> 5;
                jBit   = 1 > 5;
        int iBit   = 1 << (iBitNumber & 31);
        if (primes[iIndex] & iBit) {
            count++;
        }
        iBit <<= 1;
        if (primes[iIndex] & iBit) {
            count++;
        }
    }
    printf("%d\n", count);

    free(primes);
    return 0;
}


I'm sure the variant that works on multiples of 30 would be even faster than this, although the code would be even more complicated.

Timings

Here are the speeds of the various programs I ran

Code Snippets

num_t j = i;
for(num_t i = 0; i ^ N; ++i)
for(num_t i = 0; i < N; ++i)
for(num_t i = 3; i ^ 1 ^ N; i += 2) {
for(num_t i = 3; i < N; i += 2) {

Context

StackExchange Code Review Q#112901, answer score: 10

Revisions (0)

No revisions yet.