Zig NEWS

InspectorBoat
InspectorBoat

Posted on • Updated on

To SIMD and beyond: Optimizing a simple comparison routine

I won't bore you with the backstory (and I didn't want to draw a bunch of 3d diagrams), so I'll jump straight into the optimizing without explaining anything about the program that motivated said optimizations.

Anywho, we have this packed struct:

const Foo = packed struct(u32) {
    a: u4,
    _0: u4,
    b: u4,
    _1: u4,
    c: u4,
    _2: u4,
    d: u4,
    _3: u4,
};
Enter fullscreen mode Exit fullscreen mode

The routine we want to optimize is comparing two instances of Foo - we'll call them left and right - to see if the fields a, b, c, and d of left were all greater than the corresponding fields of right.
Also, some facts that will come in handy later:

  • The fields prefixed with _ aren't used for anything - they're purely padding. We can set them to whatever values we desire.
  • We know whether an instance of Foo will be left or right at initialiation, which occurs at compile time.

Here's the simplest way to write a function that does such an operation:

pub fn cmpScalar(left: Foo, right: Foo) bool {
    return left.a >= right.a and
        left.b >= right.b and
        left.c >= right.c and
        left.d >= right.d;
}
Enter fullscreen mode Exit fullscreen mode

Let's take a look at the assembly in godbolt. There's a lot of and cl/dl, 15 going on, which just serves to mask off the bits of left._0, left._1, right._0, etc. If we were to ditch the underscored fields and use u8s, the result gets better.

Since we can put whatever we want in the fields with underscores, let's just initialize them to 0 at comptime (not shown), and use asserts to make LLVM assume those fields are 0. We'll add a method to Foo that does so - and mark it inline, just to be sure:

pub inline fn assertZeroes(self: @This()) void {
    std.debug.assert(self._0 == 0);
    std.debug.assert(self._1 == 0);
    std.debug.assert(self._2 == 0);
    std.debug.assert(self._3 == 0);
}
Enter fullscreen mode Exit fullscreen mode

and call this method in cmpScalar on both left and right.

Looking in godbolt - nothing changed! Bummer. Perhaps LLVM isn't too adept at reasoning about exotically sized integers - maybe Zig's own backends will do better in the future?

Rather than try to optimize this scalar version further, let's just SIMDify our code to make it compare all fields at once:

pub fn cmpSimd(left: Foo, right: Foo) bool {
    const left_vec: @Vector(4, u4) = .{ left.a, left.b, left.c, left.d };
    const right_vec: @Vector(4, u4) = .{ right.a, right.b, right.c, right.d };
    return @reduce(.And, left_vec >= right_vec);
}
Enter fullscreen mode Exit fullscreen mode

And here's the godbolt link.
The codegen this time is even worse, and performs even worse than the scalar version. The reason is that in Zig, vectors are packed - e.g. a @Vector(4, u4) actually fits in 16 bits, not 32. In our case, LLVM seems to able to partially mitigate this with optimizations, but when compiling without -OReleaseFast, I suspect LLVM first has to pack Foo.a, Foo.b, Foo.c, and Foo.d into 16 bits, then unpack them into a 4 byte vector so it actually operate on them with SIMD instructions.

There is a simple fix. Just avoid using @Vectors of exotically sized ints:

pub fn cmpSimd(left: Foo, right: Foo) bool {
    // Changed from @Vector(4, u4) to @Vector(4, u8)
    const left_vec: @Vector(4, u8) = .{ @intCast(left.a), @intCast(left.b), @intCast(left.c), @intCast(left.d) };
    const right_vec: @Vector(4, u8) = .{ @intCast(right.a), @intCast(right.b), @intCast(right.c), @intCast(right.d) };
    return @reduce(.And, left_vec >= right_vec);
}
Enter fullscreen mode Exit fullscreen mode

Again, here's the godbolt link.
There's still some bitmasking going on, so let's also add back our assertZeros() method calls.

This is a fair improvement over our scalar code, but we can do better - the title is "To SIMD and beyond", after all! Instead of using SIMD, let's try to write our routine using scalar operations, but in such a way that we test all fields simultaneously. This is known as SIMD-within-a-register, or SWAR.
How do we test if one number is greater or equal to the other? One way: Calculate a - b, and see if overflow occurs. If it didn't, then indeed, a >= b. So we just need a way to

  1. Perform 4 subtractions of 4 bit integers independently, and
  2. Test if overflow occured in all 4 subtractions.

We'll start by observing the memory layout of Foo in bits:
0000dddd0000cccc0000bbbb0000aaaa
We can prevent the result of the subtraction of each field from stealing carry bits from the next by inserting overflow guard bits before each field in left, so the bitwise representation of left | overflow_guard becomes
0001dddd0001cccc0001bbbb0001aaaa
After performing the subtraction, we can simply mask off everything but the overflow guard, and check if all of the bits are still set. If so, then none of the subtractions overflowed, which means that every field of left >= the same field in right.

pub fn cmpSwar(left: Foo, right: Foo) bool {
    const left_bits: u32 = @bitCast(left);
    const right_bits: u32 = @bitCast(right);

    // add overflow guard bits to left, and subtract
    const overflow_guard = 0b00010000_00010000_00010000_00010000;
    const diff = (left_bits | overflow_guard) - right_bits;

    // mask off everything except the overflow guard bits
    const masked_diff = diff & overflow_guard;

    // check if all the overflow guard bits are still set
    return masked_diff == overflow_guard;
}
Enter fullscreen mode Exit fullscreen mode

There's one more optimization to make: Since we know which Foos will be left and which will be right at comptime, let's just preemptively add the mask bits to the former group. Now we can simply

const diff = (left_bits | overflow_guard) - right_bits;
Enter fullscreen mode Exit fullscreen mode

to

const diff = left_bits - right_bits;
Enter fullscreen mode Exit fullscreen mode

This saves a whole one instruction! Here's what our comparison routine looks like now.

Let's see if our optimizing has paid off.

The benchmark code is at zigbin.io. Here are the results on my laptop cpu (your mileage may vary, especially if you have AVX-512):

Inlining: true
Scalar runtime: 40.5288 ms
Simd runtime: 51.8806 ms
Swar runtime: 20.8467 ms
Optimized swar runtime: 14.0681 ms
Inlining: false
Scalar runtime: 273.1669 ms
Simd runtime: 73.0247 ms
Swar runtime: 74.3847 ms
Optimized swar runtime: 69.0435 ms

With inlining disabled, the optimized SWAR version is barely faster, but with inlining, the speedup is over 200%. Even more curiously, the scalar code is dramatically slower than the SIMD code in the run without inlining, but much faster with inlining!

The reason for both results, somewhat ironically, is auto-vectorization: LLVM isn't able to vectorize the loop using cmpSimd, because it already uses vectors registers; There's no such thing as a vector of vectors. Since the SWAR and scalar routines don't use vector registers, they can benefit immensely from auto-vectorization.

Top comments (3)

Collapse
 
corendos profile image
Corentin Godeau

This kind of bitwise tricks can sometimes be very handy indeed. If I'm not mistaken, the book Hacker's Delight is full of this kind of trick. That might interest some of the readers of your nice post ! 😄

Collapse
 
kristoff profile image
Loris Cro

Great post, thanks for sharing. The subject matter was interesting in its own right, but I also really appreciated the writing style: clear and to the point.

Collapse
 
voilaneighbor profile image
Fifnmar

I'm thinking if you could "vectorize the call to cmpSimd" by comparing 8 of them at the same time (with u256)