💡

Rustで常駐プログラムでDBアクセス

2024/02/19に公開

目的

Rustで常駐プログラムのグレースフルストップ
Rustで常駐プログラムの中で定期的に処理したい
の続きです。
今回はDBアクセスになります。

コネクションプーリング

起動したプロセスでは複数のスレッドが動作しますが、それぞれのスレッドが無制限にDBにアクセスするとコネクション数の上限を超えてしまうかもしれません。
また都度DBの接続を開くのは時間がかかるので、パフォーマンスのためにもコネクションプールを導入します。

Rustには色々なコネクションプールのライブラリがありますが、ここではdeadpoolを使います。

コード

DB

PostgreSQLをうごかします。

docker-compose.yaml
services:
  db:
    image: postgres:16
    environment:
      - POSTGRES_DB=web
      - POSTGRES_USER=user
      - POSTGRES_PASSWORD=pass
      - TZ=Asia/Tokyo
      - PGTZ=Asia/Tokyo
    ports:
      - 5432:5432
    volumes:
      - resident_postgresql_data:/var/lib/postgresql/data

volumes:
  resident_postgresql_data:

DB接続周り

ここで重要になるのは「max_size: 2」です。きちんとコネクションプールが動作しているか確認するため少なめの数字を設定しています。本番とかでは状況に応じて設定してください。
タイムアウトを指定しました。コネクションエラーも見てみます。

また型を明示していますが、これはredisなど他のDBも同時に使うことがあるので、こうしておくと便利です。またどんなコネクションプールを使っているかも隠蔽できます。

pub type PgPool = deadpool_postgres::Pool;
pub type PgClient = deadpool_postgres::Client;

fn get_postgres_pool(url: &str) -> anyhow::Result<PgPool> {
    let pg_url = url::Url::parse(&url)?;
    let dbname = match pg_url.path_segments() {
        Some(mut res) => res.next(),
        None => Some("web"),
    };
    let pool_config = deadpool_postgres::PoolConfig {
        max_size: 2,
        timeouts: deadpool_postgres::Timeouts { 
            wait: Some(Duration::from_secs(2)),
            ..Default::default()
        },
        ..Default::default()
    };
    let cfg = deadpool_postgres::Config {
        user: Some(pg_url.username().to_string()),
        password: pg_url.password().map(|password| password.to_string()),
        dbname: dbname.map(|dbname| dbname.to_string()),
        host: pg_url.host_str().map(|host| host.to_string()),
        pool: Some(pool_config),
        ..Default::default()
    };
    let res = cfg.create_pool(Some(deadpool_postgres::Runtime::Tokio1), NoTls)?;
    Ok(res)
}

async fn get_postgres_client(
    pool: &deadpool_postgres::Pool,
) -> anyhow::Result<PgClient> {
    pool.get().await.map_err(Into::into)
}

改良したmake_looper

定期的に処理する以前に作成したmake_looperをDBアクセスするように改良しました。
以前との違いは引数にpg_poolを受け付けるようにしたのと、本当の処理をする関数fでDBのコネクションを持つようにしています。

DBコネクションを取得する部分がコネクションプールの醍醐味になっています。コネクション数を超えている場合にはここで待たされます。他のスレッドが解放すれば動き出します。
また指定したタイムアウトの時間を超えた場合はエラーを返します。エラーの場合適切にハンドリングする必要がありますが、ここではこのスレッドの処理は諦めて次の時間に動くことを期待します。

fn make_looper<Fut1, Fut2>(
    pg_pool: PgPool,
    token: CancellationToken,
    expression: &'static str,
    f: impl Fn(&DateTime<Utc>, PgClient) -> Fut1 + Send + Sync + 'static,
    g: impl Fn() -> Fut2 + Send + Sync + 'static,
) -> JoinHandle<()>
where
    Fut1: Future<Output = ()> + Send,
    Fut2: Future<Output = ()> + Send,
{
    spawn(async move {
        let schedule = Schedule::from_str(expression).unwrap();
        let mut next_tick = schedule.upcoming(Utc).next().unwrap();
        loop {
            // グレースフルストップのチェック
            if token.is_cancelled() {
                g().await;
                break;
            }

            let now = Utc::now();
            if now >= next_tick {
                // 定期的に行う処理実行
                match get_postgres_client(&pg_pool).await {
                    Ok(pg_conn) => f(&now, pg_conn).await,
                    Err(e) => {
                        // エラーが出たので、ここでは何もしないで次に期待する
                        println!("get_postgres_client error={}", e);
                    }
                }

                // 次の時間取得
                next_tick = schedule.upcoming(Utc).next().unwrap();
            }

            // 次の時間計算
            sleep(Duration::from_secs(std::cmp::min(
                (next_tick - now).num_seconds() as u64,
                60,
            )))
            .await;
        }
    })
}

全体

Cargo.toml
[package]
name = "resident"
version = "0.1.0"
edition = "2021"

[dependencies]
anyhow = "1"
chrono = "0.4"
cron = "0.12"
deadpool-postgres = { version = "0.12.1" }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1", features = ["rt-multi-thread", "macros", "time", "signal"] }
tokio-util = "0.7"
url = "2.5.0"
main.rs
use chrono::prelude::*;
use cron::Schedule;
use std::{future::Future, str::FromStr, time::Duration};

use deadpool_postgres::tokio_postgres::NoTls;
use tokio::{signal::ctrl_c, spawn, task::JoinHandle, time::sleep};
use tokio_util::sync::CancellationToken;

pub type PgPool = deadpool_postgres::Pool;
pub type PgClient = deadpool_postgres::Client;

