Zig NEWS

Cover image for Zig Erasure Coding -- WTF is Zig Comptime 2 (Part 1)
Ed Yu
Ed Yu

Posted on • Updated on

Zig Erasure Coding -- WTF is Zig Comptime 2 (Part 1)

The power and complexity of Comptime in Zig


Ed Yu (@edyu on Github and
@edyu on Twitter)
January.17.2024


Zig Logo

Introduction

Zig is a modern systems programming language and although it claims to a be a better C, many people who initially didn't need systems programming were attracted to it due to the simplicity of its syntax compared to alternatives such as C++ or Rust.

However, due to the power of the language, some of the syntaxes are not obvious for those first coming into the language. I was actually one such person.

Today we will explore a unique aspect of metaprogramming in Zig that sets it apart from other similarly low-level languages -- comptime.

I've already written an earlier article on comptime but today I want to explore an example using a simplified version of erasure coding that I've implemented using Zig with extensive use of comptime.

Due to the length of the code example, I'll explain the code in a series of articles and eventually even compare it to an Odin implementation.

WTF is Comptime

I gave an overview of Zig comptime in WTF Zig Comptime. In short, comptime is a way to designate either a variable or a piece of code such as a block or a function as something to be run at compile time as opposed to runtime.

The first benefit of running a these comptime code at comptime is so that the results of comptime are essentially constants that can be used directly by the compiled program. As a result, comptime does not use runtime resources for these any results.

One example is the same factorial code I showed in WTF Zig Comptime. In such a case, even if I pass in a large number n, which takes a long time to calculate at comptime, during runtime it's just another constant.

By using a comptime variable, I still have the flexibility to change the number anytime and just recompile the program to recalculate the correct result.

pub fn factorial(comptime n: u8) comptime_int {
    var r = 1; // no need to write it as comptime var r = 1
    for (1..(n + 1)) |i| {
        r *= i;
    }
    return r;
}

const v = comptime factorial(120);
Enter fullscreen mode Exit fullscreen mode

The other benefit is actually due to necessity of metaprogramming. One of the primary reasons to use metaprogramming is to allow for different behaviors for different types.

Because you have to know precisely how much memory to allocate at compile time, Zig types must be complete at compile type. What it means is that if you want to use a type function (once again, see WTF Zig Comptime) to create different types based on some variable you pass, that variable must be comptime.

One example is say you want to create a matrix type and that a matrix type of 3 x 3 should be completely separate from a matrix type of 4 x 4 in that you cannot simply treat them in the same way. However, you do want to use the same type function to create both matrices, then you must declare the n in n x n matrix as a comptime variable that you pass in.

// a square matrix is a matrix where number_of_rows == number_of_cols
pub fn SquareMatrix(comptime n: comptime_int) type {
    return struct {
        allocator: std.mem.Allocator,
        comptime numRows: u8 = n, 
        comptime numCols: u8 = n, 

        const Self = @This(); 

        pub fn init(allocator: std.mem.Allocator) Self {
            return .{
                .allocator = allocator,
            }
        }

        // other methods
    }
}
Enter fullscreen mode Exit fullscreen mode

WTF is Erasure Coding

Now, with the introduction out of the way, let's talk about Erasure Coding.

Erasure Coding is a way to split up your data in multiple pieces (shards) so that if you lose one or more of the pieces, you can still retrieve the original data. If you think this sounds like duplication or replication, or even RAID, then you are somewhat correct. The difference is that in a simple replication (or RAID 1), you make complete copies of the data so that if you make 2 copies of the data, you have a total of 3 copies (including the original copy) and that as long as you don't lose all 3 copies, you can get the data back. In a more complex RAID setup such as RAID 5, data is split across (stripe) multiple disks and then additional parity blocks are added to tolerate disk loss. However, in the case of RAID 5, you can tolerate 1 disk loss, whereas RAID 6 can tolerate up to 2 disk failures.

In Erasure Coding, you transform the data while you split the data so that the copies are not necessarily the same. You still need a minimum number of copies to reconstruct the original data and as long as you don't lose more than such minimal number, you can reconstruct the data.

