📑

RustとPostgreSQLでストアード・プロシージャーのTDD

2024/12/07に公開

目的

ユニークビジョン株式会社 Advent Calendar 2024の12/5の記事です。

ストアード・プロシージャーを利用しているとデバッグがやりにくいという話を聞きます。
例えばWebアプリでバックエンドのAPIからストアード・プロシージャーが呼び出される場合、単体テストでAPIのテストは書けますが、間接的にしかストアード・プロシージャーをテストできません。

そこで、直接ストアード・プロシージャーの単体テストを行うことにします。ただしストアード・プロシージャーだけでテストフレームワークを作るのは大変なのでRustでストアード・プロシージャー呼び出す関数の単体テストとして実現します。

説明

ここでは企業一覧を取得するストアード・プロシージャーを考えます。
企業名を引数に与えることで、部分一致する企業を絞り込みます。

雛形

まずはロジックの無い雛形を作成します。

sample_get_list_companies.sql
DROP TYPE IF EXISTS type_sample_get_list_companies CASCADE;
CREATE TYPE type_sample_get_list_companies AS (
  company_uuid UUID
  ,company_name TEXT
);

CREATE OR REPLACE FUNCTION sample_get_list_companies (
  p_company_name TEXT DEFAULT NULL
  ,p_now TIMESTAMPTZ DEFAULT NULL
  ,p_pg TEXT DEFAULT NULL
  ,p_operator_uuid UUID DEFAULT NULL
) RETURNS SETOF type_sample_get_list_companies AS $FUNCTION$
DECLARE
  w_now TIMESTAMPTZ := COALESCE(p_now, NOW());
  w_pg TEXT := COALESCE(p_pg, 'sample_get_list_companies');
  w_operator_uuid UUID := COALESCE(p_operator_uuid, '00000000-0000-0000-0000-000000000000');
BEGIN
  RAISE NOTICE 'sample_get_list_companies started p_company_name = %', p_company_name;
END;
$FUNCTION$ LANGUAGE plpgsql;
sample_get_list_companies.rs
use serde::{Deserialize, Serialize};
use chrono::prelude::*;
use sqlx::prelude::*;
use derive_builder::Builder;
use uuid::Uuid;

#[derive(Serialize, Deserialize, Debug, Clone, Builder, Default, PartialEq, Eq)]
#[builder(setter(into))]
#[builder(default)]
#[builder(field(public))]
pub struct DbInput {
    pub company_name: String,
    pub now: Option<DateTime<Utc>>,
    pub pg: Option<String>,
    pub operator_uuid: Option<Uuid>,
}

#[derive(Serialize, Deserialize, Debug, Clone, Builder, Default, PartialEq, Eq, FromRow)]
#[builder(setter(into))]
#[builder(default)]
#[builder(field(public))]
pub struct DbOutput {
    pub company_uuid: Uuid,
    pub company_name: String
}

const SQL: &str = r#"
SELECT
    t1.*
FROM
    sample_get_list_companies(
        p_company_name := $1
        ,p_now := $2
        ,p_pg := $3
        ,p_operator_uuid := $4
    ) AS t1
"#;

pub async fn execute(
    pg_pool: &crate::Pool,
    params: DbInput,
) -> Result<Vec<DbOutput>, sqlx::Error> {
    let res: Vec<DbOutput> = sqlx::query_as(SQL)
        .bind(&params.company_name)
        .bind(params.now)
        .bind(&params.pg)
        .bind(params.operator_uuid)
        .fetch_all(pg_pool)
        .await?;
    Ok(res)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{prelude::*, setup_test};

    // RUST_LOG=info REALM_CODE=test cargo test -p postgresql test_postgresql_sample_get_list_companies -- --nocapture --test-threads=1
    #[tokio::test]
    async fn test_postgresql_sample_get_list_companies () -> anyhow::Result<()> {
        let pool = setup_test().await?;

        let params = DbInput {
            ..Default::default()
        };

        let result = execute(&pool, params).await?;
        assert_eq!(result.len(), 0);

        Ok(())
    }
}

Rustのコードに実行するコードが記述されています。実行します。

$ RUST_LOG=info REALM_CODE=test cargo test -p postgresql \
test_postgresql_sample_get_list_companies -- --nocapture --test-threads=1

running 1 test
test custom::sample_get_list_companies::tests::test_postgresql_sample_get_list_companies ... [2024-12-06T00:04:19Z INFO  sqlx::postgres::notice] sample_get_list_companies started p_company_name = 
ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.13s

テストは成功します。

データ取得

まずはデータを入れて取得してみます。
データの入れ方はRustでテーブル変更に対して壊れにくいテストを書くの記事で紹介しているbuilderパターンを使います。

