HiveBrain v1.2.0
Get Started
← Back to all entries
snippetjavaMinor

MSD radix sort in Java for parallel arrays

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
arraysradixjavamsdparallelforsort

Problem

I have this MSD (most-significant digit) radix sort sorting Entry objects holding a sorting key of type long and a satellite datum. It handles the issue of sign bits so that it produces a permutation that honours the sign bit: the Entry with the sign bit on will precede all the Entry objects with sign bit off.
It is highly efficient on large arrays (almost 3 times faster than java.util.Arrays.sort to sort 1e7 Entry objects).

So what do you think?

Arrays.java:

```
package net.coderodde.util;

public class Arrays {

private static final int BUCKETS = 256;
private static final int BITS_PER_BYTE = 8;
private static final int RIGHT_SHIFT_AMOUNT = 56;
private static final int MOST_SIGNIFICANT_BYTE_INDEX = 7;
private static final int MERGESORT_THRESHOLD = 4096;
private static final int LEAST_SIGNED_BUCKET_INDEX = 128;

public static final void sort(final Entry[] array,
final int fromIndex,
final int toIndex) {
if (toIndex - fromIndex [] buffer = array.clone();
sortTopImpl(array, buffer, fromIndex, toIndex);
}

public static final void sort(final Entry[] array) {
sort(array, 0, array.length);
}

public static final >
boolean isSorted(final E[] array,
final int fromIndex,
final int toIndex) {
for (int i = fromIndex; i 0) {
return false;
}
}

return true;
}

public static final >
boolean isSorted(final E[] array) {
return isSorted(array, 0, array.length);
}

public static final boolean areEqual(final Entry[]... arrays) {
for (int i = 0; i the type of satellite data of each entry.
* @param array the actual array to sort.
* @param buffer the auxiliary buffer.
* @param fromIndex the least index of the range to sort.
* @param toIndex the index

Solution

Entry

Your Entry compareTo(...) method is fine, but you should try to defer to Long.compare() instead. your code:

@Override
public int compareTo(Entry o) {
    if (key  o.key) {
        return 1;
    } else {
        return 0;
    }
}


could be:

@Override
public int compareTo(Entry o) {
    return Long.compare(key, o.key);
}


The key and satelliteData fields in the Entry should also be final, and instead of being public, should rather have 'getters' for them.
Sorting

You have special handling for the buckets and the ranges, depending on negative values.

This special handling has also resulted in a lot of code duplication. You essentially have two complete method duplicates, one for sorting the high-byte (with negative values), and the other for sorting the remaining low bytes. Your code 'buckets' the data (or a data subset) in to buckets based on a significant byte. The challenge here is that the most significant byte has a different sort order than other bytes.

The trick to solving this is to flip the most significant bit, and the resulting order is now accurate as if the long was unsigned.....

Your code would boil down to something like:

private static final int BITS_PER_BUCKET = 8;
private static final int BUCKETS = 1 >> bitShift) & BUCKET_MASK
}

private static  void sortImpl(final Entry[] source,
                                 final Entry[] target,
                                 final int recursionDepth,
                                 final int fromIndex,
                                 final int toIndex) {
    // Try merge sort.
    if (toIndex - fromIndex  e = source[i];
        final int index = getBucket(source[i].key, recursionDepth);
        target[startIndexMap[index] + processedMap[index]++] = e;
    }

    // Recur to sort each bucket.
    for (int i = 0; i != BUCKETS; ++i) {
        if (bucketSizeMap[i] != 0) {
            sortImpl(target,
                     source,
                     recursionDepth + 1,
                     startIndexMap[i],
                     startIndexMap[i] + bucketSizeMap[i]);
        }
    }
}


With this code, there is no need for the 'Top' sort method at all, it's superfluous. Your recursion entry can change from:

sortTopImpl(array, buffer, fromIndex, toIndex);


to:

sortImpl(array, buffer, 0, fromIndex, toIndex);


and you can delete the sortTopImpl method entirely.

Note that the bucket-constants are all based off the one BITS_PER_BUCKET size, and the rest is calculated from that. You should be able to easily change the bucket size by just changing that one constant.

Additionally, just so long as the merge-sort threshold is larger than a bucket-size, then there is no need to check for the limit on the depth of recursion....

Code Snippets

@Override
public int compareTo(Entry<E> o) {
    if (key < o.key) {
        return -1;
    } else if (key > o.key) {
        return 1;
    } else {
        return 0;
    }
}
@Override
public int compareTo(Entry<E> o) {
    return Long.compare(key, o.key);
}
private static final int BITS_PER_BUCKET = 8;
private static final int BUCKETS = 1 << BITS_PER_BUCKET;
private static final int BUCKET_MASK = BUCKETS - 1;
private static final long SIGN_MASK = 1L << 63;

/*
 * Converts a section of a key in to a bucket. Treats sign bit properly.
 */
private static final int getBucket(final long key, final int recursionDepth) {
    final int bitShift = 64 - (recursionDepth + 1) * BITS_PER_BUCKET;
    return (int)((key ^ SIGN_MASK) >>> bitShift) & BUCKET_MASK
}

private static <E> void sortImpl(final Entry<E>[] source,
                                 final Entry<E>[] target,
                                 final int recursionDepth,
                                 final int fromIndex,
                                 final int toIndex) {
    // Try merge sort.
    if (toIndex - fromIndex <= MERGESORT_THRESHOLD) {
        // perform merge sort .....
        ....

        return;
    }

    final int[] bucketSizeMap = new int[BUCKETS];
    final int[] startIndexMap = new int[BUCKETS];
    final int[] processedMap  = new int[BUCKETS];

    // Compute the size of each bucket.
    for (int i = fromIndex; i < toIndex; ++i) {
        bucketSizeMap[getBucket(source[i].key, recursionDepth)]++;
    }

    // Initialize the start index map.
    startIndexMap[0] = fromIndex;

    // Compute the start index map in its entirety.
    for (int i = 1; i != BUCKETS; ++i) {
        startIndexMap[i] = startIndexMap[i - 1] +
                           bucketSizeMap[i - 1];
    }

    // Insert the entries from 'source' into their respective 'target'.
    for (int i = fromIndex; i < toIndex; ++i) {
        final Entry<E> e = source[i];
        final int index = getBucket(source[i].key, recursionDepth);
        target[startIndexMap[index] + processedMap[index]++] = e;
    }

    // Recur to sort each bucket.
    for (int i = 0; i != BUCKETS; ++i) {
        if (bucketSizeMap[i] != 0) {
            sortImpl(target,
                     source,
                     recursionDepth + 1,
                     startIndexMap[i],
                     startIndexMap[i] + bucketSizeMap[i]);
        }
    }
}
sortTopImpl(array, buffer, fromIndex, toIndex);
sortImpl(array, buffer, 0, fromIndex, toIndex);

Context

StackExchange Code Review Q#72045, answer score: 3

Revisions (0)

No revisions yet.