We typically describe Erasure Coding with 2 numbers of N and K where N is the total number of shards, and K is the minimum number of shards you need in order to reconstruct the data. For instance, in a (5, 3) Erasure Coding, N is 5 and K is 3, you need 3 of the 5 shards in order to reconstruct the data. In other words, you can tolerate the loss of up to (N - K) = 2 shards.

What makes Erasure Coding better than a simple replication is that you save more disk spaces by using Erasure Coding compared to simple replication.

In a quick test, using a file of 680 bytes, if I make 2 additional copies, the total storage needed is 680 * 3 = 2040 bytes. If instead I use my Zig Erasure Coding, I need 5 copies of 240 bytes each which means the total storage needed is 240 * 5 = 1020 bytes. The total saving in storage comparing the two is (2040 - 1020) / 2040 = 50%.

Of course the storage saving is not free as there is extra computation involved both in encoding (splitting the original data to shards) and decoding (reconstruct the data from shards).

Erasure Code Example

The example Erasure Coding is based upon Erasure Coding For The Masses by Dr. Vishesh Khemani.

The method described is not the most efficient Erasure Coding but it's easy to understand and I really liked how Dr. Vishesh Khemani used simple finite field arithmetic to describe Erasure Coding. finite field arithmetic to describe Erasure Coding.

Finite (Binary) Field Arithmetic

If you are mathmetically inclined, you are certainly welcome to read the Wikipedia article or in particular binary fields. I'm not but I'll try my best to give a simple explanation of binary fields from a programmer's perspective.

When you have a finite field, you are basically doing modular maths, which basically means that you have to take the modulo of the modulus (usually 1 more than then largest allowed number). For example, for a finite field of 3, you only have [0, 1, 2] as valid numbers, the modulus is 3 (1 more than the largest number which is 2). So if you add 1 and 2 together, you normally would get 3 without modular arithmetic but because it's modulo maths, you have to modulo 3 which is 3 mod 3 = 0. For us programmers, this is very similar to how a ring buffer works.

In the case of binary fields, the maths is even easier as you only have [0, 1] as valid numbers and addition is really just XOR because 1 + 1 = 2 mod 2 = 0 = 1 XOR 1. What's even cooler is that because -1 = 1 mod 2 (try going counter-clockwise on your ring buffer), we can do subtraction just as we do addition because n - 1 = n + (-1) = n + 1.

Code: matrix.zig

Ok, that's probably too much maths already. Let's go back to the code example.

The first code I want to describe is an implementation of m x n matrix where m is the number of rows and n is the number of columns.

I decided to use type function to return the matrix type because I wanted to make sure the compiler can enforce that a 3 x 2 matrix is different from a 2 x 3 matrix.

The way to do so is to pass in both m and n as comptime variables to the type function.

pub fn Matrix(comptime m: comptime_int, comptime n: comptime_int) type {
    return struct {
        allocator: std.mem.Allocator = undefined,
        mdata: std.ArrayList(u8) = undefined,
        mtype: DataOrder = undefined,
        comptime numRows: u8 = m,
        comptime numCols: u8 = n,

        const Self = @This();

        // followed by type methods
    }
}
Enter fullscreen mode Exit fullscreen mode

Do note that I ended up saving m and n as fields inside the Matrix(m, n) type as numRows and numCols. Zig doesn't require you to only access fields as methods and by exposing them as fields, I can use mat.numRow and mat.numCols directly if mat is a Matrix(m, n). I do have to designate both numRows and numCols as comptime in order to store m and n.

The code should be fairly straightforward but I do want to explain why I have an enum called DataOrder. The internal data format is an array using ArrayList called mdata. Typically, you can store a 2D matrix as 1D vector (or ArrayList in Zig) either in row-major or column-major order. The original reason to have DataOrder is so that I can have the flexibility to store the matrix data in either order and then transpose is really just changing the mtype from one to the other.

In the end, not only I ended up not needing such flexibiltiy, but it made my code much more complicated. I kept it there mostly as a reminder to myself to not do premature optimization.

Premature Optimization is the root of all evil
                            -- Donald Knuth
Enter fullscreen mode Exit fullscreen mode

One artifact of such complication is that getSlice() returns a vector based on whether the matrix is stored internally in row-major or colume-major. The burden is on the caller to keep track of of the DataOrder because the caller is the one that initialized the matrix. Of course, getRow() and getCol() are exposed and preferred.

