😺

Rustで常駐プログラムのグレースフルストップ

2023/12/01に公開

Rustで常駐プログラムのグレースフルストップ

目的

複数のスレッドが様々な処理をしている常駐プログラムを考えます。この時プロセスがSIGINT(Ctrl+C)を受け取った時に各スレッドがグレースフルストップするようなプログラムを検討します。

1秒ごとに表示するプログラム

最初は単純に1つのスレッドの中で1秒おきに表示するプログラムを考えます。
tokio::spawnで引数のクロージャーをスレッドに変化してくれます。JoinHandleという戻り値を返してJoinHandleをawaitすることで処理完了を待ちます。

main.rs
#[tokio::main]
async fn main() {
    let handle = tokio::spawn(async {
        for i in 0..10 {
            println!("{}", i);
            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
        }
    });
    handle.await.unwrap();
}

実行結果

0
1
2
^C

SIGINTでプログラムは処理途中で止められてしまいます。

スレッドの中でCtrl+Cされた時にグレースフルストップしたい。

SIGINTを監視するスレッドを作ります。スレッド間で通信を行うためにchannelを用意します。
crossbeam_channelは送信も受信も複数作れるチャンネルです。後で複数のスレッドの時に生きてきます。

tokio::signal::ctrl_c()でSIGINTが来るまで止まってます。来たら以降のプログラムが動き出します。channel経由で停止をスレッドに報告します。

受信側はstop_receiver.try_recv()で直ぐに、何も受信していないか、受信したか、チャンネルが閉じてしまったかを判定できます。何も受信していない場合は通常処理をして、受信した場合はグレースフルストップを行います。

main.rs
#[tokio::main]
async fn main() {
    let (stop_sender, stop_receiver) = crossbeam_channel::unbounded::<()>();

    let _ = tokio::spawn(async move {
        tokio::signal::ctrl_c().await.unwrap();
        println!("received ctrl-c");
        stop_sender.send(()).unwrap();
    });

    let handle = tokio::spawn(async move {
        for i in 0..10 {
            match stop_receiver.try_recv() {
                Ok(_) => {
                    println!("graceful stop!");
                    break;
                },
                Err(crossbeam_channel::TryRecvError::Empty) => {},
                Err(crossbeam_channel::TryRecvError::Disconnected) => panic!("channel disconnected"),
            }
            println!("{}", i);
            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
        }
    });

    handle.await.unwrap();
}

実行結果

0
1
2
^Creceived ctrl-c
graceful stop!

SIGINTを受信したら、グレースフルストップが呼ばれました!

2つの処理でグレースフルストップ(失敗)

では複数のスレッドで処理をしていて、グレースフルストップをやってみます。
stop_receiverはcloneすることで、受信できる口を増やせます。

main.rs
#[tokio::main]
async fn main() {
    let (stop_sender, stop_receiver) = crossbeam_channel::unbounded::<()>();

    let _ = tokio::spawn(async move {
        tokio::signal::ctrl_c().await.unwrap();
        println!("received ctrl-c");
        stop_sender.send(()).unwrap();
    });

    let mut handles = vec![];
    let stop_receiver_cloned = stop_receiver.clone();
    let handle = tokio::spawn(async move {
        for i in 0..10 {
            match stop_receiver_cloned.try_recv() {
                Ok(_) => {
                    println!("graceful stop! 1");
                    break;
                }
                Err(crossbeam_channel::TryRecvError::Empty) => {}
                Err(crossbeam_channel::TryRecvError::Disconnected) => {
                    panic!("channel disconnected")
                }
            }
            println!("{}", i);
            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
        }
    });
    handles.push(handle);

    let stop_receiver_cloned = stop_receiver.clone();
    let handle = tokio::spawn(async move {
        for i in 10..20 {
            match stop_receiver_cloned.try_recv() {
                Ok(_) => {
                    println!("graceful stop! 2");
                    break;
                }
                Err(crossbeam_channel::TryRecvError::Empty) => {}
                Err(crossbeam_channel::TryRecvError::Disconnected) => {
                    panic!("channel disconnected")
                }
            }
            println!("{}", i);
            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
        }
    });
    handles.push(handle);

    for handle in handles {
        handle.await.unwrap();
    }
}

実行結果

0
10
11
1
2
12
^Creceived ctrl-c
graceful stop! 2
thread 'tokio-runtime-worker' panicked at 'channel disconnected', src/main.rs:22:21
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: JoinError::Panic(Id(18), ...)', src/main.rs:51:22

2つ目のスレッドでグレースフルストップが呼ばれず、Disconnectされてしまいました。
Rustではスコープを抜けた変数はdropされます。SIGINTを監視しているスレッドではstop_senderが送信後dropされます。送信者がdropされることと、送信するメッセージが存在しなくなるとchannelはDisconnectされてしまう実装になっています。

2つの処理でグレースフルストップ(成功)

そこで処理しているスレッドの数だけstop_senderがメッセージを送るように変更します。
SIGINTを監視してるスレッドの実装を下げて、登録したスレッドの数を受け取ります。

