patternjavaModerate
Multiplication 128 x 64 bits
Viewed 0 times
128bitsmultiplication
Problem
I was fooling around with the Collatz sequence a bit and found out that
I'm interested in optimization tips and general comments (apart from formatting as it's fine, at least according to my conventions).
Below I'm giving only the multiplication code and supporting methods. The whole class can be found on github.
```
/**
* A class representing a 127 bit positive number
* (negative numbers are unsupported and the highest bit is currently unusable).
* Only those operations which were needed for this toy project are implemented.
*
* Speed is the goal, general usability is not given.
*/
@Getter @EqualsAndHashCode final class MutableLong128 implements Cloneable {
/**
* Multiply {@code this} by the argument which must be non-negative.
* Overflow may or may not be detected.
*/
public void multiply(long x) {
checkArgument(x>=0);
long y0 = low;
long y1 = high;
// The variables p0 to p3 represent the product where each of them is to be treated as an unsigned quantity.
// The weight of pi is 232 rather then 264.
long p0 = unsignedLow(x) * unsignedLow(y0);
long p1 = unsignedLow(x) * unsignedHigh(y0);
long p2 = unsignedHigh(x) * unsignedHigh(y0);
// Before another product can be added to p1, its higher part must be transferred to p2.
p2 += unsignedHigh(p1);
p1 = unsignedLow(p1) + unsignedHigh(x) * unsignedLow(y0);
// p3 represents the highest part and a possible overflow gets ignored.
long p3 = unsignedLow(x) unsignedHigh(y1) + unsignedHigh(x) unsignedLow(y1);
// Sim
long is only sufficient for starting values below 8528817511L. As I wanted to optimize my code from this answer for computations up to \$10^{10}\$, I needed either BigInteger or to run my own. As the range is limited and just few operations are needed, I tried my own. My special-case multiplication is 5x faster, so I guess it wasn't a useless exercise.I'm interested in optimization tips and general comments (apart from formatting as it's fine, at least according to my conventions).
Below I'm giving only the multiplication code and supporting methods. The whole class can be found on github.
```
/**
* A class representing a 127 bit positive number
* (negative numbers are unsupported and the highest bit is currently unusable).
* Only those operations which were needed for this toy project are implemented.
*
* Speed is the goal, general usability is not given.
*/
@Getter @EqualsAndHashCode final class MutableLong128 implements Cloneable {
/**
* Multiply {@code this} by the argument which must be non-negative.
* Overflow may or may not be detected.
*/
public void multiply(long x) {
checkArgument(x>=0);
long y0 = low;
long y1 = high;
// The variables p0 to p3 represent the product where each of them is to be treated as an unsigned quantity.
// The weight of pi is 232 rather then 264.
long p0 = unsignedLow(x) * unsignedLow(y0);
long p1 = unsignedLow(x) * unsignedHigh(y0);
long p2 = unsignedHigh(x) * unsignedHigh(y0);
// Before another product can be added to p1, its higher part must be transferred to p2.
p2 += unsignedHigh(p1);
p1 = unsignedLow(p1) + unsignedHigh(x) * unsignedLow(y0);
// p3 represents the highest part and a possible overflow gets ignored.
long p3 = unsignedLow(x) unsignedHigh(y1) + unsignedHigh(x) unsignedLow(y1);
// Sim
Solution
After many iterations, where I tried stuff in C that didn't port to Java cleanly, I came up with the following, that should work equally well in Java as in C. When tested in C versus the OP's function (ported to C), this function was about 60% faster. I don't know if that speed will translate to Java or not. Here is the Java version:
An explanation of the algorithm
To explain the algorithm I use, first it would be good to look at how to do the multiplication the "normal way", which is the way the OP's function works. The normal way involves using 32-bit multiplies to create 64-bit partial products, and adding the partial products together.
First notice that
y01 = y.low
y12 = U(y.low) + L(y.high)
What does multiplying
But if you notice, this is exactly what we needed to compute for
This one line solves half the problem with a single multiplication. The rest of the function is dedicated to finding the upper 64 bits. Let's look at how the equation for the upper 64 bits can be simplified:
This is exactly what is computed in my function. It would be nice if there was a better way of taking care of the "Carry", but as of right now I'm not sure if there is anything better than the straightforward way. I did try a different way but it ended up being slower so I reverted that change.
In the end, there are 6 multiplies which is only one less than the original function, but there are only 6 adds which is significantly less the original. There are also 6 truncate operations where we take the upper or lower 32-bits, and that is also a lot less than in the original function.
Added by the OP
The code passed my tests and won my benchmark by a factor of 1.4. Great!
public void multiply(long x)
{
long y0 = unsignedLow (low);
long y1 = unsignedHigh(low);
long x0 = unsignedLow (x);
long x1 = unsignedHigh(x);
long y12 = composeLowHigh(y1, high);
// The low 64-bits of the output can be computed easily.
low = x * low;
// The upper 64 bits of the output is a combination of several
// factors. These are the first two:
high = (high x0) + (y12 x1);
// Now handle the factors coming from the rest:
long p01 = x0 * y1;
long p10 = x1 * y0;
long p00 = x0 * y0;
// Add the high parts directly in.
high += unsignedHigh(p01);
high += unsignedHigh(p10);
// Account for the possible carry from the low parts.
long p2 = unsignedHigh(p00) + unsignedLow(p01) + unsignedLow(p10);
high += unsignedHigh(p2);
}
An explanation of the algorithm
To explain the algorithm I use, first it would be good to look at how to do the multiplication the "normal way", which is the way the OP's function works. The normal way involves using 32-bit multiplies to create 64-bit partial products, and adding the partial products together.
Multiplication of 64-bit x against 128-bit y using 32-bit multiplies:
U(n) means upper 32 bits of n.
L(n) means lower 32 bits of n.
x0 = L(x)
x1 = U(x)
y0 = L(y.low)
y1 = U(y.low)
y2 = L(y.high)
y3 = U(y.high)
p00 = x0 * y0
p01 = x0 * y1
p02 = x0 * y2
p03 = x0 * y3
p10 = x1 * y0
p11 = x1 * y1
p12 = x1 * y2
p13 = x1 * y3
( Bits in the result )
0.......32......64......96......128.....160.....192
[------p00------]
[------p01------]
[------p10------]
[------p02------]
[------p11------]
[------p03------]
[------p12------]
[------p13------]
First notice that
p13 is not needed because we are not looking for any bits beyond bit 127 of the result. That leaves 7 multiplies (notice the OP's function has exactly 7 multiplies), and a bunch of adds/shifts. The formula for the resulting low and high could be written as:low = p00 + L(p01)
This algorithm uses this idea to speed up the computation:
We can use a 64 bit x 64 bit multiplication when we only need the bottom 64 bits of the product.
What does this mean? There are several places where we can save multiplies and adds by multiplying more than 32-bits at a time. What happens when we multiply 64 bits against 64 bits? Suppose we define these 64 bit values:
x01 = xy01 = y.low
y12 = U(y.low) + L(y.high)
What does multiplying
x01 * y01 produce? The full 128-bit result in terms of the partial products listed above would be:x01 * y01 = p00 + p01
But since the result is only 64 bits, the actual 64-bit result would be:
x01 * y01 = p00 + L(p01)But if you notice, this is exactly what we needed to compute for
low. So the first computation of my algorithm is exactly that:low = x * low;
This one line solves half the problem with a single multiplication. The rest of the function is dedicated to finding the upper 64 bits. Let's look at how the equation for the upper 64 bits can be simplified:
high = U(p01) + U(p10) + p02 + p11 + L(p03)
We can use the same idea to combine some of the multiplies together. Except instead of multiplying two 64-bit numbers together, we can multiply a 32-bit number by a 64-bit number. The value p02 + L(p03)high = U(p01) + U(p10) + x0y23 + x1y12 + Carry
Carry = U(U(p00) + L(p01) + L(p10))
This is exactly what is computed in my function. It would be nice if there was a better way of taking care of the "Carry", but as of right now I'm not sure if there is anything better than the straightforward way. I did try a different way but it ended up being slower so I reverted that change.
In the end, there are 6 multiplies which is only one less than the original function, but there are only 6 adds which is significantly less the original. There are also 6 truncate operations where we take the upper or lower 32-bits, and that is also a lot less than in the original function.
Added by the OP
The code passed my tests and won my benchmark by a factor of 1.4. Great!
52676.447224 ns {=timeBigInteger}
9098.453982 ns {=timeMutableLong}
6351.209769 ns {=timeMutableLong2}
Context
StackExchange Code Review Q#72286, answer score: 12
Revisions (0)
No revisions yet.