Zig NEWS

jack
jack

Posted on

A* algorithm

Hello people, long time no blog. I've being trying to write a roguelike game recently, in zig of course. Naturally, I needed a decent path-finding utility, and A* algorithm immediately came to my mind. Thanks to Amit Patel’ awesome writing, I've had a blast implementing the algorithm. Now, I want to share the interesting experience with you.

I won't explain details of algorithm here, as Amit has done an excellent job already. In case anybody isn't familliar with the subject, I recommend you to read through the above linked paragraph. OK, let's begin.

In order to implement the algorithm, we need to define following stuff first:

  1. Type of position
  2. Type of node
  3. Type of graph

What is position? A 2d or 3d vector is obviously suitable. However, if you really think about it, we don't care about geometry meaning of position at all, what we really need is some sort of identifier to differentiate positions. Like file descriptor in unix world, a number is enough to represent anything. In my case, I choose usize to stand for position.

Although node is only used internally in algorithm, it's structure is vital enough to be discussed here. To simply put, a node also represent a position in the searching graph. More than that, node also contains information about it's established path to starting point, and it's guessed distance to our final destination (the formal name is heuristic cost). Here's final definitions:

const Node = struct {
    from: usize,    // Where do we came from
    gcost: usize,   // Cost between starting point to this node
    hcost: usize,   // Cost between this node to destination  
};
Enter fullscreen mode Exit fullscreen mode

You might ask: "What is position id of the node?". Well, we don't need to include it in the struct at all. Because all nodes will be added into a hash table, whose key is position id.

The last but not least, graph type. Let's think, do we need to know all the positions within graph? The answer depends on how complicated the map is. In a simple map, A* algorithm only need to access very few positions to calculate the path. In a daunting maze, the algorithm might have to access much more positions in order to work properly. So, the more suitable way is to pass only needed positions to algorithm. How do we accomplish that? The answer is virtual interface. Here's my definition of graph:

// A searchable graph struct must have following methods:
//
//     // Used to get id of nodes
//     pub const Iterator = {
//         pub fn next(*@This()) ?usize
//     };
//
//     // Used to get iterator for traversing neighbours of a node
//     pub fn iterateNeigh(self: @This(), id: usize) Iterator
//
//     // Used to calculate graph cost betwen 2 nodes
//     pub fn gcost(self: @This(), from: usize, to: usize) usize
//
//     // Used to calculate heuristic cost betwen 2 nodes
//     pub fn hcost(self: @This(), from: usize, to: usize) usize
Enter fullscreen mode Exit fullscreen mode

Dude, they're just comments! Yeah, that's right, It's actually a contract between the algorithm and it's user. If your map wants to be searched in the algorithm, it needs to fulfill these requirements. We only need 3 api: one for iterating a position's walkable neighbours, two for calculating costs between positions.

Can we do better? Of course, zig has great reflection system to make compile-time type checking possible. We can easily write a function to verify the graph struct:

inline fn verifyGraphStruct(graph: anytype) void {
    const gtype = @TypeOf(graph);
    if (!@hasDecl(gtype, "iterateNeigh") or
        !@hasDecl(gtype, "gcost") or
        !@hasDecl(gtype, "hcost"))
    {
        @compileError("Please verify the graph struct according to above demands.");
    }
    switch (@typeInfo(@typeInfo(@TypeOf(gtype.iterateNeigh)).Fn.return_type.?)) {
        .Struct => if (!std.meta.hasMethod(@typeInfo(@TypeOf(gtype.iterateNeigh)).Fn.return_type.?, "next")) {
            @compileError("`iterateNeigh` must return a valid iterator");
        },
        else => @compileError("`iterateNeigh` must return Iterator"),
    }
    if (@typeInfo(@TypeOf(gtype.gcost)).Fn.return_type.? != usize) {
        @compileError("`gcost` must return usize");
    }
    if (@typeInfo(@TypeOf(gtype.hcost)).Fn.return_type.? != usize) {
        @compileError("`hcost` must return usize");
    }
}
Enter fullscreen mode Exit fullscreen mode

All set now, we can go ahead write the algorithm, it's very easy code really, as long as you've done good homework:

pub fn calculatePath(allocator: std.mem.Allocator, graph: anytype, from: usize, to: usize) !?std.ArrayList(usize) {
    verifyGraphStruct(graph);

    var arena = std.heap.ArenaAllocator.init(allocator);
    defer arena.deinit();
    var nodes = NodeMap.init(arena.allocator());
    var frontier = NodeQueue.init(arena.allocator(), &nodes);

    try nodes.put(from, .{});
    try frontier.add(from);
    while (frontier.count() != 0) {
        const current = frontier.remove();
        if (current == to) break;

        var it = graph.iterateNeigh(current);
        while (it.next()) |next| {
            const new_gcost = nodes.get(current).?.gcost + graph.gcost(current, next);
            const neighbour = nodes.get(next);
            if (neighbour == null or new_gcost < neighbour.?.gcost) {
                try nodes.put(next, .{
                    .from = current,
                    .gcost = new_gcost,
                    .hcost = graph.hcost(to, next),
                });
                try frontier.add(next);
            }
        }
    } else {
        return null;
    }

    var pos = to;
    var path = try std.ArrayList(usize).initCapacity(allocator, 10);
    try path.append(pos);
    if (from != to) {
        while (true) {
            const node = nodes.get(pos).?;
            pos = node.from;
            try path.append(pos);
            if (pos == from) break;
        }
        std.mem.reverse(usize, path.items);
    }
    return path;
}

const NodeMap = std.AutoHashMap(usize, Node);
const NodeQueue = std.PriorityQueue(usize, *const NodeMap, compareNode);
fn compareNode(map: *const NodeMap, n1: usize, n2: usize) math.Order {
    const node1 = map.get(n1).?;
    const node2 = map.get(n2).?;
    return math.order(node1.gcost + node1.hcost, node2.gcost + node2.hcost);
}
Enter fullscreen mode Exit fullscreen mode

The function returns !?std.ArrayList(usize). The error is returned in the situation of memory allocating failure. null is returned if there's no valid path between two positions. NodeMap is used to accumulate the accessed positions so far, which will be searched to form the final path later. NodeQueue is used to expand searching area, whose frontier positions are ordered by sum of graph cost and heuristic cost. As soon as we meet destination, the algorithm stops and search backwards for the path to starting point.

Whew, that's all. I've also wrote a simple program to test the algorithm. You know what, it's quite fun :-).
Demo

In case you want to see full source code, here's the link: https://github.com/Jack-Ji/jok/blob/main/src/utils/pathfind.zig

Top comments (0)