Zig NEWS

Cover image for Using comptime to invert bijective functions on enums
Riccardo Binetti
Riccardo Binetti

Posted on • Updated on

Using comptime to invert bijective functions on enums

Hi everyone! This is my first post on Zig News, and I've decided to break the ice with this bikeshed I ended up into while working on the Advent of Code 2022.

We have the usual "Rock Paper Scissors" game, and given a Shape we need to be able to retrieve the Shape that it beats and the Shape it's beaten by.

pub const Shape = enum(u8) {
    rock,
    paper,
    scissors,

    pub fn beats(self: Shape) Shape {
        return switch (self) {
            .rock => .scissors,
            .paper => .rock,
            .scissors => .paper,
        };
    }

    pub fn beatenBy(self: Shape) Shape {
        return switch (self) {
            .rock => .paper,
            .paper => .scissors,
            .scissors => .rock,
        };
    }
}
Enter fullscreen mode Exit fullscreen mode

This works, but it's very error prone because we don't have any guarantee that the second function is actually the inverse of the first one.

Moreover, if the "Rock Paper Scissors" ISO Committee decides to change the rules for RockPaperScissors23 (now with Atomics!), we have two places where our code needs to change.

Can we do better? Using the power of comptime, I think we can.

pub const Shape = enum(u8) {
    rock,
    paper,
    scissors,

    pub fn beats(self: Shape) Shape {
        return switch (self) {
            .rock => .scissors,
            .paper => .rock,
            .scissors => .paper,
        };
    }

    pub fn beatenBy(self: Shape) Shape {
        switch (self) {
            inline else => |shape| {
                const winner = comptime blk: {
                    inline for (@typeInfo(Shape).Enum.fields) |field| {
                        const other_shape = @intToEnum(Shape, field.value);
                        if (other_shape.beats() == shape) {
                            break :blk other_shape;
                        }
                    }
                    unreachable;
                };
                return winner;
            },
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

As you can see, now beatenBy is generated at comptime starting from beats, and this allows us to concentrate all changes only in the beats function.

In fact, we can also generalize the concept and create a comptime function that returns the inverse of an arbitrary bijective function from an enum domain T1 to an enum codomain T2:

fn invert(comptime T1: type, comptime T2: type, comptime originalFn: fn (a: T1) T2) fn (T2) T1 {
    if (@typeInfo(T1).Enum.fields.len != @typeInfo(T2).Enum.fields.len) {
        @compileError("Trying to invert a non-bijective function: domain " ++
            @typeName(T1) ++ " and codomain " ++ @typeName(T2) ++ " have different sizes");
    }

    return struct {
        fn function(self: T2) T1 {
            switch (self) {
                inline else => |input| {
                    const inverse = comptime blk: {
                        inline for (@typeInfo(T1).Enum.fields) |field| {
                            const candidate = @intToEnum(T1, field.value);
                            if (originalFn(candidate) == input) {
                                break :blk candidate;
                            }
                        }
                        @compileError("Trying to invert a non-bijective function: " ++
                            @tagName(input) ++ " is not contained in the codomain");
                    };
                    return inverse;
                },
            }
        }
    }.function;
}
Enter fullscreen mode Exit fullscreen mode

This reduces the creation of beatenBy to

const beatenBy = invert(Shape, Shape, beats);
Enter fullscreen mode Exit fullscreen mode

Using this also gives us an additional advantage: the comptime guarantee that the function is bijective. You can use this code to test this quickly:

pub const Shape = enum {
    square,
    triangle,
    circle,

    pub fn toColor(self: Shape) Color {
        return switch (self) {
            .square => .red,
            .triangle => .green,
            .circle => .blue,
        };
    }
};

pub const Color = enum {
    red,
    green,
    blue,

    pub const toShape = invert(Shape, Color, Shape.toColor);
};
Enter fullscreen mode Exit fullscreen mode

If you make two different Shapes return the same color in toColor, the compilation will fail with:

src/main.zig:20:25: error: Trying to invert a non-bijective function: blue is not contained in the codomain
                        @compileError("Trying to invert a non-bijective function: " ++
                        ^~~~~~~~~~~~~
Enter fullscreen mode Exit fullscreen mode

And if you add an extra element either to the Shape enum or to the Color enum, the compilation will fail with:

src/main.zig:5:9: error: Trying to invert a non-bijective function: domain main.Shape and codomain main.Color have different sizes
        @compileError("Trying to invert a non-bijective function: domain " ++
        ^~~~~~~~~~~~~
Enter fullscreen mode Exit fullscreen mode

I'm aware that the "Rock Paper Scissors" problem could've been solved in a different way (e.g. representing the "beats" relationship with a circular buffer and looking at the element after or before you), but I took the occasion to make myself a little more comfortable with comptime.

Let me know if you have any suggestion or correction in the comments and happy Zigging!

Top comments (6)

Collapse
 
slavka profile image
slavka

I don't quite understand, is it necessary to use this switch + inline else construct? If so, why?

Collapse
 
rbino profile image
Riccardo Binetti • Edited

This bit me in my first implementation too so I'll expand how I got from there to the final code.

My first implementation was this:

pub fn beatenBy(self: Shape) Shape {
    const winner = blk: {
        inline for (@typeInfo(Shape).Enum.fields) |field| {
            const other_shape = @intToEnum(Shape, field.value);
            if (other_shape.beats() == self) {
                break :blk other_shape;
            }
        }
        unreachable;
    };
    return winner;
}
Enter fullscreen mode Exit fullscreen mode

This seems to work, but the problem is that it does its work at runtime. What happens is that inline for unrolls the for loop scanning through all Shapes at compile time, but the comparison happens as runtime, so the produced code (after comptime) should look something like this:

pub fn beatenBy(self: Shape) Shape {
    const winner = blk: {
      if (Shape.rock.beats() == self) break :blk .rock;
      if (Shape.paper.beats() == self) break :blk .paper;            
      if (Shape.scissors.beats() == self) break :blk .scissors;
      unreachable;
    };
    return winner;
}
Enter fullscreen mode Exit fullscreen mode

The problem with this is that it does not give the comptime guarantees I detailed above. Indeed if you try this with the Shapes and Colors example, it will compile without complaining even if the function is not bijective and only fail at runtime (hitting the unreachable above).

If we want the compile time guarantees, we need to add comptime before :blk, thus making that block evaluated at compile time. But if you try to do just that to the code above, the compiler will fail with:

src/main.zig:14:47: error: unable to resolve comptime value
            if (other_shape.beats() == self) {
            ~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~
Enter fullscreen mode Exit fullscreen mode

This is because we're trying to use self, which is a runtime value, in a comptime block. We want instead to handle all the possible values that self could have in the same place, and that's exactly what inline else is for.

If we expand the inline else part of comptime, the code in the original post looks something like this:

pub fn beatenBy(self: Shape) Shape {
    switch (self) {
        .rock => {
            const winner = comptime blk: {
                inline for (@typeInfo(Shape).Enum.fields) |field| {
                    const other_shape = @intToEnum(Shape, field.value);
                    if (other_shape.beats() == .rock) {
                        break :blk other_shape;
                    }
                }
                unreachable;
            };
            return winner;
        },
        .paper => {
            const winner = comptime blk: {
                inline for (@typeInfo(Shape).Enum.fields) |field| {
                    const other_shape = @intToEnum(Shape, field.value);
                    if (other_shape.beats() == .paper) {
                        break :blk other_shape;
                    }
                }
                unreachable;
            };
            return winner;
        },
        .scissors => {
            const winner = comptime blk: {
                inline for (@typeInfo(Shape).Enum.fields) |field| {
                    const other_shape = @intToEnum(Shape, field.value);
                    if (other_shape.beats() == .scissors) {
                        break :blk other_shape;
                    }
                }
                unreachable;
            };
            return winner;
        },
    }
}
Enter fullscreen mode Exit fullscreen mode

But since now all 3 values are comptime-known, comptime :blk can actually get executed at compile time and the final resulting code is:

pub fn beatenBy(self: Shape) Shape {
    switch (self) {
        .rock => {
            const winner = .paper;
            return winner;
        },
        .paper => {
            const winner = .scissors;
            return winner;
        },
        .scissors => {
            const winner = .rock;
            return winner;
        },
    }
}
Enter fullscreen mode Exit fullscreen mode

I hope this makes it clearer (I might consider integrating this in the original post).

Collapse
 
slavka profile image
slavka

Oh, I see, so the value between the | signs in the inline else construct is associated with the "current iteration over all possible enum values", which actually makes sense. Thanks!

I suppose the article could benefit from some version of this explanation, since it highlights a less obvious point about this feature of the language, especially if the reader hasn't had an opportunity to use it.

By the way, is the last snippet in your comment written by hand? If not, it would be cool to see how it was generated, otherwise it might be a good idea to include the generated machine code or something like that to verify that there's no actual work done at runtime.

Thread Thread
 
rbino profile image
Riccardo Binetti

It's written by hand to give the intuition of how that works, but if you look at the output of the code on GodBolt and search into the assembly, you can see that the beats functions is never called, and the beatenBy functions is translated to a number of comparisons, jumps and movs of constant (the winners calculated at compile time).

Collapse
 
kristoff profile image
Loris Cro

This is one use of fancy meta programming I can get behind!

Collapse
 
odalet profile image
odalet

Oh! Nice!