Rust側を修正します。

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{prelude::*, setup_test};

    // RUST_LOG=info REALM_CODE=test cargo test -p postgresql test_postgresql_sample_get_list_companies -- --nocapture --test-threads=1
    #[tokio::test]
    async fn test_postgresql_sample_get_list_companies () -> anyhow::Result<()> {
        let pool = setup_test().await?;

        make_companies(
            &pool,
            &mut CompaniesBuilder::default().company_name("あいうえお工業"),
        )
        .await?;
        make_companies(
            &pool,
            &mut CompaniesBuilder::default().company_name("かきくけこ工業"),
        )
        .await?;

        let params = DbInput {
            ..Default::default()
        };

        let result = execute(&pool, params).await?;
        assert_eq!(result.len(), 2);

        Ok(())
    }
}
$ RUST_LOG=info REALM_CODE=test cargo test -p postgresql \
test_postgresql_sample_get_list_companies -- --nocapture --test-threads=1

running 1 test
test custom::sample_get_list_companies::tests::test_postgresql_sample_get_list_companies ... [2024-12-06T00:10:54Z INFO  sqlx::postgres::notice] sample_get_list_companies started p_company_name = 
thread 'custom::sample_get_list_companies::tests::test_postgresql_sample_get_list_companies' panicked at crates/postgresql/src/custom/sample_get_list_companies.rs:76:9:
assertion `left == right` failed
  left: 0
 right: 2

テストは失敗します。まだストアード・プロシージャーにコードを書いていないからです。ではストアード・プロシージャーを修正します。

BEGIN
  RAISE NOTICE 'sample_get_list_companies started p_company_name = %', p_company_name;
  
  RETURN QUERY SELECT
    t1.uuid
    ,t1.company_name
  FROM
    public.companies AS t1
  ;
END;

テストを実行します。

$ RUST_LOG=info REALM_CODE=test cargo test -p postgresql \
test_postgresql_sample_get_list_companies -- --nocapture --test-threads=1

running 1 test
test custom::sample_get_list_companies::tests::test_postgresql_sample_get_list_companies ... [2024-12-06T00:14:07Z INFO  sqlx::postgres::notice] sample_get_list_companies started p_company_name = 
ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.13s

成功しました。

絞り込み

次に絞り込み条件を実装します。まずはテストの修正です。

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{prelude::*, setup_test};

    // RUST_LOG=info REALM_CODE=test cargo test -p postgresql test_postgresql_sample_get_list_companies -- --nocapture --test-threads=1
    #[tokio::test]
    async fn test_postgresql_sample_get_list_companies() -> anyhow::Result<()> {
        let pool = setup_test().await?;

        make_companies(
            &pool,
            &mut CompaniesBuilder::default().company_name("あいうえお工業"),
        )
        .await?;
        make_companies(
            &pool,
            &mut CompaniesBuilder::default().company_name("かきくけこ工業"),
        )
        .await?;
 
        let params = DbInput {
            company_name: "あいうえお".to_string(),
            ..Default::default()
        };

        let result = execute(&pool, params).await?;
        assert_eq!(result.len(), 1);

        Ok(())
    }
}

テストを実行します。

$ RUST_LOG=info REALM_CODE=test cargo test -p postgresql \
test_postgresql_sample_get_list_companies -- --nocapture --test-threads=1

running 1 test
test custom::sample_get_list_companies::tests::test_postgresql_sample_get_list_companies ... [2024-12-06T00:10:54Z INFO  sqlx::postgres::notice] sample_get_list_companies started p_company_name = あいうえお
thread 'custom::sample_get_list_companies::tests::test_postgresql_sample_get_list_companies' panicked at crates/postgresql/src/custom/sample_get_list_companies.rs:76:9:
assertion `left == right` failed
  left: 2
 right: 1

テストは失敗します。
ストアード・プロシージャーを修正します。

BEGIN
  RAISE NOTICE 'sample_get_list_companies started p_company_name = %', p_company_name;
  
  RETURN QUERY SELECT
    t1.uuid
    ,t1.company_name
  FROM
    public.companies AS t1
  WHERE
    (
      p_company_name IS NULL 
      OR t1.company_name ILIKE '%' || replace(replace(p_company_name, '_', '\_'), '%', '\%') || '%'
    )
  ;
END;

テストを実行します。

$ RUST_LOG=info REALM_CODE=test cargo test -p postgresql \
test_postgresql_sample_get_list_companies -- --nocapture --test-threads=1

running 1 test
test custom::sample_get_list_companies::tests::test_postgresql_sample_get_list_companies ... [2024-12-06T00:20:01Z INFO  sqlx::postgres::notice] sample_get_list_companies started p_company_name = あいうえお
ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.16s

成功しました。

コード

完成したコードは以下のようになります。

sample_get_list_companies.sql
DROP TYPE IF EXISTS type_sample_get_list_companies CASCADE;
CREATE TYPE type_sample_get_list_companies AS (
  company_uuid UUID
  ,company_name TEXT
);

