Rust TCP Server TDD ノート
Rustで mini Redis Server を作ってみるとき、TDD したくなったのでやり方を考えてみる
最初のステップはこんな感じのコード。
この辺を参考にしている。
fn main() {
let listner = TcpListener::bind("127.0.0.1:6379").unwrap();
for stream in listner.incoming() {
match stream {
Ok(stream) => {
do_something();
}
Err(e) => {
error_handling();
}
}
}
}
接続が成功した場合
fn server_connect_success() {
thread::spawn(|| {
main();
});
let remote = "127.0.0.1:6379".parse().unwrap();
let mut stream = TcpStream::connect_timeout(&remote, Duration::from_secs(2)).unwrap();
let mut buffer = String::new();
let res = stream.read_to_string(&mut buffer).unwrap();
assert_eq!(0, res);
}
- Threadを生成してサーバーを起動する
- TcpStream::connect_timeout でサーバーに接続する
- とりあえず何か返ってくるならbufferをチェックする
接続が失敗した場合
ひとまずportが間違ってるとか。
fn server_connect_failure() {
thread::spawn(|| {
main();
});
let remote = "127.0.0.1:6378".parse().unwrap();
let stream = TcpStream::connect_timeout(&remote, Duration::from_secs(2));
let err = stream.unwrap_err();
assert_eq!("Connection refused (os error 61)", err.to_string());
}
もっといいやり方ありそう...
- assertが文字列比較になっているので違うやり方にしたい
- そもそもmain関数のmatchのError処理ではない
listner.incoming()
がErrorになるようなmockの仕方とか何かあったりするかな...
PING/PONGの場合
fn ping() {
thread::spawn(|| {
main();
});
let remote = "127.0.0.1:6379".parse().unwrap();
let mut stream = TcpStream::connect_timeout(&remote, Duration::from_secs(2)).unwrap();
let req = "PING".as_bytes();
stream.write(req).unwrap();
let mut buffer = String::new();
stream.read_to_string(&mut buffer).unwrap();
assert_eq!("+PONG\r\n", buffer);
}
- 文字列スライスを
as_bytes()
してから TcpStream に書き込む - 取得結果を buffer に書き込む
- redisのRESP protocol の仕様に従ったデータが取得できる
クライアントはとりあえず動いたから connect_timeout
使ってる感あるけど、もっと良いアプローチあるかもしれない。
Errorのケースのassertionどうやるのがいいんだろ。
コネクション接続に失敗した場合は expect とかで扱った方がいい気もする。
テストの一括実行でportのコンフリクトが発生していたので解決策を探す。
https://zenn.dev/link/comments/711484f1d05381 のテストケースで考える。
ポート0を指定した場合にランダムなポートが割り当てられる振る舞いを利用することでコンフリクトを回避できそう。
また main
関数を呼び出すのでなく、TcpListener
を引数にしたサーバー起動関数を用意することでテストと実装で利用するポートを切り替えられるようにする。
fn test_listener() -> TcpListener {
let addr = "127.0.0.1:0";
TcpListener::bind(addr).unwrap()
}
#[test]
fn server_connect_success() {
let listener = test_listener();
let addr = listener.local_addr().unwrap();
thread::spawn(|| {
start_server(listener);
});
let mut stream = TcpStream::connect_timeout(&addr, Duration::from_secs(2)).unwrap();
let mut buffer = String::new();
let res = stream.read_to_string(&mut buffer).unwrap();
assert_eq!(0, res);
}
The Book のマルチスレッドサーバーを理解した方が進めやすそうだったのでテストコードを書いてみた。
これを踏まえてマルチスレッドの実装を組み込んだ時のThreadPoolに関するテスト。
#[cfg(test)]
mod tests {
use std::{
sync::mpsc::{Receiver, Sender},
time::Duration,
};
use super::*;
#[test]
fn test_worker_new() {
let (sender, receiver): (Sender<Message>, Receiver<Message>) = mpsc::channel();
let worker = Worker::new(1, Arc::new(Mutex::new(receiver)));
assert_eq!(worker.id, 1);
let counter = Arc::new(Mutex::<i32>::new(0));
let actual = counter.clone();
let func = move || *counter.lock().unwrap() += 1;
let job = Box::new(func);
sender.send(Message::NewJob(job)).unwrap();
thread::sleep(Duration::from_millis(1));
assert!(actual.lock().unwrap().eq(&1));
sender.send(Message::Terminate).unwrap();
}
#[test]
fn test_thread_pool_new() {
let thread_pool = ThreadPool::new(2);
assert_eq!(thread_pool.workers.len(), 2);
let counter_one = Arc::new(Mutex::<i32>::new(0));
let counter_two = counter_one.clone();
let actual = counter_one.clone();
let func_one = move || {
thread::sleep(Duration::from_millis(5));
*counter_one.lock().unwrap() += 1;
};
let func_two = move || {
thread::sleep(Duration::from_millis(1));
*counter_two.lock().unwrap() += 1;
};
thread_pool.execute(func_one);
thread_pool.execute(func_two);
thread::sleep(Duration::from_millis(10));
assert!(actual.lock().unwrap().eq(&2));
}
}
各種コマンドを実装してみる。
- echo
- set(PXオプション)
- get
データベースをHashMapとして実装した場合以下のようなテストを実装してみる。
#[cfg(test)]
mod tests {
use std::{thread, time::Duration};
use super::*;
#[test]
fn test_new_with_no_arg() {
let database = Arc::new(Store::new());
let command = Command::new("sample", None, database);
assert_eq!(command.instruction, "sample");
assert_eq!(command.arguments.len(), 0);
}
#[test]
fn test_new_with_an_arg() {
let database = Arc::new(Store::new());
let command = Command::new("sample", Some("$1\r\na"), database);
assert_eq!(command.instruction, "sample");
assert_eq!(command.arguments.len(), 1);
assert_eq!(command.arguments[0].value, "a");
}
#[test]
fn test_new_with_args() {
let database = Arc::new(Store::new());
let command = Command::new("sample", Some("$1\r\na\r\n$1\r\nb\r\n"), database);
assert_eq!(command.instruction, "sample");
assert_eq!(command.arguments.len(), 2);
assert_eq!(command.arguments[0].value, "a");
assert_eq!(command.arguments[1].value, "b");
}
#[test]
fn test_execute_as_ping() {
let database = Arc::new(Store::new());
let command = Command::new("PING", None, database);
let result = command.execute().unwrap();
assert_eq!("+PONG\r\n", result)
}
#[test]
fn test_execute_as_double_ping() {
let database = Arc::new(Store::new());
let command = Command::new("PING\\nPING", None, database);
let result = command.execute().unwrap();
assert_eq!("+PONG\r\n+PONG\r\n", result)
}
#[test]
fn test_execute_as_command() {
let database = Arc::new(Store::new());
let command = Command::new("COMMAND", None, database);
let result = command.execute();
assert_eq!(None, result)
}
#[test]
fn test_execute_as_shutdown() {
let database = Arc::new(Store::new());
let command = Command::new("SHUTDOWN", None, database);
let result = command.execute();
assert_eq!(None, result);
}
#[test]
fn test_execute_as_any() {
let database = Arc::new(Store::new());
let command = Command::new("FOO", None, database);
let result = command.execute().unwrap();
assert_eq!("+OK\r\n", result)
}
#[test]
fn test_execute_as_echo() {
let database = Arc::new(Store::new());
let command = Command::new("ECHO", Some("$1\r\na\r\n"), database);
let result = command.execute().unwrap();
assert_eq!("+a\r\n", result)
}
#[test]
fn test_execute_as_set() {
let database = Arc::new(Store::new());
let cloned = database.clone();
let command = Command::new("SET", Some("$3\r\nFOO\r\n$2\r\nab\r\n"), cloned);
let result = command.execute().unwrap();
assert_eq!("+OK\r\n", result);
let value = database.get("FOO").unwrap();
assert_eq!("ab", value);
}
#[test]
fn test_execute_as_set_with_px() {
let database = Arc::new(Store::new());
let cloned = database.clone();
let command = Command::new(
"SET",
Some("$7\r\nFOO\r\n$2\r\nab\r\n$2\r\nPX\r\n$1\r\n1\r\n"),
cloned,
);
let result = command.execute().unwrap();
assert_eq!("+OK\r\n", result);
thread::sleep(Duration::from_millis(5));
let value = database.get("FOO");
assert_eq!(None, value);
}
#[test]
fn test_execute_as_get() {
let database = Arc::new(Store::new());
database.set("FOO", "bar", None);
let command = Command::new("GET", Some("$3\r\nFOO\r\n"), database);
let result = command.execute().unwrap();
assert_eq!("+bar\r\n", result);
}
#[test]
fn test_execute_as_get_with_expired() {
let database = Arc::new(Store::new());
let asserted_database = database.clone();
database.set("FOO", "bar", Some(0));
let command = Command::new("GET", Some("$3\r\nFOO\r\n"), database);
let result = command.execute().unwrap();
assert_eq!("$-1\r\n", result);
assert_eq!(false, asserted_database.exists("FOO"))
}
}
TcpStreamをクライアントとしてサーバーを経由してコマンドを発行した場合のテストケース。
#[cfg(test)]
mod tests {
use std::{
io::{Read, Write},
net::{SocketAddr, TcpStream},
thread,
time::Duration,
};
use super::*;
fn test_listener() -> TcpListener {
let addr = "127.0.0.1:0";
TcpListener::bind(addr).unwrap()
}
fn start_test_server() -> SocketAddr {
let listener = test_listener();
let addr = listener.local_addr().unwrap();
thread::spawn(|| start_server(listener));
addr
}
fn send_request(stream: &mut TcpStream, command: &str) -> String {
let req = command.as_bytes();
stream.write(req).unwrap();
stream.flush().unwrap();
let mut buffer = [0; 1024];
let content_length = stream.read(&mut buffer).unwrap();
String::from_utf8_lossy(&buffer[0..content_length]).to_string()
}
#[test]
fn server_connect_success() {
let addr = start_test_server();
let mut stream = TcpStream::connect(&addr).unwrap();
let result = send_request(&mut stream, "*0\r\n");
assert_eq!(0, result.len());
}
#[test]
fn server_connect_failure() {
let addr = start_test_server();
let invalid_addr: SocketAddr = format!("{}:6379", addr.ip()).parse().unwrap();
let stream = TcpStream::connect(&invalid_addr);
let err = stream.unwrap_err();
assert_eq!("Connection refused (os error 61)", err.to_string());
}
#[test]
fn ping() {
let addr = start_test_server();
let mut stream = TcpStream::connect(&addr).unwrap();
let command = "*1\r\n$4\r\nPING\r\n";
let result = send_request(&mut stream, command);
assert_eq!("+PONG\r\n", result);
}
#[test]
fn double_ping() {
let addr = start_test_server();
let mut stream = TcpStream::connect(&addr).unwrap();
let command = "*1\r\n$10\r\nPING\\nPING\r\n";
let result = send_request(&mut stream, command);
assert_eq!("+PONG\r\n+PONG\r\n", result);
}
#[test]
fn echo() {
let addr = start_test_server();
let mut stream = TcpStream::connect(&addr).unwrap();
let command = "*2\r\n$4\r\nECHO\r\n$3\r\nhey\r\n";
let result = send_request(&mut stream, command);
assert_eq!("+hey\r\n", result);
let mut stream = TcpStream::connect(&addr).unwrap();
let command = "*2\r\n$4\r\nECHO\r\n$5\r\nhello\r\n";
let result = send_request(&mut stream, command);
assert_eq!("+hello\r\n", result);
}
#[test]
fn set_and_get() {
let addr = start_test_server();
let mut stream = TcpStream::connect(&addr).unwrap();
let command = "*3\r\n$3\r\nSET\r\n$3\r\nFOO\r\n$1\r\n1\r\n";
let result = send_request(&mut stream, command);
assert_eq!("+OK\r\n", result);
let mut stream = TcpStream::connect(&addr).unwrap();
let command = "*2\r\n$3\r\nGET\r\n$3\r\nFOO\r\n";
let result = send_request(&mut stream, command);
assert_eq!("+1\r\n", result);
}
#[test]
fn set_and_get_with_expiration() {
let addr = start_test_server();
let mut stream = TcpStream::connect(&addr).unwrap();
let command = "*3\r\n$3\r\nSET\r\n$3\r\nFOO\r\n$1\r\n1\r\n$2\r\nPX\r\n$1\r\n5\r\n";
let result = send_request(&mut stream, command);
assert_eq!("+OK\r\n", result);
let mut stream = TcpStream::connect(&addr).unwrap();
let command = "*2\r\n$3\r\nGET\r\n$3\r\nFOO\r\n";
let result = send_request(&mut stream, command);
assert_eq!("+1\r\n", result);
thread::sleep(Duration::from_millis(5));
let mut stream = TcpStream::connect(&addr).unwrap();
let command = "*2\r\n$3\r\nGET\r\n$3\r\nFOO\r\n";
let result = send_request(&mut stream, command);
assert_eq!("$-1\r\n", result);
}
}
cargo clippy
にすごく助けられるけど、結構無駄なところはありそう。
コマンドを表す構造体CommandがデータストアとなるHashMapをfieldとして保持してしまっているのは微妙かも...。 Arc<Mutex<>>>
はまだあんまりよくわかっていない。