15x faster than Eigen
Accelerating CPU-side matrix multiplication with AMX and Neon
Hello! This month I’ve been digging into dense matrices: how to write blazing fast CPU-side matrix multiplications, and how to adjust the storage pattern to minimize the time spent waiting on cache misses.
I’ll soon be starting a new job at Bedrock Energy, where I’ll be working on accelerating some physical simulations that we use to design an optimal field of boreholes for geothermal heat pumps. These sorts of simulations, using the Finite Element Method (FEM), essentially come down to solving the equation Au=b, where A is a known matrix, b is a known vector, and u is the vector you’re trying to solve for. You could solve this by finding the inverse of A (assuming it’s invertible) and computing u=A^(-1)*b, but for an n x n matrix inversion is an O(n^3) operation, and the matrices we deal with are quite large. So practical techniques involve an optimization process: start with a guess for u, compute Au and compare to the known target b, and adjust your guess to produce a better guess. That process involves a lot of repeated matmuls, so it’s important to be able to do them fast.
An aside: in practice the matrix A is usually very sparse, with most entries being 0. For that scenario the optimal storage and matmul algorithms end up being quite different than the general case. In this blog post we’ll ignore that, and just look at dense matrices. I may look at sparse matrices in a future post.
And one last thing before we get into the meat of the optimization. When I heard “matrix multiplication”, my first instinct was “can we do this on a GPU?” GPUs are linear-algebra crunching machines, thanks to ~35 years of optimization for games and ML workloads. The issue in our case, though, is the high cost of data transfer from RAM to GPU memory. In the optimization cycle described above, each time we try a guess for u we need to send that vector over to the GPU. It’s my impression, although I haven’t yet done a careful evaluation, that the time taken by data transfer would eat up the gain from faster matrix multiplications.
So I’m left wondering: how fast can we get matrix multiplications on modern CPUs? Let’s find out. To give a goalpost to aim for, we’ll benchmark against Eigen, which is in my experience the linear algebra library that everyone uses.
1. A naive implementation
We’ll start with the simplest implementation we can think of. We’ll store matrices at rest in row-major order. This means that the values are in one big NxM array, walking through the matrix from left to right and top to bottom, in the same order that your eyes are reading through this text.
struct RowMatrix {
size_t rows, height;
f32 *values;
RowMatrix(size_t width, size_t height)
: width(width), height(height)
{
values = (f32*)aligned_alloc(128, sizeof(f32) * width * height);
}
~RowMatrix() {
free(values);
}
inline void set(size_t row, size_t col, f32 value) {
values[row * width + col] = value;
}
inline f32 get(size_t row, size_t col) {
return values[row * width + col];
}
};The matrix multiplication will also be as simple as possible. Matrix multiplication is defined such that if C = A*B, then C_ij = sum_k(A_ik * B_kj). We’ll implement that directly with a triply nested for-loop.
void naive_matmul(RowMatrix &C, RowMatrix &A, RowMatrix &B) {
for (size_t i = 0; i < C.rows; i++) {
for (size_t j = 0; j < C.cols; j++) {
f32 acc = 0.0f;
for (size_t k = 0; k < A.cols; k++) {
acc += A.get(i, k) * B.get(k, j);
}
C.set(i, j, acc);
}
}
}Done! How’d we do? On a 512 x 512 matrix of f32s, our naive implementation was 0.03x as fast as Eigen. Oof!. I’ll notate this as a speed of 0.03E, where 1E is the speed of Eigen, and 2E is twice as fast.
2. Transpose first
So how can we improve on that? Let’s think about memory first. When you issue a load instruction, the CPU will first check the L1 cache for the data before repeating with the L2, then L3, and finally issuing a fetch from RAM. The caches are split into cache lines, which are usually 64 bytes these days. When the fetch misses the cache and falls through to RAM, it will pull up a whole cache line’s worth of consecutive memory at a time, replacing some line which was already in the cache. This helps to keep the number of lines for the CPU to check small, and also helps make the ~400-cycle long fetch worth it.
This is relevant to the inner loop of our naive implementation. In that k-loop we’re iterating over a row of A and a column of B. Since our matrices are in row-major format, the elements of that row of A are consecutive in memory. When we load the first 4-byte element, we actually pull 64B/4B=8 elements at once into the cache, making the next iterations of the loop quicker to run. The elements of the column of B, however, are spaced out in memory, so each fetch from RAM can only pull one into cache at a time.
Under the hypothesis that the extra fetches are affecting our runtime, we can try first transposing B so that when we’re in our hot loop those fetches can run smoothly.
void naive_matmul_transp(RowMatrix &C, RowMatrix &A, RowMatrix &B) {
RowMatrix BB(B.cols, B.rows);
for (size_t i = 0; i < C.rows; i++) {
for (size_t j = 0; j < C.cols; j++) {
BB.set(j, i, B.get(i, j));
}
}
for (size_t i = 0; i < C.rows; i++) {
for (size_t j = 0; j < C.cols; j++) {
f32 acc = 0.0f;
for (size_t k = 0; k < A.cols; k++) {
acc += A.get(i, k) * BB.get(j, k);
}
C.set(i, j, acc);
}
}
}This speeds up our matmul by 1.5x, bringing it to 0.04E. Instrumentation shows that the transpose is just 1% of the runtime, and the multiply 99%.
3. SIMD!
Next let’s try some SIMD. I’m on an Apple M1 chip, which is ARM, so we’re looking at the Neon instruction set. Neon allows us to do loads, adds, and multiplies of four f32s with one instruction, as well as a bunch of other stuff that we won’t get into here1. As long as the chip has enough floating point hardware, this should allow us to speed up our matmuls by a factor of 4.
To understand the code sample, let’s go through some Neon basics. A float32x4_t is Neon’s C type for a 128-bit vector of four f32s. We can put them in variables as normal, and the compiler will make sure to use the vector registers instead of the normal ones. vmulq_f32 and vaddq_f32 each take two float32x4_ts and produce the elementwise product and sum of the two vectors, respectively. So vmulq_f32({1,2,3,4}, {5,6,7,8}) → {5,12,21,32}, and vaddq_f32({1,2,3,4}, {5,6,7,8}) → {6,8,10,12}. vaddvq_f32 takes a float32x4_t and adds the component f32s together, producing a single f32 sum.
#include <arm_neon.h>
void matmul_1(RowMatrix &C, RowMatrix &A, RowMatrix &B) {
RowMatrix BB(B.cols, B.rows);
for (size_t i = 0; i < C.rows; i++) {
for (size_t j = 0; j < C.cols; j++) {
BB.set(j, i, B.get(i, j));
}
}
for (size_t i = 0; i < C.rows; i++) {
for (size_t j = 0; j < C.cols; j++) {
float32x4_t acc = {};
for (size_t k = 0; k < A.cols; k += 4) {
float32x4_t *ap = (float32x4_t *)(A.values + A.cols * i + k);
float32x4_t *bp = (float32x4_t *)(BB.values + BB.cols * j + k);
float32x4_t product = vmulq_f32(*ap, *bp);
acc = vaddq_f32(acc, product);
}
C.set(i, j, vaddvq_f32(acc));
}
}
}I’ll note here that this implementation assumes the dimensions of your matrices are divisible by 4. It’s not hard to extend the matrix implementation to make sure that this is always true internally by padding the edges of the matrix with zeros. For brevity, I won’t bother to do that in this post.
Using Neon gives us a 4.5x speedup, bringing us to 0.18E. I’m not sure how we do better than 4x there. If you have an idea, I’d be eager to hear about it in the comments.
4. Loop unrolling
I tried a few things next, none of which improved the performance: moving the pointer computation outside the k-loop; incrementing the pointer instead of issuing a p[k]; using fused multiply-add in the form of vfmaq_f32. The next improvement came from a little loop unrolling:
void matmul_6(RowMatrix &C, RowMatrix &A, RowMatrix &B) {
RowMatrix BB(B.cols, B.rows);
for (size_t i = 0; i < C.rows; i++) {
for (size_t j = 0; j < C.cols; j++) {
BB.set(j, i, B.get(i, j));
}
}
for (size_t i = 0; i < C.rows; i++) {
for (size_t j = 0; j < C.cols; j++) {
float32x4_t acc_1 = {};
float32x4_t acc_2 = {};
float32x4_t acc_3 = {};
float32x4_t acc_4 = {};
float32x4_t *ap = (float32x4_t *)(A.values + A.cols * i);
float32x4_t *bp = (float32x4_t *)(BB.values + BB.cols * j);
for (size_t k = 0; k < A.cols; k += 16) {
acc_1 = vaddq_f32(acc_1, vmulq_f32(*(ap++), *(bp++)));
acc_2 = vaddq_f32(acc_2, vmulq_f32(*(ap++), *(bp++)));
acc_3 = vaddq_f32(acc_3, vmulq_f32(*(ap++), *(bp++)));
acc_4 = vaddq_f32(acc_4, vmulq_f32(*(ap++), *(bp++)));
}
C.set(
i, j,
vaddvq_f32(acc_1) + vaddvq_f32(acc_2) +
vaddvq_f32(acc_3) + vaddvq_f32(acc_4)
);
}
}
}This bumped us up by 1.8x, landing us at 0.32E.
Loop unrolling like this generally helps improve performance by reducing the number of loop instructions to be executed and allowing more loop iterations to run in parallel. Exactly how many times to unroll is a priori a mystery; profile on your target platform to find the right number. For my machine, 4 hit the sweet spot; 2 or 8 were both slower.
Also slower was using f32 accumulators instead of the float32x4 accumulators. I guess the vector reduce vaddvq_f32 is slower than the vector add vaddq_f32.
5. Apple Silicon AMX
The next big idea is to use a newer, bigger form of SIMD. Apple Silicon has an undocumented instruction set known as AMX, and brilliantly reverse-engineered by Peter Cawley (corsix@). I’ll be using his header file to get access to AMX intrinsics throughout this post. AMX gives you enormous 512-bit vector registers for operands and lets you drive a huge amount of floating point hardware to do a 16x16 vector outer product with a single instruction. Bram Wasti has some great visuals to illustrate how bonkers huge these registers are compared to Neon in his post on jott.live.
The idea then is to rearrange our matrix multiplication to be a bunch of outer products. This takes a couple of insights. The first one is that a sub-block C(a:b, c:d) = (AB)(a:b, c:d) = A(a:b, :) * B(:, c:d). The second is that any matrix multiplication AB can be factored as the sum of the outer products of matching rows of A and columns of B. Combining these two insights, we can compute a 16x16 block of our result matrix C by finding the right 16 columns of A and rows of B, and then iterating along the k-axis computing outer products of pairs of 16-float vectors.
#include "amx/aarch64.h";
inline void transpose(RowMatrix &dst, RowMatrix &src) {
for (size_t row = 0; row < src.rows; row++) {
for (size_t col = 0; col < src.cols; col++) {
dst.set(col, row, src.get(row, col));
}
}
}
void matmul_amx_1(RowMatrix &C, RowMatrix &A, RowMatrix &B) {
RowMatrix AA(A.cols, A.rows);
transpose(AA, A);
RowMatrix CC(C.cols, C.rows);
// Enable AMX instructions.
// Have to do this or the program will crash.
AMX_SET();
// Loop over each 16x16 block in C to compute it.
for (size_t block_row = 0; block_row < C.rows; block_row += 16) {
for (size_t block_col = 0; block_col < C.cols; block_col += 16) {
// Only set for k == 0
// to reset the matrix-accumulator Z.
u64 reset_z = 1ull << 27;
for (size_t k = 0; k < A.cols; k++) {
// Load 16 floats from AA into X register 0.
AMX_LDX((u64)(AA.values + (k+q)*AA.cols + block_row));
// Load 16 floats from B into Y register 0.
AMX_LDY((u64)(B.values + (k+q)*B.cols + block_col));
// if (reset_z) Z = 0;
// Z += outer_product(X, Y);
AMX_FMA32(reset_z);
reset_z = 0;
}
// Read the 16x16 floats from the various Z registers
// into the correct block of CC.
for (u64 i = 0; i < 16; i++) {
u64 z_reg = (i * 4ull) << 56;
AMX_STZ(z_reg | (u64)(CC.values + (block_col + i)*CC.cols + block_row));
}
}
}
// Disable AMX instructions.
// Have to do this because AMX_SET() is not reentrant.
AMX_CLR();
transpose(C, CC);
}This gives us a huge 10x improvement, bringing us to 3.27E.
6. Bigger blocks, fewer loads
We can further improve our utilization of the hardware by doing a 32x32 block of C at once. The advantage here comes from improving the ratio of loads to outer products: previously we did 2 loads to compute 1 outer product, whereas now we do 4 loads to compute 4 outer products. This cuts our program’s total loads in half, saving memory bandwidth and cutting the runtime.
void matmul_amx_2(RowMatrix &C, RowMatrix &A, RowMatrix &B) {
RowMatrix AA(A.cols, A.rows);
transpose(AA, A);
RowMatrix CC(C.cols, C.rows);
AMX_SET();
size_t BLOCK_SIZE = 32;
for (size_t block_row = 0; block_row < C.rows; block_row += BLOCK_SIZE) {
for (size_t block_col = 0; block_col < C.cols; block_col += BLOCK_SIZE) {
u64 reset_z = 1ull << 27; // only set for k == 0
for (size_t k = 0; k < A.cols; k++) {
// Load 32 floats into X and Y registers 0 and 1.
f32 *aa_addr = AA.values + k*AA.cols + block_row;
AMX_LDX((0ull << 56) | (u64)aa_addr);
AMX_LDX((1ull << 56) | (u64)aa_addr + 16);
f32 *b_addr = B.values + k*B.cols + block_col;
AMX_LDY((0ull << 56) | (u64)b_addr);
AMX_LDY((1ull << 56) | (u64)b_addr + 16);
// Do a 32x32 outer product as 4 16x16 outer products:
// [ X0@Y0 X1@Y0 ]
// [ X0@Y1 X1@Y1 ]
AMX_FMA32(reset_z | (0ull << 20) | (0ull << 10) | 0ull);
AMX_FMA32(reset_z | (1ull << 20) | (64ull << 10) | 0ull);
AMX_FMA32(reset_z | (2ull << 20) | (0ull << 10) | 64ull);
AMX_FMA32(reset_z | (3ull << 20) | (64ull << 10) | 64ull);
reset_z = 0;
}
// Read the 32x32 outer product out of the Z registers.
for (u64 i = 0; i < 16; i++) {
u64 reg = i*4ull;
AMX_STZ(
load_store_2
| (reg + 0) << 56
| (u64)(CC.values + (block_col + i + 0)*CC.cols + block_row)
);
AMX_STZ(
load_store_2
| (reg + 0) << 56
| (u64)(CC.values + (block_col + i + 16)*CC.cols + block_row)
);
}
}
}
AMX_CLR();
transpose(C, CC);
}This gives another ~2x improvement, for a total of 6.19E.
7. Faster transposes with SIMD
When we started we justified ignoring the matrix transpose because it took just 1% of the runtime, whereas the naive multiply took 99%. Over the course of this post we’ve cut down that multiply by a factor of >200x, so it’s probably time to look at the transpose. Profiling now confirms: our two transposes together are 66% of the runtime.
We can do a 4x4 transpose with some creative use of Neon byte-juggling instructions.
inline void transpose_4x4_block(f32 *dst, f32 *src, size_t rows, size_t cols) {
float32x4_t row0 = vld1q_f32(src + 0*cols); // 0 1 2 3
float32x4_t row1 = vld1q_f32(src + 1*cols); // 4 5 6 7
float32x4_t row2 = vld1q_f32(src + 2*cols); // 8 9 a b
float32x4_t row3 = vld1q_f32(src + 3*cols); // c d e f
float32x4x2_t mix01 = vtrnq_f32(row0, row1);
// 0 4 2 6
// 1 5 3 7
float32x4x2_t mix23 = vtrnq_f32(row2, row3);
// 8 c a e
// 9 d b f
float32x4_t out0 = vcombine_f32(vget_low_f32(mix01.val[0]), vget_low_f32(mix23.val[0]));
vst1q_f32(dst + 0*rows, out0); // 0 4 8 c
float32x4_t out1 = vcombine_f32(vget_low_f32(mix01.val[1]), vget_low_f32(mix23.val[1]));
vst1q_f32(dst + 1*rows, out1); // 1 5 9 d
float32x4_t out2 = vcombine_f32(vget_high_f32(mix01.val[0]), vget_high_f32(mix23.val[0]));
vst1q_f32(dst + 2*rows, out2); // 2 6 a e
float32x4_t out3 = vcombine_f32(vget_high_f32(mix01.val[1]), vget_high_f32(mix23.val[1]));
vst1q_f32(dst + 3*rows, out3); // 3 7 b f
}Using this as a primitive, we can build up a bigger transpose by walking over the matrix and doing 4x4 block transposes.
// B = A.transpose()
inline void transpose_4x4(RowMatrix &B, RowMatrix &A) {
f32 *dst = B.values;
f32 *src = A.values;
size_t rows = A.rows;
size_t cols = A.cols;
for (size_t j = 0; j < cols; j += 4) {
for (size_t i = 0; i < rows; i += 4) {
transpose_4x4_block(dst + j*rows + i, src + i*cols + j, rows, cols);
}
}
}This speeds up the transposes by a factor of 3x, so that they’re now just 35% of the runtime of the matmul, giving us a 1.8x overall speedup and putting us at 11.06E.
8. Transposes with AMX
And, one more idea: let’s do bigger transposes at once by using AMX. The AMX_EXTRV instruction treats the Z registers collectively as a square matrix of cells, each holding a few values, and moves a column from Z into a single Y register. By loading rows into Z, moving columns into Y, and then reading those columns out to memory as rows, we can transpose a 16x16 block of f32s with 48 instructions: 16 loads, 16 moves, and 16 stores.
inline void transpose_16x16_block(f32 *dst, f32 *src, size_t rows, size_t cols) {
AMX_SET(); {
for (u64 row = 0; row < 16; row++) {
u64 z_reg = row*4;
AMX_LDZ((z_reg << 56) | (u64)(src + row*cols));
}
for (size_t row = 0; row < 16; row++) {
size_t z_column = row * 4;
AMX_EXTRY(
AMX_EXTRV_PLAIN_ZTOY
| AMX_EXTRV_LANEWIDTH_32BIT
| (z_column << 20)
);
AMX_STY((u64)(dst + row*cols));
}
} AMX_CLR();
}
// transpose_16x16() is then very similar to transpose_4x4(),
// just moving over 16x16 blocks and calling our new kernel.This chops our transpose time down by a factor of 4.2x, so that it’s now just 12% of our total matmul time, improving our total time by 1.36x, and leaving us at a final 15.03E.
Over the course of this article we took an initial naive implementation of matrix multiplication and incrementally improved its runtime by 500x by improving cache coherency and making use of all the intrinsics available. Our final implementation was 15x faster than Eigen. That was a nice surprise! Eigen makes use of SIMD, but evidently it does not try to use Apple Silicon’s AMX extensions. It would be interesting to see what it would take to patch Eigen to make use of these extensions when they’re available.
I set out on this exploration to learn how fast matmuls work and to try using SIMD intrinsics. I’m happy to report that I learned a great deal. One more piece of evidence in favor of Matthias Endler’s recommendation to reinvent the wheel.
A great thing about this project has been how at every point there are several more threads to pull on to learn more. Some that I’m left with here at the end:
patch Eigen with to use Apple Silicon AMX
repeat the optimization process for sparse matrices
translate to x86 SIMD, and explore Intel’s AMX
adjust the naive code to convince the compiler to auto-vectorize it
figure out why that SIMD change gave a 4.5x speedup instead of just 4x
do more careful profiling with hardware performance counters to verify or disprove our guesses about why each step helped
get hard numbers for the cost + benefit of doing the matmuls on the GPU.
when does it start making sense? when the matrices are large? when you’ll reuse them some number of times?
explore some of the further optimizations listed in the Algorithmica HPC post on matmuls
Not sure which of these I’ll try. Until next time!
Resources
If you want to find out more about what instructions exist, you can check out either the ARM docs or the instrinsics tree at simd.info.