CREATE OR REPLACE FUNCTION sample_get_list_companies (
  p_company_name TEXT DEFAULT NULL
  ,p_now TIMESTAMPTZ DEFAULT NULL
  ,p_pg TEXT DEFAULT NULL
  ,p_operator_uuid UUID DEFAULT NULL
) RETURNS SETOF type_sample_get_list_companies AS $FUNCTION$
DECLARE
  w_now TIMESTAMPTZ := COALESCE(p_now, NOW());
  w_pg TEXT := COALESCE(p_pg, 'sample_get_list_companies');
  w_operator_uuid UUID := COALESCE(p_operator_uuid, '00000000-0000-0000-0000-000000000000');
BEGIN
  RAISE NOTICE 'sample_get_list_companies started p_company_name = %', p_company_name;
  
  RETURN QUERY SELECT
    t1.uuid
    ,t1.company_name
  FROM
    public.companies AS t1
  WHERE
    (
      p_company_name IS NULL 
      OR t1.company_name ILIKE '%' || replace(replace(p_company_name, '_', '\_'), '%', '\%') || '%'
    )
  ;
END;
$FUNCTION$ LANGUAGE plpgsql;
sample_get_list_companies.rs
use chrono::prelude::*;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use sqlx::prelude::*;
use uuid::Uuid;

#[derive(Serialize, Deserialize, Debug, Clone, Builder, Default, PartialEq, Eq)]
#[builder(setter(into))]
#[builder(default)]
#[builder(field(public))]
pub struct DbInput {
    pub company_name: String,
    pub now: Option<DateTime<Utc>>,
    pub pg: Option<String>,
    pub operator_uuid: Option<Uuid>,
}

#[derive(Serialize, Deserialize, Debug, Clone, Builder, Default, PartialEq, Eq, FromRow)]
#[builder(setter(into))]
#[builder(default)]
#[builder(field(public))]
pub struct DbOutput {
    pub company_uuid: Uuid,
    pub company_name: String,
}

const SQL: &str = r#"
SELECT
    t1.*
FROM
    sample_get_list_companies(
        p_company_name := $1
        ,p_now := $2
        ,p_pg := $3
        ,p_operator_uuid := $4
    ) AS t1
"#;

pub async fn execute(pg_pool: &crate::Pool, params: DbInput) -> Result<Vec<DbOutput>, sqlx::Error> {
    let res: Vec<DbOutput> = sqlx::query_as(SQL)
        .bind(&params.company_name)
        .bind(params.now)
        .bind(&params.pg)
        .bind(params.operator_uuid)
        .fetch_all(pg_pool)
        .await?;
    Ok(res)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{prelude::*, setup_test};

    // RUST_LOG=info REALM_CODE=test cargo test -p postgresql test_postgresql_sample_get_list_companies -- --nocapture --test-threads=1
    #[tokio::test]
    async fn test_postgresql_sample_get_list_companies() -> anyhow::Result<()> {
        let pool = setup_test().await?;

        make_companies(
            &pool,
            &mut CompaniesBuilder::default().company_name("あいうえお工業"),
        )
        .await?;
        make_companies(
            &pool,
            &mut CompaniesBuilder::default().company_name("かきくけこ工業"),
        )
        .await?;
 
        let params = DbInput {
            company_name: "あいうえお".to_string(),
            ..Default::default()
        };

        let result = execute(&pool, params).await?;
        assert_eq!(result.len(), 1);

        Ok(())
    }
}
lib.rs
use crate::prelude::*;
use sqlx::postgres::PgPoolOptions;
pub mod custom;
pub mod table;
pub use sqlx;
pub use sqlx::Error;

pub type Pool = sqlx::Pool<sqlx::Postgres>;

pub async fn get_postgres_pool(url: &str, max_connections: u32) -> Result<Pool, sqlx::Error> {
    let res = PgPoolOptions::new()
        .max_connections(max_connections)
        .connect(url)
        .await?;
    Ok(res)
}

pub mod prelude {
    pub use crate::table::companies::*;
    pub use crate::table::users::*;
    pub use crate::table::*;
}

pub async fn clear_db(pool: &Pool) -> Result<(), sqlx::Error> {
    Users::delete_all(pool).await?;
    Companies::delete_all(pool).await?;
    Ok(())
}

pub async fn setup_test() -> Result<Pool, sqlx::Error> {
    env_logger::init();
    let pool = get_postgres_pool("postgres://user:pass@postgresql/web", 5).await?;
    clear_db(&pool).await?;
    Ok(pool)
}

まとめ

雛形の段階でテストが成功するコードになっていました。段階的にテストを追加し、エラーを起こし、コードを修正して、テストが成功するTDDのループが回りました。

実際の開発ではインターフェースをJSONで定義して、ストアード・プロシージャーとRustの雛形は自動生成しています。
詳細はこちら
RustとPostgreSQLのストアード・プロシージャーの雛形を作る

またストアード・プロシージャーを修正するたびにテストDBに自動で反映するようにしています。ここまでやると本当にストレス無く開発が進みます。
詳細はこちら
PostgreSQLでローカルでストアード・プロシージャーを自動適用する

Discussion