fn get_postgres_pool(url: &str) -> anyhow::Result<PgPool> {
    let pg_url = url::Url::parse(&url)?;
    let dbname = match pg_url.path_segments() {
        Some(mut res) => res.next(),
        None => Some("web"),
    };
    let pool_config = deadpool_postgres::PoolConfig {
        max_size: 2,
        timeouts: deadpool_postgres::Timeouts { 
            wait: Some(Duration::from_secs(2)),
            ..Default::default()
        },
        ..Default::default()
    };
    let cfg = deadpool_postgres::Config {
        user: Some(pg_url.username().to_string()),
        password: pg_url.password().map(|password| password.to_string()),
        dbname: dbname.map(|dbname| dbname.to_string()),
        host: pg_url.host_str().map(|host| host.to_string()),
        pool: Some(pool_config),
        ..Default::default()
    };
    let res = cfg.create_pool(Some(deadpool_postgres::Runtime::Tokio1), NoTls)?;
    Ok(res)
}

async fn get_postgres_client(
    pool: &deadpool_postgres::Pool,
) -> anyhow::Result<PgClient> {
    pool.get().await.map_err(Into::into)
}

fn ctrl_c_handler(token: CancellationToken) -> JoinHandle<()> {
    spawn(async move {
        ctrl_c().await.unwrap();
        println!("received ctrl-c");
        token.cancel();
    })
}

fn make_looper<Fut1, Fut2>(
    pg_pool: PgPool,
    token: CancellationToken,
    expression: &'static str,
    f: impl Fn(&DateTime<Utc>, PgClient) -> Fut1 + Send + Sync + 'static,
    g: impl Fn() -> Fut2 + Send + Sync + 'static,
) -> JoinHandle<()>
where
    Fut1: Future<Output = ()> + Send,
    Fut2: Future<Output = ()> + Send,
{
    spawn(async move {
        let schedule = Schedule::from_str(expression).unwrap();
        let mut next_tick = schedule.upcoming(Utc).next().unwrap();
        loop {
            // グレースフルストップのチェック
            if token.is_cancelled() {
                g().await;
                break;
            }

            let now = Utc::now();
            if now >= next_tick {
                // 定期的に行う処理実行
                match get_postgres_client(&pg_pool).await {
                    Ok(pg_conn) => f(&now, pg_conn).await,
                    Err(e) => {
                        // エラーが出たので、ここでは何もしないで次に期待する
                        println!("get_postgres_client error={}", e);
                    }
                }

                // 次の時間取得
                next_tick = schedule.upcoming(Utc).next().unwrap();
            }

            // 次の時間計算
            sleep(Duration::from_secs(std::cmp::min(
                (next_tick - now).num_seconds() as u64,
                60,
            )))
            .await;
        }
    })
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    let pg_url = std::env::var("PG_URL").unwrap_or("postgres://user:pass@localhost:5432/web".to_owned());
    let pg_pool = get_postgres_pool(&pg_url)?;
    let token: CancellationToken = CancellationToken::new();
    let handles = vec![make_looper(
        pg_pool.clone(),
        token.clone(),
        "*/1 * * * * *",
        |&now: &_, pg_conn: _| async move {
            println!("定期的に処理する何か1 {}", now);
            let _result = pg_conn.query("SELECT pg_sleep(5)", &[]).await.unwrap();
        },
        || async move {
            println!("graceful stop looper 1");
        },
    ),make_looper(
        pg_pool.clone(),
        token.clone(),
        "*/1 * * * * *",
        |&now: &_, pg_conn: _| async move {
            println!("定期的に処理する何か2 {}", now);
            let _result = pg_conn.query("SELECT  pg_sleep(5)", &[]).await.unwrap();
        },
        || async move {
            println!("graceful stop looper 2");
        },
    ),make_looper(
        pg_pool.clone(),
        token.clone(),
        "*/1 * * * * *",
        |&now: &_, pg_conn: _| async move {
            println!("定期的に処理する何か3 {}", now);
            let _result = pg_conn.query("SELECT  pg_sleep(5)", &[]).await.unwrap();
        },
        || async move {
            println!("graceful stop looper 3");
        },
    ),make_looper(
        pg_pool.clone(),
        token.clone(),
        "*/1 * * * * *",
        |&now: &_, pg_conn: _| async move {
            println!("定期的に処理する何か4 {}", now);
            let _result = pg_conn.query("SELECT  pg_sleep(5)", &[]).await.unwrap();
        },
        || async move {
            println!("graceful stop looper 4");
        },
    )];

    #[allow(clippy::let_underscore_future)]
    let _ = ctrl_c_handler(token);
    for handle in handles {
        handle.await.unwrap();
    }
    Ok(())
}

結果

4つのスレッドを動かしています。DBでは単純にスリープするSQL文になっています。実行結果は以下の通りです。期待通り2つのコネクションしか取らず、タイムアウトのエラーがでています。

定期的に処理する何か3 2024-02-18 23:43:20.000768886 UTC
定期的に処理する何か4 2024-02-18 23:43:20.000816040 UTC
get_postgres_client error=Timeout occurred while waiting for a slot to become available
get_postgres_client error=Timeout occurred while waiting for a slot to become available

まとめ

コネクションプールを使ったDBアクセスができるようになりました。指定した以上にコネクションを取ることは無いので安心して使えます。

おまけ

これを書いている時に「pg_pool.clone()」をしたら、コネクションはどんどん取れちゃうだろうからArcとか使う必要があるかなと思ってました。実行したらうまくいったのでビックリです。
deadpoolのソースを見たところ自前でArcを使って制御してました。

Discussion