Open6

The Rust Programming LanguageのマルチスレッドのWebサーバのテスト

ta1kt0meta1kt0me

The Rust Programming Language の最後のプロジェクト マルチスレッドのWebサーバを構築する でテストを書きながらRust勉強してみる。

ta1kt0meta1kt0me

写経しつつ変更を加えるたびにテストを書いてみたけど、最終的にこうなった、という感じの記録を残す。
ところどころもうちょっと上手いやり方があるのかもしれない...

ta1kt0meta1kt0me
src/lib.rs
#[cfg(test)]
mod tests {
    use super::*;
    use std::{io::Write, panic, time::Duration};

    #[test]
    fn test_thread_pool_new() {
        let thread_pool = ThreadPool::new(1);
        let workers = &thread_pool.workers;
        assert_eq!(workers.len(), 1);
        assert_eq!(workers[0].id, 0);
    }

    #[test]
    fn test_thread_pool_new_with_size_zero() {
        // https://stackoverflow.com/a/42649833
        let result = panic::catch_unwind(|| ThreadPool::new(0));
        assert!(result.is_err());
    }

    #[test]
    fn test_thread_pool_execute() {
        let thread_pool = ThreadPool::new(1);

        let output: Vec<u8> = Vec::new();
        let output_wrapper = Arc::new(Mutex::new(output));
        let actual = output_wrapper.clone();

        thread_pool.execute(move || {
            write!(output_wrapper.lock().unwrap(), "{}", "I'm Foo").unwrap();
        });

        // wait for execution of workers
        thread::sleep(Duration::from_millis(1));

        assert_eq!(*actual.lock().unwrap(), b"I'm Foo");
    }

    #[test]
    fn test_worker_new() {
        let (sender, receiver) = mpsc::channel();
        let receiver = Arc::new(Mutex::new(receiver));
        let worker = Worker::new(1, receiver);
        assert_eq!(worker.id, 1);

        // https://stackoverflow.com/a/28370712
        let output: Vec<u8> = Vec::new();
        let output_wrapper = Arc::new(Mutex::new(output));
        let actual = output_wrapper.clone();

        let func = move || write!(output_wrapper.lock().unwrap(), "{}", "I'm Foo").unwrap();
        let job = Message::NewJob(Box::new(func));
        sender.send(job).unwrap();

        thread::sleep(Duration::from_millis(1));

        assert_eq!(*actual.lock().unwrap(), b"I'm Foo");

        let job = Message::Terminate;
        sender.send(job).unwrap();
        assert_eq!(worker.thread.unwrap().join().unwrap(), ());
    }
}
ta1kt0meta1kt0me

main.rsはテストをしやすくするために一部コードを書き換えている。
main関数の一部をstart_serverという関数に切り出す。

src/bin/main.rs
fn main() {
    let listener = TcpListener::bind("127.0.0.1:7878").unwrap();
    start_server(listener);
}

// テキストにはない関数
// テスト時にserverを起動するportを調整したかったので外部から渡せるインターフェースにする
fn start_server(listener: TcpListener) {
    let pool = ThreadPool::new(4);

    for stream in listener.incoming() {
        let stream = stream.unwrap();

        pool.execute(|| {
            handle_connection(stream);
        });
    }
}
ta1kt0meta1kt0me

https://github.com/rust-lang-nursery/rust-cookbook/issues/123#issuecomment-651159230
start_serverを用意したのはこの辺りをやりたかったから。

src/bin/main.rs
#[cfg(test)]
mod tests {
    use super::*;
    use std::thread;

    // port 0 を指定することで起動時のportが動的に切り替わる挙動を利用
    fn local_listener() -> TcpListener {
        TcpListener::bind("127.0.0.1:0").unwrap()
    }

    // レスポンスを status code, header, bodyに分割する
    fn parse_response(buffer: &[u8]) -> [String; 3] {
        let raw_response = String::from_utf8_lossy(&buffer).into_owned();
        let mut response = raw_response.split("\r\n");
        [(); 3].map(|()| response.next().unwrap().to_string())
    }

    #[test]
    fn success_root_response() {
        let listener = local_listener();
        let addr = listener.local_addr().unwrap();
        thread::spawn(|| start_server(listener));

        let mut stream = TcpStream::connect(&addr).unwrap();
        let get = b"GET / HTTP/1.1\r\n";
        stream.write(get).unwrap();
        stream.flush().unwrap();

        let mut buffer = [0; 1024];
        stream.read(&mut buffer).unwrap();
        let [status, _, body] = parse_response(&buffer);
        assert_eq!("HTTP/1.1 200 OK", status);
        assert_eq!(true, format!("{body}").contains("<h1>Hello!</h1>"))
    }

