Open13

Rust TCP Server TDD ノート

ta1kt0meta1kt0me

Rustで mini Redis Server を作ってみるとき、TDD したくなったのでやり方を考えてみる

ta1kt0meta1kt0me

最初のステップはこんな感じのコード。

この辺を参考にしている。

fn main() {
    let listner = TcpListener::bind("127.0.0.1:6379").unwrap();

    for stream in listner.incoming() {
        match stream {
            Ok(stream) => {
                do_something();
            }
            Err(e) => {
                error_handling();
            }
        }
    }
}
ta1kt0meta1kt0me

接続が成功した場合

    fn server_connect_success() {
        thread::spawn(|| {
            main();
        });
        let remote = "127.0.0.1:6379".parse().unwrap();
        let mut stream = TcpStream::connect_timeout(&remote, Duration::from_secs(2)).unwrap();
        let mut buffer = String::new();
        let res = stream.read_to_string(&mut buffer).unwrap();
        assert_eq!(0, res);
    }
  • Threadを生成してサーバーを起動する
  • TcpStream::connect_timeout でサーバーに接続する
  • とりあえず何か返ってくるならbufferをチェックする
ta1kt0meta1kt0me

接続が失敗した場合
ひとまずportが間違ってるとか。

    fn server_connect_failure() {
        thread::spawn(|| {
            main();
        });
        let remote = "127.0.0.1:6378".parse().unwrap();
        let stream = TcpStream::connect_timeout(&remote, Duration::from_secs(2));
        let err = stream.unwrap_err();
        assert_eq!("Connection refused (os error 61)", err.to_string());
    }

もっといいやり方ありそう...

  • assertが文字列比較になっているので違うやり方にしたい
  • そもそもmain関数のmatchのError処理ではない

listner.incoming() がErrorになるようなmockの仕方とか何かあったりするかな...

ta1kt0meta1kt0me

PING/PONGの場合

    fn ping() {
        thread::spawn(|| {
            main();
        });
        let remote = "127.0.0.1:6379".parse().unwrap();
        let mut stream = TcpStream::connect_timeout(&remote, Duration::from_secs(2)).unwrap();
        let req = "PING".as_bytes();
        stream.write(req).unwrap();
        let mut buffer = String::new();
        stream.read_to_string(&mut buffer).unwrap();
        assert_eq!("+PONG\r\n", buffer);
    }
  • 文字列スライスを as_bytes() してから TcpStream に書き込む
  • 取得結果を buffer に書き込む
  • redisのRESP protocol の仕様に従ったデータが取得できる

http://mogile.web.fc2.com/redis/docs/reference/protocol-spec/index.html#resp-simple-strings

ta1kt0meta1kt0me

Errorのケースのassertionどうやるのがいいんだろ。
コネクション接続に失敗した場合は expect とかで扱った方がいい気もする。

ta1kt0meta1kt0me

テストの一括実行でportのコンフリクトが発生していたので解決策を探す。

https://zenn.dev/link/comments/711484f1d05381 のテストケースで考える。
https://github.com/rust-lang-nursery/rust-cookbook/issues/123#issuecomment-651159230
ポート0を指定した場合にランダムなポートが割り当てられる振る舞いを利用することでコンフリクトを回避できそう。
また main 関数を呼び出すのでなく、TcpListener を引数にしたサーバー起動関数を用意することでテストと実装で利用するポートを切り替えられるようにする。

    fn test_listener() -> TcpListener {
        let addr = "127.0.0.1:0";
        TcpListener::bind(addr).unwrap()
    }

    #[test]
    fn server_connect_success() {
        let listener = test_listener();
        let addr = listener.local_addr().unwrap();
        thread::spawn(|| {
            start_server(listener);
        });
        let mut stream = TcpStream::connect_timeout(&addr, Duration::from_secs(2)).unwrap();
        let mut buffer = String::new();
        let res = stream.read_to_string(&mut buffer).unwrap();
        assert_eq!(0, res);
    }

ta1kt0meta1kt0me

https://doc.rust-jp.rs/book-ja/ch20-00-final-project-a-web-server.html
これを踏まえてマルチスレッドの実装を組み込んだ時のThreadPoolに関するテスト。

#[cfg(test)]
mod tests {
    use std::{
        sync::mpsc::{Receiver, Sender},
        time::Duration,
    };

    use super::*;

    #[test]
    fn test_worker_new() {
        let (sender, receiver): (Sender<Message>, Receiver<Message>) = mpsc::channel();
        let worker = Worker::new(1, Arc::new(Mutex::new(receiver)));
        assert_eq!(worker.id, 1);

        let counter = Arc::new(Mutex::<i32>::new(0));
        let actual = counter.clone();
        let func = move || *counter.lock().unwrap() += 1;

        let job = Box::new(func);
        sender.send(Message::NewJob(job)).unwrap();

        thread::sleep(Duration::from_millis(1));
        assert!(actual.lock().unwrap().eq(&1));

        sender.send(Message::Terminate).unwrap();
    }