// return eithers a row or a column based on matrix data order
// use getRow() or getCol() if you want row/col regardless of data order
pub fn getSlice(self: *const Self, rc: usize) []u8 {
    switch (self.mtype) {
        .row => {
            std.debug.assert(rc < m);
            const i = rc * n;
            return self.mdata.items[i .. i + n];
        },
        .col => {
            std.debug.assert(rc < n);
            const i = rc * m;
            return self.mdata.items[i .. i + m];
        },
    }
}
Enter fullscreen mode Exit fullscreen mode

The reverse setSlice() has similar logic and in the end, I'm not sure whether the optimization is ever needed. Once again, it's probably better to not go through such complication and just expose setRow() and setCol() instead.

fn setSlice(self: *Self, rc: usize, new_rc: []const u8) void {
    switch (self.mtype) {
        .row => {
            std.debug.assert(rc < m and new_rc.len >= n);
            const i = rc * n;
            self.mdata.replaceRange(i, n, new_rc[0..n]) catch return;
        },
        .col => {
            std.debug.assert(rc < n and new_rc.len >= m);
            const i = rc * m;
            self.mdata.replaceRange(i, m, new_rc[0..m]) catch return;
        },
    }
}
Enter fullscreen mode Exit fullscreen mode

Code: finite_field.zig

Finally, we can look at the code for binary fields. The finite_field.zig exposes the n of finite field p^n^ as n using comptime: Because we only care about binary finite fields, p is always 2 and doesn't needed to specifically passed in.

pub fn BinaryFiniteField(comptime n: comptime_int) type {
    return struct {
        exp: u8 = undefined,
        order: u8 = undefined,
        divisor: u8 = undefined,
        // more code
    }
}
Enter fullscreen mode Exit fullscreen mode

The init() method mostly initializes the 3 numbers needed later. The exp is just the n passed in. Note that because of the use case and my lack of advanced mathematics knowledge, I only care about n >= 1 and n <= 7. This also made it easier for calculating the order because u8 is enough to left shift 1 based on the n. The divisor are primes that are results of using 2 as x in the comments. Leave in the comments if you know the maths and want to explain to other readers why that's the case.

pub fn init() ValueError!Self {
    var d: u8 = undefined;

    // Irreducible polynomial for mod multiplication
    d = switch (n) {
        1 => 3, // 1 + x ? undef  shift(0b11)=2
        2 => 7, // 1 + x + x^2    shift(0b111)=3
        3 => 11, // 1 + x + x^3   shift(0b1011)=4
        4 => 19, // 1 + x + x^4   shift(0b10011)=5
        5 => 37, // 1 + x^2 + x^5 shift(0b100101)=6
        6 => 67, // 1 + x + x^6   shift(0b1000011)=7
        7 => 131, // 1 + x + x^7  shift(0b10000011)=8
        else => return ValueError.InvalidExponentError,
    };

    return .{
        .exp = @intCast(n),
        .divisor = d,
        .order = @as(u8, 1) << @intCast(n),
    };
}
Enter fullscreen mode Exit fullscreen mode

The number order is the 1 + the maximum allowed number in the fields. For example, if n is 3, the field cannot contain any number greater than 10.

pub fn validated(self: *const Self, a: usize) ValueError!u8 {
    if (a < self.order) {
        return @intCast(a);
    } else {
        return ValueError.InvalidNumberError;
    }
}
Enter fullscreen mode Exit fullscreen mode

Addition and Subtraction

Recall that for binary fields, addition is just exclusive OR.

pub fn add(self: *const Self, a: usize, b: usize) ValueError!u8 {
    return try self.validated((try self.validated(a)) ^ (try self.validated(b)));
}
Enter fullscreen mode Exit fullscreen mode

Also, negation is simply the number itself and subtraction is just addition of the negation.

pub fn neg(self: *const Self, a: usize) ValueError!u8 {
    return try self.validated(a);
}

pub fn sub(self: *const Self, a: usize, b: usize) ValueError!u8 {
    return try self.add(a, try self.neg(b));
}
Enter fullscreen mode Exit fullscreen mode

Multiplication

Multiplication is probably the most complicated of all the methods except in the case of n is 1.

