🐷

Zigで配列をなめるときにはできるだけwhileでなくてforを使いましょうという話

2023/06/15に公開

まあ、当たり前のことなんですけどね。

前回の記事のプログラムを書き直した

前回の記事のMJPEGからJPEGを切り出すプログラムにはバグがありました。脚注にも書きましたが、それは読み込んだバッファの最後のバイトが0xffで、次のバッファの先頭が0xd9のときに、JPEGの終わりを見逃してしまうというものです。
なので、ちゃんと4つの状態を遷移するように書き直しました。

a3.zig
const std = @import("std");
const fs = std.fs;
const io = std.io;
const os = std.os;

const BUF_SIZE = 64 * 1024;
const OUTPUT_FILENAME_PATTERN = "out{d:0>4}.jpg";
const JPEG_START0 = 0xff;
const JPEG_START1 = 0xd8;
const JPEG_END0 = 0xff;
const JPEG_END1 = 0xd9;

pub fn main() !void {
    var allocator = std.heap.page_allocator;
    var buffer: [BUF_SIZE]u8 = undefined;
    var frame_num: usize = 0;
    var write_buffer = std.ArrayList(u8).init(allocator);
    defer write_buffer.deinit();
    const State = enum {
        st0, // waiting for JPEG_START0
        st1, // waiting for JPEG_START1
        st2, // waiting for JPEG_END0
        st3, // waiting for JPEG_END1
    };
    var state: State = State.st0;

    while (true) {
        const n = try io.getStdIn().read(&buffer);
        if (n == 0) break;

        var i: usize = 0;
        while (i < n) : (i += 1) {
            switch (state) {
                State.st0 => {
                    if (buffer[i] == JPEG_START0) {
                        state = State.st1;
                    }
                },
                State.st1 => {
                    if (buffer[i] == JPEG_START1) {
                        try write_buffer.append(JPEG_START0);
                        try write_buffer.append(JPEG_START1);
                        state = State.st2;
                    } else if (buffer[i] != JPEG_START0) {
                        state = State.st0;
                    }
                },
                State.st2 => {
                    try write_buffer.append(buffer[i]);
                    if (buffer[i] == JPEG_END0) {
                        state = State.st3;
                    }
                },
                State.st3 => {
                    try write_buffer.append(buffer[i]);
                    if (buffer[i] == JPEG_END1) {
                        try writeFile(frame_num, write_buffer.items);
                        frame_num += 1;
                        write_buffer.clearRetainingCapacity();
                        state = State.st0;
                    } else if (buffer[i] != JPEG_END0) {
                        state = State.st2;
                    }
                },
            }
        }
    }
}

fn writeFile(frame_num: usize, buf: []const u8) !void {
    var filename_buf: [32]u8 = undefined;
    const filename = try std.fmt.bufPrint(&filename_buf, OUTPUT_FILENAME_PATTERN, .{frame_num});
    const output_file = try fs.cwd().createFile(filename, .{});
    defer output_file.close();
    try output_file.writeAll(buf);
}
$ zig build-exe -O Debug a3.zig
$ time ./a3 < a.mjpeg 

real	3m20.802s
user	3m20.419s
sys	0m0.085s
$ zig build-exe -O ReleaseSafe a3.zig
$ time ./a3 < a.mjpeg 

real	3m20.715s
user	3m20.336s
sys	0m0.084s

これは意外でした。
buffer[i] で配列の範囲を外れることはないことはコンパイル時にわかるかと思ったのですが、それはやってくれないのですね。

$ time ./a3 < a.mjpeg 

real	0m0.299s
user	0m0.213s
sys	0m0.080s
$ time ./a3 < a.mjpeg 

real	0m0.313s
user	0m0.232s
sys	0m0.075s

ReleaseFastReleaseSafeではruntime safetyがOFFになるので、実行時間が短縮されます。

ループをwhileからforに修正

前の記事のように、@setRuntimeSafety(false);をつけてもいいのですが、正攻法としては配列bufferを0からn - 1まで回すのをwhileでなくforで回すのが良さそうです。

--- a3.zig	2023-06-15 16:25:34.626407437 +0900
+++ a4.zig	2023-06-15 16:39:07.408740529 +0900
@@ -28,16 +28,15 @@
         const n = try io.getStdIn().read(&buffer);
         if (n == 0) break;
 