    #[test]
    fn test_thread_pool_new() {
        let thread_pool = ThreadPool::new(2);
        assert_eq!(thread_pool.workers.len(), 2);

        let counter_one = Arc::new(Mutex::<i32>::new(0));
        let counter_two = counter_one.clone();
        let actual = counter_one.clone();
        let func_one = move || {
            thread::sleep(Duration::from_millis(5));
            *counter_one.lock().unwrap() += 1;
        };
        let func_two = move || {
            thread::sleep(Duration::from_millis(1));
            *counter_two.lock().unwrap() += 1;
        };
        thread_pool.execute(func_one);
        thread_pool.execute(func_two);

        thread::sleep(Duration::from_millis(10));
        assert!(actual.lock().unwrap().eq(&2));
    }
}
ta1kt0meta1kt0me

各種コマンドを実装してみる。

  • echo
  • set(PXオプション)
  • get

データベースをHashMapとして実装した場合以下のようなテストを実装してみる。

#[cfg(test)]
mod tests {
    use std::{thread, time::Duration};

    use super::*;

    #[test]
    fn test_new_with_no_arg() {
        let database = Arc::new(Store::new());
        let command = Command::new("sample", None, database);
        assert_eq!(command.instruction, "sample");
        assert_eq!(command.arguments.len(), 0);
    }

    #[test]
    fn test_new_with_an_arg() {
        let database = Arc::new(Store::new());
        let command = Command::new("sample", Some("$1\r\na"), database);
        assert_eq!(command.instruction, "sample");
        assert_eq!(command.arguments.len(), 1);
        assert_eq!(command.arguments[0].value, "a");
    }

    #[test]
    fn test_new_with_args() {
        let database = Arc::new(Store::new());
        let command = Command::new("sample", Some("$1\r\na\r\n$1\r\nb\r\n"), database);
        assert_eq!(command.instruction, "sample");
        assert_eq!(command.arguments.len(), 2);
        assert_eq!(command.arguments[0].value, "a");
        assert_eq!(command.arguments[1].value, "b");
    }

    #[test]
    fn test_execute_as_ping() {
        let database = Arc::new(Store::new());
        let command = Command::new("PING", None, database);
        let result = command.execute().unwrap();
        assert_eq!("+PONG\r\n", result)
    }

    #[test]
    fn test_execute_as_double_ping() {
        let database = Arc::new(Store::new());
        let command = Command::new("PING\\nPING", None, database);
        let result = command.execute().unwrap();
        assert_eq!("+PONG\r\n+PONG\r\n", result)
    }

    #[test]
    fn test_execute_as_command() {
        let database = Arc::new(Store::new());
        let command = Command::new("COMMAND", None, database);
        let result = command.execute();
        assert_eq!(None, result)
    }

    #[test]
    fn test_execute_as_shutdown() {
        let database = Arc::new(Store::new());
        let command = Command::new("SHUTDOWN", None, database);
        let result = command.execute();
        assert_eq!(None, result);
    }

    #[test]
    fn test_execute_as_any() {
        let database = Arc::new(Store::new());
        let command = Command::new("FOO", None, database);
        let result = command.execute().unwrap();
        assert_eq!("+OK\r\n", result)
    }

    #[test]
    fn test_execute_as_echo() {
        let database = Arc::new(Store::new());
        let command = Command::new("ECHO", Some("$1\r\na\r\n"), database);
        let result = command.execute().unwrap();
        assert_eq!("+a\r\n", result)
    }

    #[test]
    fn test_execute_as_set() {
        let database = Arc::new(Store::new());
        let cloned = database.clone();
        let command = Command::new("SET", Some("$3\r\nFOO\r\n$2\r\nab\r\n"), cloned);
        let result = command.execute().unwrap();
        assert_eq!("+OK\r\n", result);
        let value = database.get("FOO").unwrap();
        assert_eq!("ab", value);
    }

    #[test]
    fn test_execute_as_set_with_px() {
        let database = Arc::new(Store::new());
        let cloned = database.clone();
        let command = Command::new(
            "SET",
            Some("$7\r\nFOO\r\n$2\r\nab\r\n$2\r\nPX\r\n$1\r\n1\r\n"),
            cloned,
        );
        let result = command.execute().unwrap();
        assert_eq!("+OK\r\n", result);
        thread::sleep(Duration::from_millis(5));
        let value = database.get("FOO");
        assert_eq!(None, value);
    }

    #[test]
    fn test_execute_as_get() {
        let database = Arc::new(Store::new());
        database.set("FOO", "bar", None);
        let command = Command::new("GET", Some("$3\r\nFOO\r\n"), database);
        let result = command.execute().unwrap();
        assert_eq!("+bar\r\n", result);
    }

