Zig NEWS

Kristin Rutenkolk
Kristin Rutenkolk

Posted on

Data driven polymorphism

The last few days I have been writing some small posts on my cohost, trying to come up with my own solution to the question of interfaces in Zig.

So I thought this might be of interest here as well.

One of the most interesting solutions to "interfaces" or "function dispatch" I have come across is the way clojure deals with the problem.

The main "intuition" or goal in mind is to decouple interface definition from implementation. After all, that is what an interface is for, right? We want to specify in some way what we are talking about, and then we want to be able to use all the things that satisfy our specification. It's the dream we surely all had at some point: Being able to combine program parts just as easily as it is snapping together matching lego pieces.

But the reality is often messier and not so dreamy. So I was somewhat amazed when I saw a dynamic language pull it off.

We only really need two things for clojure style multimethods / records. Let's start with multimethods:

  1. The definition of a multimethod (indicated with defmulti)
  2. An implementation of a multimethod (indicated with defmethod)

For the definition of the multimethod we need to decide how we want our multimethod to change behaviour. For that we implement a dispatch function that gives us some result. Let's keep it simple for now:

fn my_dispatch(some_thing: anytype) i32 {
    return some_thing.some_value;
}
const mymulti = defmulti("mymulti", my_dispatch, usize);
Enter fullscreen mode Exit fullscreen mode

Now we have a multimethod, that we can call with anything that has a member some_value of type i32.

Then we need to offer different implementations for individual values returned by our dispatch function.

//some dummy functions
fn impl1(x: anytype) usize {
    _ = x;
    return 5;
}
fn impl2(x: anytype) usize {
    _ = x;
    return 6;
}
const method1 = defmethod(mymulti, 0, impl1);
const method2 = defmethod(mymulti, 1, impl2);
Enter fullscreen mode Exit fullscreen mode

This now set up two implementations, or methods for mymulti.

  • method1 will be used, if our dispatch function returns 0. The argument to our mymulti will be passed to impl1
  • method2 will be used, if our dispatch function returns 1. The argument to our mymulti will be passed to impl2

Looks innocent enough, right?

But the virtue is on how flexible this whole approach is. We can use any type as long as it works with our dispatch function and we can not only dispatch on something like a type tag, but any information we like. It's data driven polymorphism!

test "test defmulti defmethod" {
    const TypeA = struct { some_value: i32 };
    const valA1 = TypeA{ .some_value = 0 };
    const valA2 = TypeA{ .some_value = 1 };

    const TypeB = struct { some_value: i32 };
    const valB1 = TypeB{ .some_value = 0 };
    const valB2 = TypeB{ .some_value = 1 };

    try std.testing.expectEqual(@as(usize, 5), mymulti.call(valA1));
    try std.testing.expectEqual(@as(usize, 6), mymulti.call(valA2));

    try std.testing.expectEqual(@as(usize, 5), mymulti.call(valB1));
    try std.testing.expectEqual(@as(usize, 6), mymulti.call(valB2));
}
Enter fullscreen mode Exit fullscreen mode

We also decouple interface definition from implementation. In addition we do so with no dependency on new types. mymulti could be used in some library and your specific usecase is not supported in the algorithm you really did not want to write yourself. But now you have the option to just extend the functionality. We also never break the behaviour of existing code, we can only extend it. Either there already is an implementation, or you can build your own.

The generated assembly also doesn't look too terrible:

"example.MultiMethod_T(\"mymulti\",fn(anytype) i32,usize).call__anon_318":
        push    rbp
        mov     rbp, rsp
        sub     rsp, 32
        mov     qword ptr [rbp - 32], rdi
        call    example.noice_dispatch__anon_346 // call dispatch fn
        mov     ecx, eax
        mov     dword ptr [rbp - 24], ecx
        mov     dword ptr [rbp - 20], ecx
        xor     eax, eax
        cmp     eax, ecx
        jne     .LBB1_2 // skip impl1 if the dispatch value is different
        mov     rdi, qword ptr [rbp - 32]
        call    example.impl1__anon_634 // call impl1 of correct dispatch value
        mov     qword ptr [rbp - 16], rax
        mov     rax, qword ptr [rbp - 16]
        add     rsp, 32
        pop     rbp
        ret
.LBB1_2: // just past the return of impl1 version
        jmp     .LBB1_3 // why, compiler?
.LBB1_3:
        mov     ecx, dword ptr [rbp - 24]
        mov     eax, 1
        cmp     eax, ecx
        jne     .LBB1_5 // skip impl2 if the dispatch value is different (error handling case)
        mov     rdi, qword ptr [rbp - 32]
        call    example.impl2__anon_636 // call impl2
        mov     qword ptr [rbp - 8], rax
        mov     rax, qword ptr [rbp - 8]
        add     rsp, 32
        pop     rbp
        ret
Enter fullscreen mode Exit fullscreen mode

Let's build a bigger example to get another taste:

const Tuple = std.meta.Tuple;

