Learn Zig Series (#82) - DNS Resolver from Scratch

Learn Zig Series (#82) - DNS Resolver from Scratch

zig.png

What will I learn

  • How DNS works at the protocol level -- query format, response parsing, record types;
  • How to construct raw DNS query packets in Zig using packed structs and byte manipulation;
  • How to send DNS queries over UDP and parse the responses;
  • How to handle DNS compression pointers in response messages;
  • How to resolve A, AAAA, CNAME, and MX records;
  • How to build a reusable DNS resolver struct with caching and timeout support.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Zig 0.14+ distribution (download from ziglang.org);
  • The ambition to learn Zig programming.

Difficulty

  • Intermediate

Curriculum (of the Learn Zig Series):

Learn Zig Series (#82) - DNS Resolver from Scratch

Solutions to Episode 81 Exercises

Exercise 1: UDP chat application

const std = @import("std");
const posix = std.posix;

pub fn main() !void {
    const allocator = std.heap.page_allocator;
    _ = allocator;

    var args = std.process.args();
    _ = args.next(); // skip program name

    const mode = args.next() orelse {
        std.debug.print("usage: udpchat <send|recv> [host] [port]\n", .{});
        return;
    };

    if (std.mem.eql(u8, mode, "recv")) {
        const port_str = args.next() orelse "9000";
        const port = std.fmt.parseInt(u16, port_str, 10) catch 9000;
        try runReceiver(port);
    } else {
        const host = args.next() orelse "127.0.0.1";
        const port_str = args.next() orelse "9000";
        const port = std.fmt.parseInt(u16, port_str, 10) catch 9000;
        try runSender(host, port);
    }
}

fn runReceiver(port: u16) !void {
    const sock = try posix.socket(posix.AF.INET, posix.SOCK.DGRAM, 0);
    defer posix.close(sock);

    try posix.setsockopt(sock, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
    const addr = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, port);
    try posix.bind(sock, &addr.any, addr.getOsSockLen());
    std.debug.print("listening on port {d}\n", .{port});

    var buf: [1024]u8 = undefined;
    while (true) {
        var src_addr: posix.sockaddr.storage = undefined;
        var addr_len: posix.socklen_t = @sizeOf(posix.sockaddr.storage);
        const n = try posix.recvfrom(sock, &buf, 0, @ptrCast(&src_addr), &addr_len);
        if (n == 0) continue;
        const sender = std.net.Address.initPosix(@ptrCast(&src_addr));
        std.debug.print("[{}.{}:{d}] {s}\n", .{
            sender.in.sa.addr[0], sender.in.sa.addr[1],
            std.mem.bigToNative(u16, sender.in.sa.port),
            buf[0..n],
        });
    }
}

fn runSender(host: []const u8, port: u16) !void {
    const sock = try posix.socket(posix.AF.INET, posix.SOCK.DGRAM, 0);
    defer posix.close(sock);

    var addr_parts: [4]u8 = undefined;
    var it = std.mem.splitScalar(u8, host, '.');
    for (&addr_parts) |*b| {
        const part = it.next() orelse break;
        b.* = std.fmt.parseInt(u8, part, 10) catch 0;
    }
    const dest = std.net.Address.initIp4(addr_parts, port);

    const stdin = std.io.getStdIn().reader();
    var line_buf: [1024]u8 = undefined;
    while (true) {
        const line = stdin.readUntilDelimiter(&line_buf, '\n') catch break;
        if (line.len == 0) continue;
        _ = try posix.sendto(sock, line, 0, &dest.any, dest.getOsSockLen());
    }
}

The key insight is that UDP's connectionless nature makes the receiver trivially simple -- one socket handles all senders, and recvfrom tells you who sent each message.

Exercise 2: UDP ping tool

const std = @import("std");
const posix = std.posix;

pub fn main() !void {
    const sock = try posix.socket(posix.AF.INET, posix.SOCK.DGRAM, 0);
    defer posix.close(sock);

    const target = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 9000);
    const tv = posix.timeval{ .sec = 2, .usec = 0 };
    try posix.setsockopt(sock, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(tv));

    var min_us: i64 = std.math.maxInt(i64);
    var max_us: i64 = 0;
    var total_us: i64 = 0;
    var received: u32 = 0;

    for (0..10) |seq| {
        var payload: [16]u8 = undefined;
        const sent_ts = std.time.microTimestamp();
        std.mem.writeInt(i64, payload[0..8], sent_ts, .little);
        std.mem.writeInt(u64, payload[8..16], seq, .little);

        _ = posix.sendto(sock, &payload, 0, &target.any, target.getOsSockLen()) catch continue;

        var buf: [64]u8 = undefined;
        var src: posix.sockaddr.storage = undefined;
        var slen: posix.socklen_t = @sizeOf(posix.sockaddr.storage);
        const n = posix.recvfrom(sock, &buf, 0, @ptrCast(&src), &slen) catch {
            std.debug.print("ping {d}: lost\n", .{seq});
            std.time.sleep(1_000_000_000);
            continue;
        };
        if (n < 8) continue;

        const rtt = std.time.microTimestamp() - sent_ts;
        if (rtt < min_us) min_us = rtt;
        if (rtt > max_us) max_us = rtt;
        total_us += rtt;
        received += 1;

        std.debug.print("ping {d}: rtt={d}us\n", .{ seq, rtt });
        std.time.sleep(1_000_000_000);
    }

    if (received > 0) {
        std.debug.print("\n--- stats ---\n", .{});
        std.debug.print("sent=10 recv={d} lost={d}\n", .{ received, 10 - received });
        std.debug.print("rtt min={d}us avg={d}us max={d}us\n", .{
            min_us, @divTrunc(total_us, received), max_us,
        });
    } else {
        std.debug.print("all packets lost\n", .{});
    }
}

The trick here is embedding the send timestamp directly in the payload so we can calculate round-trip time on receipt without maintaining separate state.

Exercise 3: Service discovery via broadcast

const std = @import("std");
const posix = std.posix;

// Responder: listens for DISCOVER, replies with services
pub fn runResponder(port: u16) !void {
    const sock = try posix.socket(posix.AF.INET, posix.SOCK.DGRAM, 0);
    defer posix.close(sock);

    try posix.setsockopt(sock, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
    const addr = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, port);
    try posix.bind(sock, &addr.any, addr.getOsSockLen());

    var buf: [256]u8 = undefined;
    while (true) {
        var src: posix.sockaddr.storage = undefined;
        var slen: posix.socklen_t = @sizeOf(posix.sockaddr.storage);
        const n = try posix.recvfrom(sock, &buf, 0, @ptrCast(&src), &slen);
        if (n >= 8 and std.mem.eql(u8, buf[0..8], "DISCOVER")) {
            const reply = "myhost|http:8080,ssh:22";
            _ = try posix.sendto(sock, reply, 0, @ptrCast(&src), slen);
        }
    }
}

// Discoverer: sends DISCOVER broadcast, collects responses for 2 seconds
pub fn runDiscoverer(port: u16) !void {
    const sock = try posix.socket(posix.AF.INET, posix.SOCK.DGRAM, 0);
    defer posix.close(sock);

    try posix.setsockopt(sock, posix.SOL.SOCKET, posix.SO.BROADCAST, &std.mem.toBytes(@as(c_int, 1)));
    const bcast = std.net.Address.initIp4(.{ 255, 255, 255, 255 }, port);
    _ = try posix.sendto(sock, "DISCOVER", 0, &bcast.any, bcast.getOsSockLen());

    const tv = posix.timeval{ .sec = 2, .usec = 0 };
    try posix.setsockopt(sock, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(tv));

    var buf: [512]u8 = undefined;
    std.debug.print("discovering services...\n", .{});
    while (true) {
        var src: posix.sockaddr.storage = undefined;
        var slen: posix.socklen_t = @sizeOf(posix.sockaddr.storage);
        const n = posix.recvfrom(sock, &buf, 0, @ptrCast(&src), &slen) catch break;
        const sender = std.net.Address.initPosix(@ptrCast(&src));
        std.debug.print("found: {s} from {}.{}.{}.{}\n", .{
            buf[0..n],
            sender.in.sa.addr[0], sender.in.sa.addr[1],
            sender.in.sa.addr[2], sender.in.sa.addr[3],
        });
    }
    std.debug.print("discovery complete\n", .{});
}

The broadcast approach (SO_BROADCAST + 255.255.255.255) lets us find services without knowing any addresses ahead of time -- the network layer delivers the packet to every host on the subnet.

OK, now that we've covered UDP fundamentals, let's put them to real use. Every single time you type a URL into your browser, your machine fires off a UDP datagram to a DNS server and parses the response. That's what we're building today -- a DNS resolver from scratch. No libc getaddrinfo, no std.net.Address.resolveIp. Just raw UDP packets, hand-crafted DNS queries, and manual response parsing.

This is one of those projects where Zig's strengths really shine: packed structs for wire formats, explicit byte-order handling with std.mem.bigToNative, and error handling that forces you to deal with every malformed packet edge case. And it's directly practical -- understanding DNS at the packet level is something that pays dividends whether you're debugging network issues, building a web server, or just trying to figure out why your app takes 5 seconds to start (spoiler: it's usually DNS).

How DNS actually works

DNS (Domain Name System) is conceptually simple: you ask "what's the IP address for example.com?" and a server tells you. But the protocol details are worth understanding because they explain why DNS uses UDP and not TCP (for most queries, anyway).

A DNS query is a single UDP datagram, typically under 512 bytes. The server response is also a single UDP datagram. The whole exchange is one round trip -- send a question, get an answer. No handshake, no connection setup, no teardown. This is why DNS uses UDP: the overhead of TCP's three-way handshake would double the latency for something that needs to be as fast as posible.

The DNS message format (defined in RFC 1035, published in 1987 -- this protocol is older than most of you reading this) consists of:

  1. A 12-byte header with a transaction ID, flags, and counts
  2. A question section with the domain name and query type
  3. An answer section with the actual records (only in responses)
  4. Optional authority and additional sections (which we'll mostly ignore)

Let's build each piece.

The DNS header

The DNS header is exactly 12 bytes, and every field matters:

const std = @import("std");

pub const DnsHeader = packed struct {
    id: u16,           // transaction ID -- match query to response
    flags: u16,        // bitfield: QR, opcode, AA, TC, RD, RA, rcode
    qdcount: u16,      // number of questions
    ancount: u16,      // number of answer records
    nscount: u16,      // number of authority records
    arcount: u16,      // number of additional records

    pub fn toNetworkOrder(self: DnsHeader) DnsHeader {
        return .{
            .id = std.mem.nativeToBig(u16, self.id),
            .flags = std.mem.nativeToBig(u16, self.flags),
            .qdcount = std.mem.nativeToBig(u16, self.qdcount),
            .ancount = std.mem.nativeToBig(u16, self.ancount),
            .nscount = std.mem.nativeToBig(u16, self.nscount),
            .arcount = std.mem.nativeToBig(u16, self.arcount),
        };
    }

    pub fn fromNetworkOrder(self: DnsHeader) DnsHeader {
        return .{
            .id = std.mem.bigToNative(u16, self.id),
            .flags = std.mem.bigToNative(u16, self.flags),
            .qdcount = std.mem.bigToNative(u16, self.qdcount),
            .ancount = std.mem.bigToNative(u16, self.ancount),
            .nscount = std.mem.bigToNative(u16, self.nscount),
            .arcount = std.mem.bigToNative(u16, self.arcount),
        };
    }
};

The id field is crucial -- it's how you match responses to queries. When you send a query with id 0x1234, the response will have the same id. This is especially important because UDP is connectionless, so if you send multiple queries from the same socket, you need the id to tell the responses apart.

The flags field is a 16-bit bitfield packed with information. For a standard recursive query (which is what your computer sends to your ISP's DNS server), you set bit 8 (RD -- recursion desired). The response sets bit 15 (QR -- query response) and the low 4 bits contain the response code (0 = success, 3 = NXDOMAIN / name not found).

Encoding domain names

DNS encodes domain names in a somewhat unusual format called label encoding. Instead of "example.com", the wire format is "\x07example\x03com\x00" -- each label is prefixed with its length, and the whole thing is terminated by a zero-length label.

pub const DnsName = struct {
    /// Encode a domain name (like "example.com") into DNS wire format.
    /// Returns the number of bytes written.
    pub fn encode(name: []const u8, buf: []u8) !usize {
        if (name.len == 0) return error.InvalidName;
        if (name.len > 253) return error.NameTooLong;

        var pos: usize = 0;
        var it = std.mem.splitScalar(u8, name, '.');

        while (it.next()) |label| {
            if (label.len == 0) continue; // trailing dot
            if (label.len > 63) return error.LabelTooLong;
            if (pos + 1 + label.len >= buf.len) return error.BufferTooSmall;

            buf[pos] = @intCast(label.len);
            pos += 1;
            @memcpy(buf[pos..][0..label.len], label);
            pos += label.len;
        }

        if (pos >= buf.len) return error.BufferTooSmall;
        buf[pos] = 0; // terminating zero-length label
        pos += 1;

        return pos;
    }

    /// Decode a DNS name from a response packet. Handles compression pointers.
    /// `packet` is the full response, `offset` is where to start reading.
    /// Returns the decoded name and the number of bytes consumed from `offset`.
    pub fn decode(
        packet: []const u8,
        start_offset: usize,
        out_buf: []u8,
    ) !struct { name: []const u8, bytes_consumed: usize } {
        var offset = start_offset;
        var out_pos: usize = 0;
        var bytes_consumed: usize = 0;
        var followed_pointer = false;
        var jumps: u8 = 0;

        while (offset < packet.len) {
            const len = packet[offset];

            if (len == 0) {
                // End of name
                if (!followed_pointer) {
                    bytes_consumed = offset - start_offset + 1;
                }
                break;
            }

            // Check for compression pointer (top 2 bits set)
            if (len & 0xC0 == 0xC0) {
                if (offset + 1 >= packet.len) return error.Truncated;
                const ptr_offset = (@as(u16, len & 0x3F) << 8) | packet[offset + 1];

                if (!followed_pointer) {
                    bytes_consumed = offset - start_offset + 2;
                }
                followed_pointer = true;
                offset = ptr_offset;

                jumps += 1;
                if (jumps > 10) return error.TooManyPointers;
                continue;
            }

            // Regular label
            if (offset + 1 + len > packet.len) return error.Truncated;
            if (out_pos + len + 1 > out_buf.len) return error.BufferTooSmall;

            if (out_pos > 0) {
                out_buf[out_pos] = '.';
                out_pos += 1;
            }
            @memcpy(out_buf[out_pos..][0..len], packet[offset + 1 ..][0..len]);
            out_pos += len;
            offset += 1 + len;
        }

        if (!followed_pointer) {
            bytes_consumed = offset - start_offset + 1;
        }

        return .{
            .name = out_buf[0..out_pos],
            .bytes_consumed = bytes_consumed,
        };
    }
};

That decode function is where things get interesting. DNS responses use compression pointers to avoid repeating the same domain name multiple times. If a response contains records for www.example.com, mail.example.com, and example.com, the example.com part only appears once in the packet -- the other records point back to it using a 2-byte pointer.

A pointer is indicated by the top two bits of the length byte being set (0xC0). The remaining 14 bits are an offset into the packet where the rest of the name can be found. This is a really clever space optimization from 1987 and it means we need to track "jumps" to avoid infinite loops from malicious packets (hence the jumps > 10 check).

Building a DNS query

Now we can put it all together -- construct a complete DNS query packet:

pub const QueryType = enum(u16) {
    A = 1,      // IPv4 address
    AAAA = 28,  // IPv6 address
    CNAME = 5,  // canonical name (alias)
    MX = 15,    // mail exchange
    TXT = 16,   // text record
    NS = 2,     // name server
};

pub fn buildQuery(
    buf: []u8,
    name: []const u8,
    qtype: QueryType,
    id: u16,
) !usize {
    if (buf.len < 12) return error.BufferTooSmall;

    // Write header
    const header = DnsHeader{
        .id = id,
        .flags = 0x0100, // RD (recursion desired) flag set
        .qdcount = 1,
        .ancount = 0,
        .nscount = 0,
        .arcount = 0,
    };
    const net_header = header.toNetworkOrder();
    @memcpy(buf[0..12], std.mem.asBytes(&net_header));

    // Write question: encoded name + qtype (2 bytes) + qclass (2 bytes)
    var pos: usize = 12;
    const name_len = try DnsName.encode(name, buf[pos..]);
    pos += name_len;

    if (pos + 4 > buf.len) return error.BufferTooSmall;

    // Query type
    std.mem.writeInt(u16, buf[pos..][0..2], @intFromEnum(qtype), .big);
    pos += 2;

    // Query class (IN = Internet = 1)
    std.mem.writeInt(u16, buf[pos..][0..2], 1, .big);
    pos += 2;

    return pos;
}

The query format is straightforward: 12 bytes of header, then the encoded domain name, then 2 bytes for the query type (A record? AAAA? MX?) and 2 bytes for the query class (always 1 for Internet). That's it. A DNS query for example.com is about 30 bytes total.

The 0x0100 flags value sets the RD (Recursion Desired) bit, which tells the DNS server "please do the full resolution for me, don't just tell me to ask another server." This is what you want when querying your ISP's recursive resolver. If you were building a recursive resolver yourself, you'd leave RD unset and do the iterative resolution manually (following referrals from root servers to TLD servers to authoritative servers).

Parsing DNS responses

The response uses the same header format, followed by the original question (echoed back), and then the answer records. Each answer record has a name, type, class, TTL, and the actual data:

pub const DnsRecord = struct {
    name: [256]u8,
    name_len: usize,
    rtype: u16,
    class: u16,
    ttl: u32,
    // Record-specific data
    data: union {
        ipv4: [4]u8,
        ipv6: [16]u8,
        cname: [256]u8,
        mx: struct {
            preference: u16,
            exchange: [256]u8,
            exchange_len: usize,
        },
        raw: struct {
            buf: [512]u8,
            len: usize,
        },
    },
    data_tag: enum { ipv4, ipv6, cname, mx, raw },
    cname_len: usize,
};

pub fn parseResponse(
    packet: []const u8,
    records: []DnsRecord,
) !struct { header: DnsHeader, count: usize } {
    if (packet.len < 12) return error.Truncated;

    // Parse header
    var header: DnsHeader = undefined;
    @memcpy(std.mem.asBytes(&header), packet[0..12]);
    header = header.fromNetworkOrder();

    // Check response code
    const rcode = header.flags & 0x0F;
    if (rcode == 3) return error.NameNotFound;   // NXDOMAIN
    if (rcode != 0) return error.ServerError;

    // Skip question section
    var offset: usize = 12;
    var qi: u16 = 0;
    while (qi < header.qdcount) : (qi += 1) {
        // Skip the question name
        while (offset < packet.len) {
            const len = packet[offset];
            if (len == 0) { offset += 1; break; }
            if (len & 0xC0 == 0xC0) { offset += 2; break; }
            offset += 1 + len;
        }
        offset += 4; // skip qtype + qclass
    }

    // Parse answer records
    var count: usize = 0;
    var ai: u16 = 0;
    while (ai < header.ancount and count < records.len) : (ai += 1) {
        var rec: DnsRecord = undefined;

        // Parse record name
        const name_result = try DnsName.decode(packet, offset, &rec.name);
        rec.name_len = name_result.name.len;
        offset += name_result.bytes_consumed;

        if (offset + 10 > packet.len) return error.Truncated;

        rec.rtype = std.mem.readInt(u16, packet[offset..][0..2], .big);
        rec.class = std.mem.readInt(u16, packet[offset + 2 ..][0..2], .big);
        rec.ttl = std.mem.readInt(u32, packet[offset + 4 ..][0..4], .big);
        const rdlen = std.mem.readInt(u16, packet[offset + 8 ..][0..2], .big);
        offset += 10;

        if (offset + rdlen > packet.len) return error.Truncated;

        // Parse record data based on type
        switch (rec.rtype) {
            1 => { // A record (IPv4)
                if (rdlen != 4) return error.InvalidRecord;
                @memcpy(&rec.data.ipv4, packet[offset..][0..4]);
                rec.data_tag = .ipv4;
            },
            28 => { // AAAA record (IPv6)
                if (rdlen != 16) return error.InvalidRecord;
                @memcpy(&rec.data.ipv6, packet[offset..][0..16]);
                rec.data_tag = .ipv6;
            },
            5 => { // CNAME
                const cname_result = try DnsName.decode(packet, offset, &rec.data.cname);
                rec.cname_len = cname_result.name.len;
                rec.data_tag = .cname;
            },
            15 => { // MX
                if (rdlen < 3) return error.InvalidRecord;
                rec.data.mx.preference = std.mem.readInt(u16, packet[offset..][0..2], .big);
                const mx_result = try DnsName.decode(packet, offset + 2, &rec.data.mx.exchange);
                rec.data.mx.exchange_len = mx_result.name.len;
                rec.data_tag = .mx;
            },
            else => {
                const copy_len = @min(rdlen, 512);
                @memcpy(rec.data.raw.buf[0..copy_len], packet[offset..][0..copy_len]);
                rec.data.raw.len = copy_len;
                rec.data_tag = .raw;
            },
        }

        offset += rdlen;
        records[count] = rec;
        count += 1;
    }

    return .{ .header = header, .count = count };
}

The response parser has to handle several record types. A records are the most common -- they contain a 4-byte IPv4 address. AAAA records contain a 16-byte IPv6 address. CNAME records contain another domain name (which means you might need to do a second lookup to resolve the actual address). MX records contain a priority value and a mail server name.

Notice how the CNAME and MX data fields contain domain names that also use compression pointers. The DnsName.decode function handles this because it takes the full packet as input, allowing it to follow pointers anywhere in the message. This is why DNS parsing always needs access to the complete packet, not just the current record.

The resolver: putting it all together

Now let's build the actual resolver that sends queries and collects answers:

pub const DnsResolver = struct {
    sock: posix.socket_t,
    server: std.net.Address,
    next_id: u16,

    pub fn init(server_ip: [4]u8) !DnsResolver {
        const sock = try posix.socket(posix.AF.INET, posix.SOCK.DGRAM, 0);
        errdefer posix.close(sock);

        // Set a 3-second receive timeout
        const tv = posix.timeval{ .sec = 3, .usec = 0 };
        try posix.setsockopt(sock, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(tv));

        return .{
            .sock = sock,
            .server = std.net.Address.initIp4(server_ip, 53),
            .next_id = @truncate(@as(u64, @bitCast(std.time.milliTimestamp()))),
        };
    }

    pub fn deinit(self: *DnsResolver) void {
        posix.close(self.sock);
    }

    pub fn resolve(
        self: *DnsResolver,
        name: []const u8,
        qtype: QueryType,
        records: []DnsRecord,
    ) !struct { count: usize, response_id: u16 } {
        var query_buf: [512]u8 = undefined;
        const id = self.next_id;
        self.next_id +%= 1;

        const query_len = try buildQuery(&query_buf, name, qtype, id);

        // Send query
        _ = try posix.sendto(
            self.sock,
            query_buf[0..query_len],
            0,
            &self.server.any,
            self.server.getOsSockLen(),
        );

        // Receive response
        var resp_buf: [512]u8 = undefined;
        var src: posix.sockaddr.storage = undefined;
        var slen: posix.socklen_t = @sizeOf(posix.sockaddr.storage);

        const n = try posix.recvfrom(
            self.sock,
            &resp_buf,
            0,
            @ptrCast(&src),
            &slen,
        );

        const response = resp_buf[0..n];

        // Verify the response ID matches our query
        if (n < 12) return error.Truncated;
        const resp_id = std.mem.readInt(u16, response[0..2], .big);
        if (resp_id != id) return error.IdMismatch;

        const result = try parseResponse(response, records);
        return .{ .count = result.count, .response_id = resp_id };
    }
};

The next_id field starts from a timestamp-derived value (not starting at 0 every time -- that would make it trivially predictable, which is a security concern we'll get to in a moment). Each query increments the id with wrapping addition (+%=).

The resolve function is straightforward: build query, send it, receive response, verify the ID matches, parse the records. The 3-second timeout prevents the resolver from hanging forever if the DNS server doesn't respond.

Using the resolver

Here's a complete program that resolves a domain name and prints the results:

const std = @import("std");
const dns = @import("dns.zig");

pub fn main() !void {
    var resolver = try dns.DnsResolver.init(.{ 8, 8, 8, 8 }); // Google's public DNS
    defer resolver.deinit();

    const name = "example.com";
    var records: [16]dns.DnsRecord = undefined;

    // Resolve A records (IPv4)
    const result = resolver.resolve(name, .A, &records) catch |err| {
        std.debug.print("DNS lookup failed: {}\n", .{err});
        return;
    };

    std.debug.print("DNS results for {s}:\n", .{name});
    for (records[0..result.count]) |rec| {
        switch (rec.data_tag) {
            .ipv4 => {
                std.debug.print("  A: {d}.{d}.{d}.{d} (TTL: {d}s)\n", .{
                    rec.data.ipv4[0], rec.data.ipv4[1],
                    rec.data.ipv4[2], rec.data.ipv4[3],
                    rec.ttl,
                });
            },
            .cname => {
                std.debug.print("  CNAME: {s}\n", .{rec.data.cname[0..rec.cname_len]});
            },
            else => {
                std.debug.print("  type={d}\n", .{rec.rtype});
            },
        }
    }
}

If you run this, you should see something like A: 93.184.216.34 (TTL: 3600s) -- that's example.com's actual IPv4 address. The TTL tells you how long you can cache this result before asking again (3600 seconds = 1 hour).

Security considerations

Building a DNS resolver is a great learning exercise, but there are some real security concerns you should know about:

DNS spoofing: Since DNS uses UDP, an attacker who can guess the transaction ID and source port can inject fake responses before the real server replies. This is why real resolvers use randomized source ports and transaction IDs, and why DNSSEC exists (cryptographic signatures on DNS records).

Buffer overflows: Malformed DNS packets with incorrect length fields or circular compression pointers could cause out-of-bounds reads. Our parser handles this with bounds checking on every read and a pointer-jump limit, but in C this is where things historically went wrong. Zig's slice bounds checking and explicit error handling make this significantly safer.

Amplification attacks: DNS responses are often much larger than queries (especially TXT and ANY queries), which makes DNS servers useful for DDoS amplification. If you build a DNS server (which we'll do in the next episode), rate limiting and response size limits are essential.

Having said that, for a local resolver used in your own applications, the primary risk is just getting incorrect results from a compromised network. Using DNS-over-HTTPS or DNS-over-TLS (which we'll explore later in the networking section) mitigates most of these issues.

Testing the resolver

Testing DNS code without hitting real servers is important for CI and offline development. We can build test packets manually:

const std = @import("std");
const dns = @import("dns.zig");

test "encode and decode domain name" {
    var buf: [256]u8 = undefined;
    const len = try dns.DnsName.encode("example.com", &buf);

    // Should be: 7, 'e','x','a','m','p','l','e', 3, 'c','o','m', 0
    try std.testing.expectEqual(@as(usize, 13), len);
    try std.testing.expectEqual(@as(u8, 7), buf[0]);
    try std.testing.expectEqualStrings("example", buf[1..8]);
    try std.testing.expectEqual(@as(u8, 3), buf[8]);
    try std.testing.expectEqualStrings("com", buf[9..12]);
    try std.testing.expectEqual(@as(u8, 0), buf[12]);

    // Decode it back
    var decode_buf: [256]u8 = undefined;
    const result = try dns.DnsName.decode(buf[0..len], 0, &decode_buf);
    try std.testing.expectEqualStrings("example.com", result.name);
}

test "build query packet" {
    var buf: [512]u8 = undefined;
    const len = try dns.buildQuery(&buf, "example.com", .A, 0x1234);

    // Verify header
    try std.testing.expectEqual(@as(u16, 0x12), buf[0]);
    try std.testing.expectEqual(@as(u16, 0x34), buf[1]);

    // Verify flags (0x0100 = RD set)
    try std.testing.expectEqual(@as(u8, 0x01), buf[2]);
    try std.testing.expectEqual(@as(u8, 0x00), buf[3]);

    // Verify qdcount = 1
    try std.testing.expectEqual(@as(u8, 0x00), buf[4]);
    try std.testing.expectEqual(@as(u8, 0x01), buf[5]);

    // Total length: 12 (header) + 13 (name) + 4 (qtype + qclass) = 29
    try std.testing.expectEqual(@as(usize, 29), len);
}

test "parse synthetic A record response" {
    // Build a minimal DNS response by hand
    var pkt: [64]u8 = undefined;
    var pos: usize = 0;

    // Header: id=0x1234, flags=0x8180 (response, no error), 1 question, 1 answer
    std.mem.writeInt(u16, pkt[0..2], 0x1234, .big);
    std.mem.writeInt(u16, pkt[2..4], 0x8180, .big);
    std.mem.writeInt(u16, pkt[4..6], 1, .big);  // qdcount
    std.mem.writeInt(u16, pkt[6..8], 1, .big);  // ancount
    std.mem.writeInt(u16, pkt[8..10], 0, .big); // nscount
    std.mem.writeInt(u16, pkt[10..12], 0, .big); // arcount
    pos = 12;

    // Question: example.com, A, IN
    const name_len = try dns.DnsName.encode("test.com", pkt[pos..]);
    pos += name_len;
    std.mem.writeInt(u16, pkt[pos..][0..2], 1, .big); pos += 2; // qtype A
    std.mem.writeInt(u16, pkt[pos..][0..2], 1, .big); pos += 2; // qclass IN

    // Answer: pointer to question name, A, IN, TTL=300, 4 bytes, 1.2.3.4
    pkt[pos] = 0xC0; pkt[pos + 1] = 12; pos += 2; // compression pointer to offset 12
    std.mem.writeInt(u16, pkt[pos..][0..2], 1, .big); pos += 2;   // type A
    std.mem.writeInt(u16, pkt[pos..][0..2], 1, .big); pos += 2;   // class IN
    std.mem.writeInt(u32, pkt[pos..][0..4], 300, .big); pos += 4;  // TTL
    std.mem.writeInt(u16, pkt[pos..][0..2], 4, .big); pos += 2;   // rdlength
    pkt[pos] = 1; pkt[pos + 1] = 2; pkt[pos + 2] = 3; pkt[pos + 3] = 4;
    pos += 4;

    var records: [4]dns.DnsRecord = undefined;
    const result = try dns.parseResponse(pkt[0..pos], &records);
    try std.testing.expectEqual(@as(usize, 1), result.count);
    try std.testing.expectEqual(dns.DnsRecord.DataTag.ipv4, records[0].data_tag);
    try std.testing.expectEqualSlices(u8, &.{ 1, 2, 3, 4 }, &records[0].data.ipv4);
}

Building synthetic DNS packets by hand is tedious but it gives you complete control over the test. You can create malformed packets, test compression pointer edge cases, and verify your parser handles every variant correctly without needing a network connection. This is the kind of test that runs in CI without any external dependancies ;-)

Exercises

  1. Extend the DNS resolver to support AAAA record lookups (IPv6). Resolve a domain that has both A and AAAA records (like google.com) and print both the IPv4 and IPv6 addresses. The AAAA record type is 28, and the data is 16 bytes containing the IPv6 address -- format it as 8 groups of 4 hex digits separated by colons.

  2. Implement a DNS cache using a hash map (episode 22). The cache key should be the combination of domain name + query type. The cached value should include the DNS records and the timestamp when they were cached. On lookup, check if the cached result is still valid (current time < cache time + TTL). Write a test that caches a result, verifies it hits on the second lookup, and misses after the TTL expires.

  3. Build a batch DNS resolver that reads a list of domain names from a file (one per line) and resolves all of them, printing the results in a table format: domain | type | value | TTL. Handle errors gracefully -- if one domain fails (NXDOMAIN, timeout), print the error and continue with the next one. Measure and print the total resolution time at the end.

Thanks for reading!

@scipio



0
0
0.000
0 comments