    #[test]
    fn test_execute_as_get_with_expired() {
        let database = Arc::new(Store::new());
        let asserted_database = database.clone();
        database.set("FOO", "bar", Some(0));
        let command = Command::new("GET", Some("$3\r\nFOO\r\n"), database);
        let result = command.execute().unwrap();
        assert_eq!("$-1\r\n", result);
        assert_eq!(false, asserted_database.exists("FOO"))
    }
}
ta1kt0meta1kt0me

TcpStreamをクライアントとしてサーバーを経由してコマンドを発行した場合のテストケース。

#[cfg(test)]
mod tests {
    use std::{
        io::{Read, Write},
        net::{SocketAddr, TcpStream},
        thread,
        time::Duration,
    };

    use super::*;

    fn test_listener() -> TcpListener {
        let addr = "127.0.0.1:0";
        TcpListener::bind(addr).unwrap()
    }

    fn start_test_server() -> SocketAddr {
        let listener = test_listener();
        let addr = listener.local_addr().unwrap();
        thread::spawn(|| start_server(listener));

        addr
    }

    fn send_request(stream: &mut TcpStream, command: &str) -> String {
        let req = command.as_bytes();
        stream.write(req).unwrap();
        stream.flush().unwrap();

        let mut buffer = [0; 1024];
        let content_length = stream.read(&mut buffer).unwrap();
        String::from_utf8_lossy(&buffer[0..content_length]).to_string()
    }

    #[test]
    fn server_connect_success() {
        let addr = start_test_server();
        let mut stream = TcpStream::connect(&addr).unwrap();
        let result = send_request(&mut stream, "*0\r\n");
        assert_eq!(0, result.len());
    }

    #[test]
    fn server_connect_failure() {
        let addr = start_test_server();
        let invalid_addr: SocketAddr = format!("{}:6379", addr.ip()).parse().unwrap();
        let stream = TcpStream::connect(&invalid_addr);
        let err = stream.unwrap_err();
        assert_eq!("Connection refused (os error 61)", err.to_string());
    }

    #[test]
    fn ping() {
        let addr = start_test_server();
        let mut stream = TcpStream::connect(&addr).unwrap();
        let command = "*1\r\n$4\r\nPING\r\n";
        let result = send_request(&mut stream, command);
        assert_eq!("+PONG\r\n", result);
    }

    #[test]
    fn double_ping() {
        let addr = start_test_server();
        let mut stream = TcpStream::connect(&addr).unwrap();
        let command = "*1\r\n$10\r\nPING\\nPING\r\n";
        let result = send_request(&mut stream, command);
        assert_eq!("+PONG\r\n+PONG\r\n", result);
    }

    #[test]
    fn echo() {
        let addr = start_test_server();
        let mut stream = TcpStream::connect(&addr).unwrap();
        let command = "*2\r\n$4\r\nECHO\r\n$3\r\nhey\r\n";
        let result = send_request(&mut stream, command);
        assert_eq!("+hey\r\n", result);

        let mut stream = TcpStream::connect(&addr).unwrap();
        let command = "*2\r\n$4\r\nECHO\r\n$5\r\nhello\r\n";
        let result = send_request(&mut stream, command);
        assert_eq!("+hello\r\n", result);
    }

    #[test]
    fn set_and_get() {
        let addr = start_test_server();
        let mut stream = TcpStream::connect(&addr).unwrap();
        let command = "*3\r\n$3\r\nSET\r\n$3\r\nFOO\r\n$1\r\n1\r\n";
        let result = send_request(&mut stream, command);
        assert_eq!("+OK\r\n", result);

        let mut stream = TcpStream::connect(&addr).unwrap();
        let command = "*2\r\n$3\r\nGET\r\n$3\r\nFOO\r\n";
        let result = send_request(&mut stream, command);
        assert_eq!("+1\r\n", result);
    }

    #[test]
    fn set_and_get_with_expiration() {
        let addr = start_test_server();
        let mut stream = TcpStream::connect(&addr).unwrap();
        let command = "*3\r\n$3\r\nSET\r\n$3\r\nFOO\r\n$1\r\n1\r\n$2\r\nPX\r\n$1\r\n5\r\n";
        let result = send_request(&mut stream, command);
        assert_eq!("+OK\r\n", result);

        let mut stream = TcpStream::connect(&addr).unwrap();
        let command = "*2\r\n$3\r\nGET\r\n$3\r\nFOO\r\n";
        let result = send_request(&mut stream, command);
        assert_eq!("+1\r\n", result);

        thread::sleep(Duration::from_millis(5));
        let mut stream = TcpStream::connect(&addr).unwrap();
        let command = "*2\r\n$3\r\nGET\r\n$3\r\nFOO\r\n";
        let result = send_request(&mut stream, command);
        assert_eq!("$-1\r\n", result);
    }
}
ta1kt0meta1kt0me

cargo clippy にすごく助けられるけど、結構無駄なところはありそう。
コマンドを表す構造体CommandがデータストアとなるHashMapをfieldとして保持してしまっているのは微妙かも...。 Arc<Mutex<>>> はまだあんまりよくわかっていない。