pub fn mul(self: *const Self, a: usize, b: usize) ValueError!u8 {
    if (self.exp == 1) {
        return self.validated(try self.validated(a) * try self.validated(b));
    }

    // n > 1
    const x = try self.validated(a);
    const y = try self.validated(b);
    var result: u16 = 0;
    for (0..8) |i| {
        const j = 7 - i;
        if (((y >> @intCast(j)) & 1) == 1) {
            result ^= @as(u16, x) << @intCast(j);
        }
    }
    while (result >= self.order) {
        // count how many binary digits result has
        var j = countBits(result);
        j -= self.exp + 1;
        result ^= @as(u16, self.divisor) << @intCast(j);
    }
    return try self.validated(result);
}
Enter fullscreen mode Exit fullscreen mode

It requires a function to count the number of binary digits of a number. There is probably a much better way to figure this out in Zig but since I don't know, I just implemented a naive solution. Leave a comment if you know a better way.

fn countBits(num: usize) u8 {
    var v = num;
    var c: u8 = 0;
    while (v != 0) {
        v >>= 1;
        c += 1;
    }
    return c;
}
Enter fullscreen mode Exit fullscreen mode

Division

Division is simply the multiplication of the inverse.

pub fn div(self: *const Self, a: usize, b: usize) ValueError!u8 {
    return try self.mul(a, try self.invert(b));
}
Enter fullscreen mode Exit fullscreen mode

However, finding the inverse of a number is more complicated than negation. We do have to check whether the number is 0 first, hence no inverse. We then find the inverse using brute force by trying every number within the order to see whether the result is 1 because b is the inverse of a if and only if a * b == 1. Once again, if you have a better way, leave a comment.

pub fn invert(self: *const Self, a: usize) ValueError!u8 {
    if (try self.validated(a) == 0) {
        return ValueError.NoInverseError;
    }
    for (0..self.order) |b| {
        if (try self.mul(a, b) == 1) {
            return try self.validated(b);
        }
    }
    return ValueError.NoInverseError;
}
Enter fullscreen mode Exit fullscreen mode

Matrix Representation

The final 3 methods are used to find the n x n matrix representation for a number in the field. This is also the primary reason why we had to implement a matrix. The main idea is to separate the number into its basis based on the polynormials mentioned in the comments of init().

fn setCol(m: *mat.Matrix(n, n), c: usize, a: u8) void {
    for (0..n) |r| {
        const v = (a >> @intCast(r)) & 1;
        m.set(r, c, v);
    }
}

fn setAllCols(self: *const Self, m: *mat.Matrix(n, n), a: usize) !void {
    var basis: u8 = 1;
    for (0..n) |c| {
        const p = try self.mul(a, basis);
        basis <<= 1;
        setCol(m, c, p);
    }
}

// n x n binary matrix representation
pub fn toMatrix(self: *const Self, allocator: std.mem.Allocator, a: usize) !mat.Matrix(n, n) {
    var m = try mat.Matrix(n, n).init(allocator, mat.DataOrder.row);
    try self.setAllCols(&m, a);
    return m;
}
Enter fullscreen mode Exit fullscreen mode

Bonus

By adding a function called format to matrix.zig, it's much easier to print out the matrix in debugging calls such as std.debug.print because Zig would implicitly check whether such method exists for the struct and use it to format the output. As such I can just call std.debug.print("{}", mat) if mat is of type Matrix(m, n).

pub fn format(self: *const Self, comptime _: []const u8, _: std.fmt.FormatOptions, stream: anytype) !void {
    switch (self.mtype) {
        .row => {
            try stream.print("\n{d}x{d} row ->\n", .{ m, n });
            for (0..m) |r| {
                for (0..n) |c| {
                    try stream.print("{d} ", .{self.get(r, c)});
                }
                try stream.print("\n", .{});
            }
        },
        .col => {
            try stream.print("\n{d}x{d} col ->\n", .{ m, n });
            for (0..n) |c| {
                for (0..m) |r| {
                    try stream.print("{d} ", .{self.get(r, c)});
                }
                try stream.print("\n", .{});
            }
        },
    }
}
Enter fullscreen mode Exit fullscreen mode

The End

You can read my part 1 at WTF Zig Comptime.

The inspiration of the code is from Erasure Coding For The Masses.

You can find the code for the article here.

The full erasure coding code is here.

Zig Logo

Top comments (0)