fn dispatchfn(number_in_any_language: anytype) Tuple(&.{ []const u8, u32 }) {
    if (std.mem.eql(u8, number_in_any_language, "eins")) return .{ "german", 1 };
    if (std.mem.eql(u8, number_in_any_language, "zwei")) return .{ "german", 2 };
    if (std.mem.eql(u8, number_in_any_language, "drei")) return .{ "german", 3 };
    if (std.mem.eql(u8, number_in_any_language, "one")) return .{ "english", 1 };
    if (std.mem.eql(u8, number_in_any_language, "two")) return .{ "english", 2 };
    if (std.mem.eql(u8, number_in_any_language, "three")) return .{ "english", 3 };
    if (std.mem.eql(u8, number_in_any_language, "一")) return .{ "chinese", 1 };
    if (std.mem.eql(u8, number_in_any_language, "二")) return .{ "chinese", 2 };
    if (std.mem.eql(u8, number_in_any_language, "三")) return .{ "chinese", 3 };

    return .{ "idk, lol", 42 };
}

const multi = defmulti("multi ", dispatchfn, usize);
const noah2_method1 = defmethod(multi , .{ "german", 1 }, impl1);
const noah2_method2 = defmethod(multi , .{ "english", 2 }, impl2);

test "test defmulti dispatch on tuples" {
    try std.testing.expectEqual(@as(usize, 5), multi .call("eins"));
    try std.testing.expectEqual(@as(usize, 6), multi .call("two"));
}
Enter fullscreen mode Exit fullscreen mode

This looks shockingly dynamic for a low level language!
I hope this demonstrates the flexibility and power multimethods hold. Unfortunately the implementation is a bit hacky right now, but it is nothing unfixable.

So, we are already deep into the whole dispatch game, what are these protocols about? Well, multimethods have a significant drawback: They are somewhat slow (keeping in mind the above assembly). Being able to dispatch on data is great, but if the possible data is not known, we will inevitably trade a bit of performance.

As a solution, there are protocols. Again, you define a protocol with a name at one place and then you can add different implementations for this protocol wherever you like.
The difference is, we now choose the implementation based on the type.

A lot of the time, this is what already solves our problem, and it can be much faster. Protocols are incredibly useful for any sort of library code. You expose something like a "count" or "length" protocol once, and then you can implement what that means for any type you like, whenever you like.
Oftentimes library code struggles with the notion of generics, because you need to account for all the weird types your users can use your code with.
You often wish you could extend the functionality of a function after the fact, but that's impossible. But why?
There really isn't anything stopping us other than finding out if the implementations exist for the type, where they are located, and then calling them.

With this we can have type based polymorphism and we didn't need a single function overload, we should have no name mangling in the resulting binary, we don't need to create "child" types just because "polymorphism is inheritance, right?" (no).

const myproto = protocol("myproto", usize);
fn impl1(x: f32) usize {
    return @floatToInt(usize, x * 5);
}
const record1 = record(myproto, f32, impl1);

fn impl2(x: u32) usize {
    return @intCast(usize, x + 100);
}
const record2 = record(myproto, u32, impl2);

fn impl3(x: f64) usize {
    return @floatToInt(usize, x * 8);
}
const record3 = record(myproto, f64, impl3);

fn impl4(x: u16) usize {
    return @intCast(usize, x + 27);
}
const record4 = record(myproto, u16, impl4);

fn impl5(x: u8) usize {
    return @intCast(usize, x * 9);
}
const record5 = record(myproto, u8, impl5);

fn impl6(x: bool) usize {
    return @intCast(usize, @boolToInt(x));
}
const record6 = record(myproto, bool, impl6);
Enter fullscreen mode Exit fullscreen mode

At this point, I hope you can guess what the above lines do. record(some_protocol,dispatch_type,impl_function) defines an implementation of some_protocol. impl_function should be called if some_protocol is being called with dispatch_type. Again, we can clearly seperate interface definition from implementation. So this must also come with perf penalties, right?

Lets look at a usage example of "myproto":

export fn lul() usize{
    var v1:u8 = 5;
    var v2:u16 = 5;
    var v3:u32 = 5;
    var v4:f32 = 5;
    var v5:f64 = 5;
    var v6:bool = false;

    return myproto.call(v1)+myproto.call(v2)+myproto.call(v3)+myproto.call(v4)+myproto.call(v5)+myproto.call(v6);
}
Enter fullscreen mode Exit fullscreen mode

We can see, the same "myproto" is being called with a variety of types. What is the resulting assembly?

lul:
        push    rbp
        mov     rbp, rsp
        sub     rsp, 240
        mov     byte ptr [rbp - 195], 5
        mov     word ptr [rbp - 194], 5
        mov     dword ptr [rbp - 192], 5
        mov     dword ptr [rbp - 188], 1084227584
        movabs  rax, 4617315517961601024
        mov     qword ptr [rbp - 184], rax
        mov     byte ptr [rbp - 170], 0
        movzx   edi, byte ptr [rbp - 195]
        mov     al, dil
        mov     byte ptr [rbp - 169], al
        call    example.impl5
        mov     qword ptr [rbp - 168], rax
        mov     rax, qword ptr [rbp - 168]
        mov     qword ptr [rbp - 208], rax
        movzx   edi, word ptr [rbp - 194]
        mov     ax, di
        mov     word ptr [rbp - 154], ax
        call    example.impl4
        mov     rcx, rax
        mov     rax, qword ptr [rbp - 208]
        mov     qword ptr [rbp - 152], rcx
        mov     rcx, qword ptr [rbp - 152]
        add     rax, rcx
        mov     qword ptr [rbp - 144], rax
        setb    byte ptr [rbp - 136]
        mov     al, byte ptr [rbp - 136]
