HiveBrain v1.2.0
Get Started
← Back to all entries
patternjavaModerate

Multiplication 128 x 64 bits

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
128bitsmultiplication

Problem

I was fooling around with the Collatz sequence a bit and found out that long is only sufficient for starting values below 8528817511L. As I wanted to optimize my code from this answer for computations up to \$10^{10}\$, I needed either BigInteger or to run my own. As the range is limited and just few operations are needed, I tried my own. My special-case multiplication is 5x faster, so I guess it wasn't a useless exercise.

I'm interested in optimization tips and general comments (apart from formatting as it's fine, at least according to my conventions).

Below I'm giving only the multiplication code and supporting methods. The whole class can be found on github.

```
/**
* A class representing a 127 bit positive number
* (negative numbers are unsupported and the highest bit is currently unusable).
* Only those operations which were needed for this toy project are implemented.
*
* Speed is the goal, general usability is not given.
*/
@Getter @EqualsAndHashCode final class MutableLong128 implements Cloneable {
/**
* Multiply {@code this} by the argument which must be non-negative.
* Overflow may or may not be detected.
*/
public void multiply(long x) {
checkArgument(x>=0);

long y0 = low;
long y1 = high;

// The variables p0 to p3 represent the product where each of them is to be treated as an unsigned quantity.
// The weight of pi is 232 rather then 264.

long p0 = unsignedLow(x) * unsignedLow(y0);
long p1 = unsignedLow(x) * unsignedHigh(y0);
long p2 = unsignedHigh(x) * unsignedHigh(y0);

// Before another product can be added to p1, its higher part must be transferred to p2.
p2 += unsignedHigh(p1);
p1 = unsignedLow(p1) + unsignedHigh(x) * unsignedLow(y0);

// p3 represents the highest part and a possible overflow gets ignored.
long p3 = unsignedLow(x) unsignedHigh(y1) + unsignedHigh(x) unsignedLow(y1);

// Sim

Solution

After many iterations, where I tried stuff in C that didn't port to Java cleanly, I came up with the following, that should work equally well in Java as in C. When tested in C versus the OP's function (ported to C), this function was about 60% faster. I don't know if that speed will translate to Java or not. Here is the Java version:

public void multiply(long x)
{
long y0 = unsignedLow (low);
long y1 = unsignedHigh(low);
long x0 = unsignedLow (x);
long x1 = unsignedHigh(x);
long y12 = composeLowHigh(y1, high);

// The low 64-bits of the output can be computed easily.
low = x * low;

// The upper 64 bits of the output is a combination of several
// factors. These are the first two:
high = (high x0) + (y12 x1);

// Now handle the factors coming from the rest:
long p01 = x0 * y1;
long p10 = x1 * y0;
long p00 = x0 * y0;

// Add the high parts directly in.
high += unsignedHigh(p01);
high += unsignedHigh(p10);

// Account for the possible carry from the low parts.
long p2 = unsignedHigh(p00) + unsignedLow(p01) + unsignedLow(p10);
high += unsignedHigh(p2);
}


An explanation of the algorithm

To explain the algorithm I use, first it would be good to look at how to do the multiplication the "normal way", which is the way the OP's function works. The normal way involves using 32-bit multiplies to create 64-bit partial products, and adding the partial products together.

Multiplication of 64-bit x against 128-bit y using 32-bit multiplies:

U(n) means upper 32 bits of n.
L(n) means lower 32 bits of n.

x0 = L(x)
x1 = U(x)
y0 = L(y.low)
y1 = U(y.low)
y2 = L(y.high)
y3 = U(y.high)

p00 = x0 * y0
p01 = x0 * y1
p02 = x0 * y2
p03 = x0 * y3
p10 = x1 * y0
p11 = x1 * y1
p12 = x1 * y2
p13 = x1 * y3

( Bits in the result )
0.......32......64......96......128.....160.....192
[------p00------]
[------p01------]
[------p10------]
[------p02------]
[------p11------]
[------p03------]
[------p12------]
[------p13------]


First notice that p13 is not needed because we are not looking for any bits beyond bit 127 of the result. That leaves 7 multiplies (notice the OP's function has exactly 7 multiplies), and a bunch of adds/shifts. The formula for the resulting low and high could be written as:

low = p00 + L(p01)

This algorithm uses this idea to speed up the computation:


We can use a 64 bit x 64 bit multiplication when we only need the bottom 64 bits of the product.

What does this mean? There are several places where we can save multiplies and adds by multiplying more than 32-bits at a time. What happens when we multiply 64 bits against 64 bits? Suppose we define these 64 bit values:

x01 = x
y01 = y.low
y12 = U(y.low) + L(y.high)

What does multiplying x01 * y01 produce? The full 128-bit result in terms of the partial products listed above would be:

x01 * y01 = p00 + p01

But since the result is only 64 bits, the actual 64-bit result would be:

x01 * y01 = p00 + L(p01)

But if you notice, this is exactly what we needed to compute for low. So the first computation of my algorithm is exactly that:

low = x * low;


This one line solves half the problem with a single multiplication. The rest of the function is dedicated to finding the upper 64 bits. Let's look at how the equation for the upper 64 bits can be simplified:

high = U(p01) + U(p10) + p02 + p11 + L(p03)

We can use the same idea to combine some of the multiplies together. Except instead of multiplying two 64-bit numbers together, we can multiply a 32-bit number by a 64-bit number. The value
p02 + L(p03)high = U(p01) + U(p10) + x0y23 + x1y12 + Carry
Carry = U(U(p00) + L(p01) + L(p10))


This is exactly what is computed in my function. It would be nice if there was a better way of taking care of the "Carry", but as of right now I'm not sure if there is anything better than the straightforward way. I did try a different way but it ended up being slower so I reverted that change.

In the end, there are 6 multiplies which is only one less than the original function, but there are only 6 adds which is significantly less the original. There are also 6 truncate operations where we take the upper or lower 32-bits, and that is also a lot less than in the original function.

Added by the OP

The code passed my tests and won my benchmark by a factor of 1.4. Great!

52676.447224 ns {=timeBigInteger}
9098.453982 ns {=timeMutableLong}
6351.209769 ns {=timeMutableLong2}

Context

StackExchange Code Review Q#72286, answer score: 12

Revisions (0)

No revisions yet.