main.rs
#[tokio::main]
async fn main() {
    let (stop_sender, stop_receiver) = crossbeam_channel::unbounded::<()>();

    let mut handles = vec![];
    let stop_receiver_cloned = stop_receiver.clone();
    let handle = tokio::spawn(async move {
        for i in 0..10 {
            match stop_receiver_cloned.try_recv() {
                Ok(_) => {
                    println!("graceful stop! 1");
                    break;
                }
                Err(crossbeam_channel::TryRecvError::Empty) => {}
                Err(crossbeam_channel::TryRecvError::Disconnected) => {
                    panic!("channel disconnected")
                }
            }
            println!("{}", i);
            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
        }
    });
    handles.push(handle);

    let stop_receiver_cloned = stop_receiver.clone();
    let handle = tokio::spawn(async move {
        for i in 10..20 {
            match stop_receiver_cloned.try_recv() {
                Ok(_) => {
                    println!("graceful stop! 2");
                    break;
                }
                Err(crossbeam_channel::TryRecvError::Empty) => {}
                Err(crossbeam_channel::TryRecvError::Disconnected) => {
                    panic!("channel disconnected")
                }
            }
            println!("{}", i);
            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
        }
    });
    handles.push(handle);

    let handle_count = handles.len();
    let _ = tokio::spawn(async move {
        tokio::signal::ctrl_c().await.unwrap();
        println!("received ctrl-c");
        for _ in 0..handle_count {
            stop_sender.send(()).unwrap();
        }
    });

    for handle in handles {
        handle.await.unwrap();
    }
}

実行結果

0
10
11
1
12
2
^Creceived ctrl-c
graceful stop! 2
graceful stop! 1

うまく行きました!

最後にリファクタリング

共通化できる部分を切り出したり、SIGINTの監視を関数化しました。

main.rs
use crossbeam_channel::{unbounded, Receiver, Sender, TryRecvError};
use std::time::Duration;
use tokio::{signal::ctrl_c, spawn, task::JoinHandle, time::sleep};

fn check_stop(stop_receiver: &Receiver<()>) -> bool {
    match stop_receiver.try_recv() {
        Ok(_) => {
            println!("received stop signal");
            true
        }
        Err(TryRecvError::Empty) => false,
        Err(TryRecvError::Disconnected) => panic!("channel disconnected"),
    }
}

fn ctrl_c_handler(stop_sender: Sender<()>, handle_count: usize) -> JoinHandle<()> {
    spawn(async move {
        ctrl_c().await.unwrap();
        println!("received ctrl-c");
        for _ in 0..handle_count {
            stop_sender.send(()).unwrap();
        }
    })
}

fn make_thread(
    stop_receiver: Receiver<()>,
    start: usize,
    end: usize,
    key: &'static str,
) -> JoinHandle<()> {
    spawn(async move {
        for i in start..end {
            if check_stop(&stop_receiver) {
                println!("graceful stop handle {}", key);
                break;
            }
            println!("{}", i);
            sleep(Duration::from_secs(1)).await;
        }
    })
}

#[tokio::main]
async fn main() {
    let (stop_sender, stop_receiver) = unbounded::<()>();
    let handles = vec![
        make_thread(stop_receiver.clone(), 0, 10, "1"),
        make_thread(stop_receiver.clone(), 10, 20, "2"),
        make_thread(stop_receiver.clone(), 100, 101, "3"),
    ];
    let _ = ctrl_c_handler(stop_sender, handles.len());
    for handle in handles {
        handle.await.unwrap();
    }
}
Cargo.toml
[package]
name = "resident"
version = "0.1.0"
edition = "2021"

[dependencies]
crossbeam-channel = "0.5"
tokio = { version = "1", features = ["rt-multi-thread", "macros", "time", "signal"] }

実行結果

0
100
10
11
1
2
12
^Creceived ctrl-c
received stop signal
graceful stop handle 2
received stop signal
graceful stop handle 1

念のため早く終わるスレッドを用意してグレースフルストップが呼ばれていないことを確認しました。

追記(2023/12/06)

社内の勉強会で発表したところ、tokioのドキュメントにグレースフルストップのトピックがあると教えてもらいました。
ここでは「CancellationToken」を使うことが紹介されてました。これで書き換えてみます。

main.rs
use std::time::Duration;
use tokio::{signal::ctrl_c, spawn, task::JoinHandle, time::sleep};
use tokio_util::sync::CancellationToken;

fn ctrl_c_handler(token: CancellationToken) -> JoinHandle<()> {
    spawn(async move {
        ctrl_c().await.unwrap();
        println!("received ctrl-c");
        token.cancel();
    })
}

fn make_thread(
    token: CancellationToken,
    start: usize,
    end: usize,
    key: &'static str,
) -> JoinHandle<()> {
    spawn(async move {
        for i in start..end {
            if token.is_cancelled() {
                println!("graceful stop handle {}", key);
                break;
            }
            println!("{}", i);
            sleep(Duration::from_secs(1)).await;
        }
    })
}

#[tokio::main]
async fn main() {
    let token = CancellationToken::new();
    let handles = vec![
        make_thread(token.clone(), 0, 10, "1"),
        make_thread(token.clone(), 10, 20, "2"),
        make_thread(token.clone(), 100, 101, "3"),
    ];
    let _ = ctrl_c_handler(token);
    for handle in handles {
        handle.await.unwrap();
    }
}
Cargo.toml
[package]
name = "resident"
version = "0.1.0"
edition = "2021"

[dependencies]
tokio = { version = "1", features = ["rt-multi-thread", "macros", "time", "signal"] }
tokio-util = "0.7.10"

CancellationTokenを使うと以下のメリットがあるのがわかりました。

  • channelは2つの値を扱うが、CancellationTokenは1つの値だけ扱えば良い
  • スレッドの数を管理しなくて良い
  • エラーハンドリングが不要になる

コードもすっきりして、良い感じになりました。

Discussion