    #[test]
    fn success_sleep_response() {
        let listener = local_listener();
        let addr = listener.local_addr().unwrap();
        thread::spawn(|| start_server(listener));

        let mut stream = TcpStream::connect(&addr).unwrap();
        let get = b"GET /sleep HTTP/1.1\r\n";
        stream.write(get).unwrap();
        stream.flush().unwrap();

        let mut buffer = [0; 1024];
        stream.read(&mut buffer).unwrap();
        let [status, _, body] = parse_response(&buffer);
        assert_eq!("HTTP/1.1 200 OK", status);
        assert_eq!(true, format!("{body}").contains("<h1>Hello!</h1>"))
    }

    #[test]
    fn not_found_response() {
        let listener = local_listener();
        let addr = listener.local_addr().unwrap();
        thread::spawn(|| start_server(listener));
        let mut stream = TcpStream::connect(&addr).unwrap();
        let get = b"GET /not_found HTTP/1.1\r\n";
        stream.write(get).unwrap();
        stream.flush().unwrap();

        let mut buffer = [0; 1024];
        stream.read(&mut buffer).unwrap();
        let [status, _, body] = parse_response(&buffer);
        assert_eq!("HTTP/1.1 404 NOT FOUND", status);
        assert_eq!(true, format!("{body}").contains("<h1>Oops!</h1>"))
    }
}
ta1kt0meta1kt0me

テストをリファクタリングしてみる。
headersはHashMapとかにすればいいんだろうけど手抜き。

src/bin/main.rs
#[cfg(test)]
mod tests {
    use super::*;
    use std::thread;

    struct Response {
        status_code: String,
        headers: String,
        body: String,
    }

    struct Requester {
        stream: TcpStream,
    }

    impl Requester {
        fn new(stream: TcpStream) -> Self {
            Self { stream }
        }

        fn get(&mut self, data: String) -> Response {
            self.stream.write(data.as_bytes()).unwrap();
            self.stream.flush().unwrap();

            let mut buffer = [0; 1024];
            self.stream.read(&mut buffer).unwrap();
            Self::parse_response(&buffer)
        }

        // レスポンスを status code, header, bodyに分割する
        fn parse_response(buffer: &[u8]) -> Response {
            let raw_response = String::from_utf8_lossy(&buffer).into_owned();
            let mut response = raw_response.split("\r\n");
            let [status_code, headers, body] =
                [(); 3].map(|()| response.next().unwrap().to_string());

            Response {
                status_code,
                headers,
                body,
            }
        }
    }

    // port 0 を指定することで起動時のportが動的に切り替わる挙動を利用
    fn local_listener() -> TcpListener {
        TcpListener::bind("127.0.0.1:0").unwrap()
    }

    #[test]
    fn success_root_response() {
        let listener = local_listener();
        let addr = listener.local_addr().unwrap();
        thread::spawn(|| start_server(listener));

        let stream = TcpStream::connect(&addr).unwrap();
        let response = Requester::new(stream).get(String::from("GET / HTTP/1.1\r\n"));
        assert_eq!("HTTP/1.1 200 OK", response.status_code);
        assert_eq!("", response.headers);
        assert_eq!(true, response.body.as_str().contains("<h1>Hello!</h1>"))
    }

    #[test]
    fn success_sleep_response() {
        let listener = local_listener();
        let addr = listener.local_addr().unwrap();
        thread::spawn(|| start_server(listener));

        let stream = TcpStream::connect(&addr).unwrap();
        let response = Requester::new(stream).get(String::from("GET /sleep HTTP/1.1\r\n"));
        assert_eq!("HTTP/1.1 200 OK", response.status_code);
        assert_eq!("", response.headers);
        assert_eq!(true, response.body.as_str().contains("<h1>Hello!</h1>"))
    }

    #[test]
    fn not_found_response() {
        let listener = local_listener();
        let addr = listener.local_addr().unwrap();
        thread::spawn(|| start_server(listener));

        let stream = TcpStream::connect(&addr).unwrap();
        let response = Requester::new(stream).get(String::from("GET /not_found HTTP/1.1\r\n"));
        assert_eq!("HTTP/1.1 404 NOT FOUND", response.status_code);
        assert_eq!("", response.headers);
        assert_eq!(true, response.body.as_str().contains("<h1>Oops!</h1>"))
    }
}