Open6
The Rust Programming LanguageのマルチスレッドのWebサーバのテスト
The Rust Programming Language の最後のプロジェクト マルチスレッドのWebサーバを構築する
でテストを書きながらRust勉強してみる。
写経しつつ変更を加えるたびにテストを書いてみたけど、最終的にこうなった、という感じの記録を残す。
ところどころもうちょっと上手いやり方があるのかもしれない...
src/lib.rs
#[cfg(test)]
mod tests {
use super::*;
use std::{io::Write, panic, time::Duration};
#[test]
fn test_thread_pool_new() {
let thread_pool = ThreadPool::new(1);
let workers = &thread_pool.workers;
assert_eq!(workers.len(), 1);
assert_eq!(workers[0].id, 0);
}
#[test]
fn test_thread_pool_new_with_size_zero() {
// https://stackoverflow.com/a/42649833
let result = panic::catch_unwind(|| ThreadPool::new(0));
assert!(result.is_err());
}
#[test]
fn test_thread_pool_execute() {
let thread_pool = ThreadPool::new(1);
let output: Vec<u8> = Vec::new();
let output_wrapper = Arc::new(Mutex::new(output));
let actual = output_wrapper.clone();
thread_pool.execute(move || {
write!(output_wrapper.lock().unwrap(), "{}", "I'm Foo").unwrap();
});
// wait for execution of workers
thread::sleep(Duration::from_millis(1));
assert_eq!(*actual.lock().unwrap(), b"I'm Foo");
}
#[test]
fn test_worker_new() {
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let worker = Worker::new(1, receiver);
assert_eq!(worker.id, 1);
// https://stackoverflow.com/a/28370712
let output: Vec<u8> = Vec::new();
let output_wrapper = Arc::new(Mutex::new(output));
let actual = output_wrapper.clone();
let func = move || write!(output_wrapper.lock().unwrap(), "{}", "I'm Foo").unwrap();
let job = Message::NewJob(Box::new(func));
sender.send(job).unwrap();
thread::sleep(Duration::from_millis(1));
assert_eq!(*actual.lock().unwrap(), b"I'm Foo");
let job = Message::Terminate;
sender.send(job).unwrap();
assert_eq!(worker.thread.unwrap().join().unwrap(), ());
}
}
main.rsはテストをしやすくするために一部コードを書き換えている。
main関数の一部をstart_serverという関数に切り出す。
src/bin/main.rs
fn main() {
let listener = TcpListener::bind("127.0.0.1:7878").unwrap();
start_server(listener);
}
// テキストにはない関数
// テスト時にserverを起動するportを調整したかったので外部から渡せるインターフェースにする
fn start_server(listener: TcpListener) {
let pool = ThreadPool::new(4);
for stream in listener.incoming() {
let stream = stream.unwrap();
pool.execute(|| {
handle_connection(stream);
});
}
}
start_serverを用意したのはこの辺りをやりたかったから。
src/bin/main.rs
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
// port 0 を指定することで起動時のportが動的に切り替わる挙動を利用
fn local_listener() -> TcpListener {
TcpListener::bind("127.0.0.1:0").unwrap()
}
// レスポンスを status code, header, bodyに分割する
fn parse_response(buffer: &[u8]) -> [String; 3] {
let raw_response = String::from_utf8_lossy(&buffer).into_owned();
let mut response = raw_response.split("\r\n");
[(); 3].map(|()| response.next().unwrap().to_string())
}
#[test]
fn success_root_response() {
let listener = local_listener();
let addr = listener.local_addr().unwrap();
thread::spawn(|| start_server(listener));
let mut stream = TcpStream::connect(&addr).unwrap();
let get = b"GET / HTTP/1.1\r\n";
stream.write(get).unwrap();
stream.flush().unwrap();
let mut buffer = [0; 1024];
stream.read(&mut buffer).unwrap();
let [status, _, body] = parse_response(&buffer);
assert_eq!("HTTP/1.1 200 OK", status);
assert_eq!(true, format!("{body}").contains("<h1>Hello!</h1>"))
}
#[test]
fn success_sleep_response() {
let listener = local_listener();
let addr = listener.local_addr().unwrap();
thread::spawn(|| start_server(listener));
let mut stream = TcpStream::connect(&addr).unwrap();
let get = b"GET /sleep HTTP/1.1\r\n";
stream.write(get).unwrap();
stream.flush().unwrap();
let mut buffer = [0; 1024];
stream.read(&mut buffer).unwrap();
let [status, _, body] = parse_response(&buffer);
assert_eq!("HTTP/1.1 200 OK", status);
assert_eq!(true, format!("{body}").contains("<h1>Hello!</h1>"))
}
#[test]
fn not_found_response() {
let listener = local_listener();
let addr = listener.local_addr().unwrap();
thread::spawn(|| start_server(listener));
let mut stream = TcpStream::connect(&addr).unwrap();
let get = b"GET /not_found HTTP/1.1\r\n";
stream.write(get).unwrap();
stream.flush().unwrap();
let mut buffer = [0; 1024];
stream.read(&mut buffer).unwrap();
let [status, _, body] = parse_response(&buffer);
assert_eq!("HTTP/1.1 404 NOT FOUND", status);
assert_eq!(true, format!("{body}").contains("<h1>Oops!</h1>"))
}
}
テストをリファクタリングしてみる。
headersはHashMapとかにすればいいんだろうけど手抜き。
src/bin/main.rs
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
struct Response {
status_code: String,
headers: String,
body: String,
}
struct Requester {
stream: TcpStream,
}
impl Requester {
fn new(stream: TcpStream) -> Self {
Self { stream }
}
fn get(&mut self, data: String) -> Response {
self.stream.write(data.as_bytes()).unwrap();
self.stream.flush().unwrap();
let mut buffer = [0; 1024];
self.stream.read(&mut buffer).unwrap();
Self::parse_response(&buffer)
}
// レスポンスを status code, header, bodyに分割する
fn parse_response(buffer: &[u8]) -> Response {
let raw_response = String::from_utf8_lossy(&buffer).into_owned();
let mut response = raw_response.split("\r\n");
let [status_code, headers, body] =
[(); 3].map(|()| response.next().unwrap().to_string());
Response {
status_code,
headers,
body,
}
}
}
// port 0 を指定することで起動時のportが動的に切り替わる挙動を利用
fn local_listener() -> TcpListener {
TcpListener::bind("127.0.0.1:0").unwrap()
}
#[test]
fn success_root_response() {
let listener = local_listener();
let addr = listener.local_addr().unwrap();
thread::spawn(|| start_server(listener));
let stream = TcpStream::connect(&addr).unwrap();
let response = Requester::new(stream).get(String::from("GET / HTTP/1.1\r\n"));
assert_eq!("HTTP/1.1 200 OK", response.status_code);
assert_eq!("", response.headers);
assert_eq!(true, response.body.as_str().contains("<h1>Hello!</h1>"))
}
#[test]
fn success_sleep_response() {
let listener = local_listener();
let addr = listener.local_addr().unwrap();
thread::spawn(|| start_server(listener));
let stream = TcpStream::connect(&addr).unwrap();
let response = Requester::new(stream).get(String::from("GET /sleep HTTP/1.1\r\n"));
assert_eq!("HTTP/1.1 200 OK", response.status_code);
assert_eq!("", response.headers);
assert_eq!(true, response.body.as_str().contains("<h1>Hello!</h1>"))
}
#[test]
fn not_found_response() {
let listener = local_listener();
let addr = listener.local_addr().unwrap();
thread::spawn(|| start_server(listener));
let stream = TcpStream::connect(&addr).unwrap();
let response = Requester::new(stream).get(String::from("GET /not_found HTTP/1.1\r\n"));
assert_eq!("HTTP/1.1 404 NOT FOUND", response.status_code);
assert_eq!("", response.headers);
assert_eq!(true, response.body.as_str().contains("<h1>Oops!</h1>"))
}
}