Enter fullscreen mode Exit fullscreen mode

The calls to the actual implementation functions get inlined! Because the type is statically resolvable, we can select the correct function at compile time and don't even need to do any checks at runtime.
This is probably as fast as it gets. We could have just called the implementation functions ourselves, but we didn't need to.

The flipside to the "write a generic algorithm in a library" is also interesting. Write something using only protocols and you can just use types it wasn't designed for.

You might be disappointed, that the protocols aren't dispatched at runtime. Don't. For a next version I will try my hands on that. Ideally we can make the compiler generate a jump table for us when we make good use of "prongs" (inline switches). However I think static protocols themselves could already be something quite useful.

As a bonus: Since we are at compile time, We can make the compiler tell us, when an implementation does not exist, before calling the protocol:

$ zig build test -freference-trace
src\protocols.zig:46:21: error: protocol protocols.Protocol_T("myproto",usize) has no implementation for type comptime_float
                    @compileError(msg);
                    ^~~~~~~~~~~~~~~~~~
referenced by:
    test.test protocol record: src\protocols.zig:282:55
Enter fullscreen mode Exit fullscreen mode

I did not show implementation details, but I made heavy use of a lot of zigs type system features, compile time reflection and standard library, sometimes maybe not in the intended way 😉

I hope you liked this article about dynamic and static dispatch, how a data driven way is possible and what that could look like in zig. Maybe you share my enthusiasm, maybe you feel unwell that I made Zig do that. Please tell me!

edit: reworded a paragraph because I duplicated some things, whoops-

Top comments (5)

Collapse
 
maldus512 profile image
Mattia Maldini

I would be interested in the implementation details of this exploit. How/where do you store the association between the type and the dispatched behaviour? Do you use the string representation of the type as key in an Hashmap?

Collapse
 
rutenkolk profile image
Kristin Rutenkolk

Hi, thanks for the comment.

I want to clean up the code a bit before I make a comprehensive "how?" post, but sure, I can explain whats going on now. It's just that this post already got kinda long.

In the case of protocols, the association between types and implementation function is determined by a comptime lookup of all instances of "records" where we passed the protocol as an argument. Since every file is a struct, we can give in the file where our implementations are as a compile time type and iterate over every declaration in it.

If a declaration has a matching type (being DefRecord_T(this_protocol) ), then the declaration is an implementation of the protocol and we save it in a comptime array.

// pretend that T is @This() or @import("root") or something like this.
fn find_all_implementing_records(comptime T: type, comptime proto: anytype, res: []DefRecord_T(proto)) void {
    const decls = @typeInfo(T).Struct.decls;
    comptime var i: usize = 0;
    for (decls) |decl| {
        std.debug.assert(@TypeOf(decl) == std.builtin.Type.Declaration);

        comptime var declref = @field(T, decl.name);
        comptime var implements_proto = implements_protocol(declref, proto);

        if (implements_proto) {
            res[i] = declref;
            i = i + 1;
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

The basic Idea now is: If the protocol is called with an argument of type Arg_T, go through the array of implementations at compile time and determine the index of an implementation with dispatch_type == Arg_T. From that, get the function pointer to the implementation function. This can all happen at comptime.

The only thing that happens at runtime is to call the appropriate implementation function.

The static case is 100% comptime.

            comptime var fn_val = res[type_index].impl_fn.to_typed_fn(x_type, self);
            return fn_val(x);
Enter fullscreen mode Exit fullscreen mode

Multimethods work similarly right now, but the equal checks for the dispatch values are simply an unrolled for loop of the comptime implementations right now. This is certainly something that needs to improve, but it works as a first proof of concept.

Also... the equality checks are really hacky right now. Maybe you asked yourself how did I manage to make the multimethods work with Tuples... Well, there isn't any generic structural equals in zig, is there? std.meta.eql doesn't recursively look into slices for example, right? Soooo, right now I'm actually using std.testing.expectEqualDeep, which does exactly that. But it definitely needs to be replaced before I feel remotely comfortable with the state the code is in.

Collapse
 
maldus512 profile image
Mattia Maldini

Uh, quite cool. Some details are still hazy but now I get the gist of it. Thanks!
Looking forward a more detailed explanation!

Collapse
 
maxzhao profile image
max

这个想法很好啊,期待你的最终成果

Collapse
 
rutenkolk profile image
Kristin Rutenkolk

谢谢,我会快写的 (请原谅我的中文。我不是中国人)