I won't bore you with the backstory (and I didn't want to draw a bunch of 3d diagrams), so I'll jump straight into the optimizing without explaining anything about the program that motivated said optimizations.
Anywho, we have this packed struct:
const Foo = packed struct(u32) {
a: u4,
_0: u4,
b: u4,
_1: u4,
c: u4,
_2: u4,
d: u4,
_3: u4,
};
The routine we want to optimize is comparing two instances of Foo
- we'll call them left
and right
- to see if the fields a
, b
, c
, and d
of left
were all greater than the corresponding fields of right
.
Also, some facts that will come in handy later:
- The fields prefixed with
_
aren't used for anything - they're purely padding. We can set them to whatever values we desire. - We know whether an instance of
Foo
will beleft
orright
at initialiation, which occurs at compile time.
Here's the simplest way to write a function that does such an operation:
pub fn cmpScalar(left: Foo, right: Foo) bool {
return left.a >= right.a and
left.b >= right.b and
left.c >= right.c and
left.d >= right.d;
}
Let's take a look at the assembly in godbolt. There's a lot of and cl/dl, 15
going on, which just serves to mask off the bits of left._0
, left._1
, right._0
, etc. If we were to ditch the underscored fields and use u8
s, the result gets better.
Since we can put whatever we want in the fields with underscores, let's just initialize them to 0 at comptime (not shown), and use asserts to make LLVM assume those fields are 0. We'll add a method to Foo
that does so - and mark it inline, just to be sure:
pub inline fn assertZeroes(self: @This()) void {
std.debug.assert(self._0 == 0);
std.debug.assert(self._1 == 0);
std.debug.assert(self._2 == 0);
std.debug.assert(self._3 == 0);
}
and call this method in cmpScalar
on both left
and right
.
Looking in godbolt - nothing changed! Bummer. Perhaps LLVM isn't too adept at reasoning about exotically sized integers - maybe Zig's own backends will do better in the future?
Rather than try to optimize this scalar version further, let's just SIMDify our code to make it compare all fields at once:
pub fn cmpSimd(left: Foo, right: Foo) bool {
const left_vec: @Vector(4, u4) = .{ left.a, left.b, left.c, left.d };
const right_vec: @Vector(4, u4) = .{ right.a, right.b, right.c, right.d };
return @reduce(.And, left_vec >= right_vec);
}
And here's the godbolt link.
The codegen this time is even worse, and performs even worse than the scalar version. The reason is that in Zig, vectors are packed - e.g. a @Vector(4, u4)
actually fits in 16 bits, not 32. In our case, LLVM seems to able to partially mitigate this with optimizations, but when compiling without -OReleaseFast, I suspect LLVM first has to pack Foo.a
, Foo.b
, Foo.c
, and Foo.d
into 16 bits, then unpack them into a 4 byte vector so it actually operate on them with SIMD instructions.
There is a simple fix. Just avoid using @Vector
s of exotically sized ints:
pub fn cmpSimd(left: Foo, right: Foo) bool {
// Changed from @Vector(4, u4) to @Vector(4, u8)
const left_vec: @Vector(4, u8) = .{ @intCast(left.a), @intCast(left.b), @intCast(left.c), @intCast(left.d) };
const right_vec: @Vector(4, u8) = .{ @intCast(right.a), @intCast(right.b), @intCast(right.c), @intCast(right.d) };
return @reduce(.And, left_vec >= right_vec);
}
Again, here's the godbolt link.
There's still some bitmasking going on, so let's also add back our assertZeros()
method calls.
This is a fair improvement over our scalar code, but we can do better - the title is "To SIMD and beyond", after all! Instead of using SIMD, let's try to write our routine using scalar operations, but in such a way that we test all fields simultaneously. This is known as SIMD-within-a-register, or SWAR.
How do we test if one number is greater or equal to the other? One way: Calculate a - b
, and see if overflow occurs. If it didn't, then indeed, a >= b
. So we just need a way to
- Perform 4 subtractions of 4 bit integers independently, and
- Test if overflow occured in all 4 subtractions.
We'll start by observing the memory layout of Foo
in bits:
0000dddd0000cccc0000bbbb0000aaaa
We can prevent the result of the subtraction of each field from stealing carry bits from the next by inserting overflow guard bits before each field in left
, so the bitwise representation of left | overflow_guard
becomes
0001dddd0001cccc0001bbbb0001aaaa
After performing the subtraction, we can simply mask off everything but the overflow guard, and check if all of the bits are still set. If so, then none of the subtractions overflowed, which means that every field of left
>= the same field in right
.
pub fn cmpSwar(left: Foo, right: Foo) bool {
const left_bits: u32 = @bitCast(left);
const right_bits: u32 = @bitCast(right);
// add overflow guard bits to left, and subtract
const overflow_guard = 0b00010000_00010000_00010000_00010000;
const diff = (left_bits | overflow_guard) - right_bits;
// mask off everything except the overflow guard bits
const masked_diff = diff & overflow_guard;
// check if all the overflow guard bits are still set
return masked_diff == overflow_guard;
}
There's one more optimization to make: Since we know which Foo
s will be left
and which will be right
at comptime, let's just preemptively add the mask bits to the former group. Now we can simply
const diff = (left_bits | overflow_guard) - right_bits;
to
const diff = left_bits - right_bits;
This saves a whole one instruction! Here's what our comparison routine looks like now.
Let's see if our optimizing has paid off.
The benchmark code is at zigbin.io. Here are the results on my laptop cpu (your mileage may vary, especially if you have AVX-512):
Inlining: true
Scalar runtime: 40.5288 ms
Simd runtime: 51.8806 ms
Swar runtime: 20.8467 ms
Optimized swar runtime: 14.0681 ms
Inlining: false
Scalar runtime: 273.1669 ms
Simd runtime: 73.0247 ms
Swar runtime: 74.3847 ms
Optimized swar runtime: 69.0435 ms
With inlining disabled, the optimized SWAR version is barely faster, but with inlining, the speedup is over 200%. Even more curiously, the scalar code is dramatically slower than the SIMD code in the run without inlining, but much faster with inlining!
The reason for both results, somewhat ironically, is auto-vectorization: LLVM isn't able to vectorize the loop using cmpSimd
, because it already uses vectors registers; There's no such thing as a vector of vectors. Since the SWAR and scalar routines don't use vector registers, they can benefit immensely from auto-vectorization.
Top comments (3)
This kind of bitwise tricks can sometimes be very handy indeed. If I'm not mistaken, the book Hacker's Delight is full of this kind of trick. That might interest some of the readers of your nice post ! 😄
Great post, thanks for sharing. The subject matter was interesting in its own right, but I also really appreciated the writing style: clear and to the point.
I'm thinking if you could "vectorize the call to cmpSimd" by comparing 8 of them at the same time (with u256)