-        var i: usize = 0;
-        while (i < n) : (i += 1) {
+        for (buffer[0..n]) |v| {
             switch (state) {
                 State.st0 => {
-                    if (buffer[i] == JPEG_START0) {
+                    if (v == JPEG_START0) {
                         state = State.st1;
                     }

これにより buffer[i] で参照するところはなくなり、配列の範囲チェックをしなくて済むようになりました。

a4.zig
const std = @import("std");
const fs = std.fs;
const io = std.io;
const os = std.os;

const BUF_SIZE = 64 * 1024;
const OUTPUT_FILENAME_PATTERN = "out{d:0>4}.jpg";
const JPEG_START0 = 0xff;
const JPEG_START1 = 0xd8;
const JPEG_END0 = 0xff;
const JPEG_END1 = 0xd9;

pub fn main() !void {
    var allocator = std.heap.page_allocator;
    var buffer: [BUF_SIZE]u8 = undefined;
    var frame_num: usize = 0;
    var write_buffer = std.ArrayList(u8).init(allocator);
    defer write_buffer.deinit();
    const State = enum {
        st0, // waiting for JPEG_START0
        st1, // waiting for JPEG_START1
        st2, // waiting for JPEG_END0
        st3, // waiting for JPEG_END1
    };
    var state: State = State.st0;

    while (true) {
        const n = try io.getStdIn().read(&buffer);
        if (n == 0) break;

        for (buffer[0..n]) |v| {
            switch (state) {
                State.st0 => {
                    if (v == JPEG_START0) {
                        state = State.st1;
                    }
                },
                State.st1 => {
                    if (v == JPEG_START1) {
                        try write_buffer.append(JPEG_START0);
                        try write_buffer.append(JPEG_START1);
                        state = State.st2;
                    } else if (v != JPEG_START0) {
                        state = State.st0;
                    }
                },
                State.st2 => {
                    try write_buffer.append(v);
                    if (v == JPEG_END0) {
                        state = State.st3;
                    }
                },
                State.st3 => {
                    try write_buffer.append(v);
                    if (v == JPEG_END1) {
                        try writeFile(frame_num, write_buffer.items);
                        frame_num += 1;
                        write_buffer.clearRetainingCapacity();
                        state = State.st0;
                    } else if (v != JPEG_END0) {
                        state = State.st2;
                    }
                },
            }
        }
    }
}

fn writeFile(frame_num: usize, buf: []const u8) !void {
    var filename_buf: [32]u8 = undefined;
    const filename = try std.fmt.bufPrint(&filename_buf, OUTPUT_FILENAME_PATTERN, .{frame_num});
    const output_file = try fs.cwd().createFile(filename, .{});
    defer output_file.close();
    try output_file.writeAll(buf);
}

最適化の各オプションでビルドして計測した結果は以下の通り。ソースコード上から配列のインデックスによるアクセスが消えているので、ランタイムチェックが軽くなっていることがわかります。

$ zig build-exe -O Debug a4.zig
$ time ./a4 < a.mjpeg 

real	0m0.831s
user	0m0.754s
sys	0m0.069s
$ zig build-exe -O ReleaseSafe a4.zig
$ time ./a4 < a.mjpeg 

real	0m0.294s
user	0m0.208s
sys	0m0.080s
$ zig build-exe -O ReleaseFast a4.zig
$ time ./a4 < a.mjpeg 

real	0m0.299s
user	0m0.198s
sys	0m0.095s
$ zig build-exe -O ReleaseSmall a4.zig
$ time ./a4 < a.mjpeg 

real	0m0.310s
user	0m0.217s
sys	0m0.087s

なのでタイトルの通り、Zigで配列をなめるときにはできるだけwhileでなくてforを使いましょう。
配列の範囲をはみ出さないことが明確になるので、コードを読むときの脳の負荷が減るし、速いコードが生成されます。
特にC言語で書かれたものをZigに書き直したときには、うっかりwhileのままにしておかないように気をつけたほうがよいですね。まあ、lint的なツールが警告してくれるようになるとは思いますが。

関連

https://zenn.dev/tetsu_koba/articles/411d58244ba993
https://zenn.dev/tetsu_koba/articles/1f9d8367668e93

Discussion