HiveBrain v1.2.0
Get Started
← Back to all entries
patterncppModerate

SIMD matrix multiplication

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
simdmatrixmultiplication

Problem

I recently started toying with SIMD and came up with the following code for matrix multiplication.

First I attempted to implement it using SIMD the same way I did in SISD, just using SIMD for things like the dot product for each particular entry, which was actually slower (still trying to figure this one out).

After giving it some thought I realized I could do the calculation for the resulting matrix row-by-row instead, by lining up the registers something like this (each row is one SIMD register, and each column is the x, y, z, w parts):

With matrices \$A\$, \$B\$ and computing \$C = A * B\$:

A_00 B_00 A_00 B_01
+ +
A_01 B_10 A_01 B_11
+ +
A_02 B_20 A_02 B_21
+ +
A_03 B_30 A_03 B_31
= =
C_00 = Dot(A_Row0, B_Col0), C_01 = Dot(A_Row0, B_Col1), ...

A_10 B_00 A_10 B_01
+ +
A_11 B_10 A_11 B_11
+ +
... ...
C_10 = Dot(A_Row1, B_Col0), C_11 = Dot(A_Row1, B_Col1), ...


I would greatly appreciate if someone with more experience with these things could tell me how far off I am from a good solution.

`__m128 BCx = _mm_load_ps((float*)&B.Row0);
__m128 BCy = _mm_load_ps((float*)&B.Row1);
__m128 BCz = _mm_load_ps((float*)&B.Row2);
__m128 BCw = _mm_load_ps((float*)&B.Row3);

// Calculate Row0 in resulting matrix
__m128 ARx = _mm_set1_ps(A.Row0.X);
__m128 ARy = _mm_set1_ps(A.Row0.Y);
__m128 ARz = _mm_set1_ps(A.Row0.Z);
__m128 ARw = _mm_set1_ps(A.Row0.W);

__m128 X = _mm_mul_ps(ARx, BCx);
__m128 Y = _mm_mul_ps(ARy, BCy);
__m128 Z = _mm_mul_ps(ARz, BCz);
__m128 W = _mm_mul_ps(ARw, BCw);

__m128

Solution

Restructure into a function to make it DRYer

Right now the structure of your code is pretty unpleasant, not very DRY, etc. The first thing I'd recommend is restructuring this to work in a function. I don't know for sure how your Mat4 structure is implemented, however you indicated in the comments that it is contiguous so I've based my assumptions off of that. I'd recommend encapsulating this into this sort of function

void dotFourByFourMatrix(const Mat4* left, const Mat4* right, Mat4* result) {
    const __m128 BCx = _mm_load_ps((float*)&B.Row0);
    const __m128 BCy = _mm_load_ps((float*)&B.Row1);
    const __m128 BCz = _mm_load_ps((float*)&B.Row2);
    const __m128 BCw = _mm_load_ps((float*)&B.Row3);

    float* leftRowPointer = &left->Row0;
    float* resultRowPointer = &result->Row0;

    for (unsigned int i = 0; i < 4; ++i, leftRowPointer += 4, resultRowPointer += 4) {
        __m128 ARx = _mm_set1_ps(leftRowPointer[0]);
        __m128 ARy = _mm_set1_ps(leftRowPointer[1]);
        __m128 ARz = _mm_set1_ps(leftRowPointer[2]);
        __m128 ARw = _mm_set1_ps(leftRowPointer[3]);

        __m128 X = ARx * BCx;
        __m128 Y = ARy * BCy;
        __m128 Z = ARz * BCz;
        __m128 W = ARw * BCw;

        __m128 R = X + Y + Z + W;
        _mm_store_ps(resultRowPointer, R);
    }
}


You'll notice that I've made some assumptions about how the pointers increment - my assumption was that member variables like Row0 or X were merely convenience pointers into a single 16-member contiguous array. If that isn't accurate, this will break. I've also labeled the BCx like variables as const because they don't seem as if they should ever change.
Potential undefined behavior in my suggestion

An important note is that there is some question about if this will work perfectly on every system - it would appear that the incrementing of leftRowPointer and resultRowPointer is undefined (after the last iteration) but unlikely to cause problems. I asked a question on StackOverfow about this that you may be interested in reading.
Performance boosts

SIMD is generally super easy to do and gain a speedup. As long as your data layout is good (aligned, contiguous, etc) then you should have pretty good caching behavior, which is the most important aspect of SIMD. The only way you could really improve this would be to add prefetching (that article is fantastic, and has a good section on prefetching), but you generally have to play around with your prefetch distance a bit - i.e. how far ahead you prefetch. This is because the prefetch still takes time to complete, so if you prefetch for the next matrix it probably won't get you any speed improvement unless your current computation takes enough time to mask the prefetching. Without timing it there's no way to be certain how many this is.

for each matrix in a lot of matrices:
    prefetch the matrix 40 matrices ahead, for example
    do some computation with the current matrix


If you don't do work on a lot of matrices sequentially, and you don't have a good way of knowing which matrix you'll work on well in advance, prefetching won't give you anything.

Also avoid pointer chasing - dereferencing a pointer isn't free, and almost always leads to more cache misses (of both the pointer and whatever it points to). Again, I don't see any evidence of that, unless that's what Row0 and such is doing.

You also need to avoid swizzling (favorite word ever) which is basically when you have to rearrange memory to put it into a simd register - using _mm_set_ps is usually a sign that you're swizzling. I don't see any evidence of that here, but for future reference it might be useful.
Use better and fewer intrinsics

I changed your function to not use _mm_mul_ps or _mm_add_ps and just use the operators - much easier to read. Honestly I think you could condense a lot of that even more without sacrificing too much readability - I don't know that all of the temporaries are necessary. I also got rid of the call to _mm_storeu_ps and replaced it with the aligned call - this is going to be substantially faster, and if you have any way of ensuring that you allocate your matrices in an aligned fashion that would be ideal.
Or drop intrinsics all together

In general I find the intrinsics to be a little hard to read, and they're less cross-platform. I strongly suggest using a library like Agner Fog's vectorclass library for both portability and readability.

If you can't or won't use a 3rd party library, it is really quite easy to write up a little wrapper class that is much more readable, and if you're into templates and macros you can make it very portable as well.
Casts

All of that casting makes me a little nervous - again, I don't know anything about the format of your data, but it seems like it should be unnecessary to use the explicit casts. If you must cast, don't use C-style casts - I usually prefer static_cast. You can learn more about why yo

Code Snippets

void dotFourByFourMatrix(const Mat4* left, const Mat4* right, Mat4* result) {
    const __m128 BCx = _mm_load_ps((float*)&B.Row0);
    const __m128 BCy = _mm_load_ps((float*)&B.Row1);
    const __m128 BCz = _mm_load_ps((float*)&B.Row2);
    const __m128 BCw = _mm_load_ps((float*)&B.Row3);

    float* leftRowPointer = &left->Row0;
    float* resultRowPointer = &result->Row0;

    for (unsigned int i = 0; i < 4; ++i, leftRowPointer += 4, resultRowPointer += 4) {
        __m128 ARx = _mm_set1_ps(leftRowPointer[0]);
        __m128 ARy = _mm_set1_ps(leftRowPointer[1]);
        __m128 ARz = _mm_set1_ps(leftRowPointer[2]);
        __m128 ARw = _mm_set1_ps(leftRowPointer[3]);

        __m128 X = ARx * BCx;
        __m128 Y = ARy * BCy;
        __m128 Z = ARz * BCz;
        __m128 W = ARw * BCw;

        __m128 R = X + Y + Z + W;
        _mm_store_ps(resultRowPointer, R);
    }
}
for each matrix in a lot of matrices:
    prefetch the matrix 40 matrices ahead, for example
    do some computation with the current matrix

Context

StackExchange Code Review Q#101144, answer score: 12

Revisions (0)

No revisions yet.