🦀

Rust sqlxでデータベースに依存した部分のテストを書く

2022/04/17に公開

はじめに

アプリケーションにおいてデータの永続化を実現しようとすると、DBとアクセスする層が必要になることが多いです。適切なインターフェースを定義すれば、DBにアクセスする層をモック化して、その層に依存する部分のテストを書くことができます。しかし時にはDBを直接扱う層のロジックをテストしたいときもあります。

例えばJavaであれば、H2を使ってテスト用のデータベースを立ち上げることができます。しかしRustでsqlxを採用した場合、どのようにすればDBに依存するテストが実現できるのでしょうか。

あまり情報が見つからなかったので、試行錯誤しながら得られた知見をまとめておきたいと思います。

前提

本稿で使用するバージョンは、Rust 1.60.0、sqlx 0.5.13です。

方針

テスト用のPostgreSQLをDockerを使って立ち上げます。テストが繰り返し実行できるように、毎回トランザクションを貼って最後にロールバックするようにします。

ソースコード

次のようなスキーマのPostgreSQLデータベースを対象とします。

migrations/20220306122339_create_tables.sql
CREATE TABLE bookshelf_user (
  id text NOT NULL PRIMARY KEY,
  created_at timestamp NOT NULL default current_timestamp,
  updated_at timestamp NOT NULL default current_timestamp
);

テスト用のデータベースはDockerで立ち上げます。

docker-compose-test.yml
services:
  db:
    image: postgres:latest
    ports:
      - "5432:5432"
    environment:
      - POSTGRES_PASSWORD=password
$ docker-compose -f docker-compose-test.yml up -d
$ sqlx migrate run

永続化対象のUser構造体はこちらです。New Typeパターンを使っているので、多少長めの実装になっています。

use validator::Validate;

use crate::domain::error::DomainError;

#[derive(Debug, Clone, PartialEq, Eq, Validate)]
pub struct UserId {
    #[validate(length(min = 1))]
    value: String,
}

impl UserId {
    pub fn new(id: String) -> Result<Self, DomainError> {
        let object = Self { value: id };
        object.validate()?;
        Ok(object)
    }

    pub fn as_str(&self) -> &str {
        &self.value
    }

    pub fn into_string(self) -> String {
        self.value
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct User {
    pub id: UserId,
}

impl User {
    pub fn new(id: UserId) -> User {
        User { id }
    }
}

次にUserRepositoryのtraitを定義しておきます。今回の話には関係ありませんが、UserRepositoryのモックを作成できるようにするためです。

use async_trait::async_trait;
use mockall::automock;

use crate::domain::{
    entity::user::{User, UserId},
    error::DomainError,
};

#[automock]
#[async_trait]
pub trait UserRepository: Send + Sync + 'static {
    async fn create(&self, user: &User) -> Result<(), DomainError>;
    async fn find_by_id(&self, id: &UserId) -> Result<Option<User>, DomainError>;
}

最後に本体の実装とテストです。PgUserRepositoryがtraitを実装した構造体です。

use async_trait::async_trait;
use sqlx::{PgConnection, PgPool};

use crate::domain::{
    entity::user::{User, UserId},
    error::DomainError,
    repository::user_repository::UserRepository,
};

#[derive(sqlx::FromRow)]
struct UserRow {
    id: String,
}

#[derive(Debug, Clone)]
pub struct PgUserRepository {
    pool: PgPool,
}

impl PgUserRepository {
    pub fn new(pool: PgPool) -> Self {
        Self { pool }
    }
}

#[async_trait]
impl UserRepository for PgUserRepository {
    async fn create(&self, user: &User) -> Result<(), DomainError> {
        let mut conn = self.pool.acquire().await?;
        let result = InternalUserRepository::create(user, &mut conn).await?;
        Ok(result)
    }

    async fn find_by_id(&self, id: &UserId) -> Result<Option<User>, DomainError> {
        let mut conn = self.pool.acquire().await?;
        let user = InternalUserRepository::find_by_id(id, &mut conn).await?;
        Ok(user)
    }
}

pub(in crate::infrastructure) struct InternalUserRepository {}

impl InternalUserRepository {
    pub(in crate::infrastructure) async fn create(
        user: &User,
        conn: &mut PgConnection,
    ) -> Result<(), DomainError> {
        sqlx::query("INSERT INTO bookshelf_user (id) VALUES ($1)")
            .bind(user.id.as_str())
            .execute(conn)
            .await?;
        Ok(())
    }

    async fn find_by_id(id: &UserId, conn: &mut PgConnection) -> Result<Option<User>, DomainError> {
        let row: Option<UserRow> = sqlx::query_as("SELECT * FROM bookshelf_user WHERE id = $1")
            .bind(id.as_str())
            .fetch_optional(conn)
            .await?;

        let id = row.map(|row| UserId::new(row.id)).transpose()?;
        Ok(id.map(|id| User::new(id)))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::time::Duration;

    use sqlx::postgres::PgPoolOptions;

    #[tokio::test]
    async fn test_user_repository() -> anyhow::Result<()> {
        dotenv::dotenv().ok();

        let db_url = fetch_database_url();
        let pool = PgPoolOptions::new()
            .max_connections(5)
            .connect_timeout(Duration::from_secs(1))
            .connect(&db_url)
            .await?;
        let mut tx = pool.begin().await?;

        let id = UserId::new(String::from("foo"))?;
        let user = User::new(id.clone());

        let fetched_user = InternalUserRepository::find_by_id(&id, &mut tx).await?;
        assert!(fetched_user.is_none());

        InternalUserRepository::create(&user, &mut tx).await?;

        let fetched_user = InternalUserRepository::find_by_id(&id, &mut tx).await?;
        assert_eq!(fetched_user, Some(user));

        tx.rollback().await?;
        Ok(())
    }

    fn fetch_database_url() -> String {
        use std::env::VarError;

        match std::env::var("DATABASE_URL") {
            Ok(s) => s,
            Err(VarError::NotPresent) => panic!("Environment variable DATABASE_URL is required."),
            Err(VarError::NotUnicode(_)) => {
                panic!("Environment variable DATABASE_URL is not unicode.")
            }
        }
    }
}

解説

PgUserRepositoryが実際のアプリケーションで使われるリポジトリですが、こちらは直接テストしません。代わりに実際の処理をInternalUserRepositoryに移譲して、こちらをテストするようにします。一種のHunble Objectパターンとみなせるでしょう。

InternalUserRepositoryはメソッドを定義せず、関連関数のみを定義します。関数をまとめておくくらいの意味合いしかないので、moduleに関数をまとめておくくらいでも良いでしょう。

関連関数はSQLに必要な情報と、&mut PgConnectionを引数に受け取ります。これは公式ドキュメントにおいて、connectionとtransactionを両方受け取れるようにする方法として紹介されているものです。

https://docs.rs/sqlx/0.5.13/sqlx/trait.Acquire.html

However, if you really just want to accept both, a transaction or a connection as an argument to a function, then it’s easier to just accept a mutable reference to a database connection like so:

PgUserRepositoryにおいてはコネクションを生成して渡します。

let mut conn = self.pool.acquire().await?;
let result = InternalUserRepository::create(user, &mut conn).await?;

テストにおいてはトランザクションを生成して渡します。

let mut tx = pool.begin().await?;
(略)
let fetched_user = InternalUserRepository::find_by_id(&id, &mut tx).await?;

こうすることで、テストの場合だけ最後にロールバックすることができるようになります。

Pros, Cons

Pros

  • 本物のPostgreSQLが使える
    • PostgreSQLの独自機能も使える
    • H2などのインメモリデータベースを使う場合と比べて、挙動の違いで悩まされることがない

Cons

  • テスト用に本物のデータベースを立ち上げないといけない

うまく行かなかった方法

ここから先は思い出しながら書いているので、不正確なことが書いてある可能性もあります。ご了承ください。

SQLiteでテストする

本番ではPostgreSQLを使い、単体テストではSQLiteを使うという方法も検討しました。しかし以下の試みはすべてうまく行きませんでした。

このアプローチでは、PostgreSQL特有の機能は使えなくなってしまうという問題もあります。

リポジトリにExecutorを持たせる

RepositoryにExecutorを持たせるアイデアです。Executorをジェネリック型として持たせれば、使用するデータベースの違いを吸収できそうに見えます。

こんなイメージです。

struct UserRepository<E> {
    executor: E,
}

impl<'a, E> UserRepository<E>
where
    E: Executor<'a>,
{
    async fn find_by_id(id: &UserId, conn: &mut PgConnection) -> Result<Option<User>, DomainError> {
        let row: Option<UserRow> = sqlx::query_as("SELECT * FROM bookshelf_user WHERE id = $1")
            .bind(id.as_str())
            .fetch_optional(self.executor)
            .await?;

        let id = row.map(|row| UserId::new(row.id)).transpose()?;
        Ok(id.map(|id| User::new(id)))
    }
}

しかしこの方法はどのように試行錯誤してもコンパイルが通りませんでした。色々問題はありますが、self.executorがmoveされてしまうのが困るところの一つです。

以下の内容も参考になりますが、現時点ではうまく行かなそうです。

https://github.com/launchbadge/sqlx/discussions/1136

リポジトリにAnyPoolを持たせる

Anyというドライバーが存在し、ランタイムに実際のDB種類を決定することができるようです。

https://docs.rs/sqlx/0.5.13/sqlx/struct.Any.html
https://docs.rs/sqlx/0.5.13/sqlx/type.AnyPool.html

しかしDecode実装が最低限の型に対してしか提供されていないため、例えばtimeクレートの型を使うことができません。

https://docs.rs/sqlx/0.5.13/sqlx/decode/trait.Decode.html

私は使っていませんが、Anyに対しては静的にSQLのチェックが行えるquery!マクロなども利用することができないようです。

https://github.com/launchbadge/sqlx/issues/964

おわりに

今回のサンプルコードは、私が趣味で開発しているこちらのアプリケーションを題材にしました。2022/04/17現在、開発途中のためまだまだ足りないところだらけですが、良ければ参考にしてみてください。

https://github.com/hiterm/bookshelf-api

2024-09-12追記

sqlx 0.6.1でテスト用の機能が追加されました。

以下のような利点があるので、特別な事情がない限り使わない手はないでしょう。

  • テストごとに独立したデータベースを作成してくれる
    • テストごとに環境が独立するので、手動でロールバックする必要がない
  • マイグレーションも自動で実施してくれる
    • sqlx-cliをインストールする必要がなくなり、CIが高速になる

実装は以下のようになります。

use async_trait::async_trait;
use sqlx::PgPool;

use crate::domain::{
    entity::user::{User, UserId},
    error::DomainError,
    repository::user_repository::UserRepository,
};

#[derive(sqlx::FromRow)]
struct UserRow {
    id: String,
}

#[derive(Debug, Clone)]
pub struct PgUserRepository {
    pool: PgPool,
}

impl PgUserRepository {
    pub fn new(pool: PgPool) -> Self {
        Self { pool }
    }
}

#[async_trait]
impl UserRepository for PgUserRepository {
    async fn create(&self, user: &User) -> Result<(), DomainError> {
        sqlx::query("INSERT INTO bookshelf_user (id) VALUES ($1)")
            .bind(user.id.as_str())
            .execute(&self.pool)
            .await?;
        Ok(())
    }

    async fn find_by_id(&self, id: &UserId) -> Result<Option<User>, DomainError> {
        let row: Option<UserRow> = sqlx::query_as("SELECT * FROM bookshelf_user WHERE id = $1")
            .bind(id.as_str())
            .fetch_optional(&self.pool)
            .await?;

        let id = row.map(|row| UserId::new(row.id)).transpose()?;
        Ok(id.map(User::new))
    }
}

#[cfg(feature = "test-with-database")]
#[cfg(test)]
mod tests {
    use super::*;

    #[sqlx::test]
    async fn test_user_repository(pool: PgPool) -> anyhow::Result<()> {
        dotenv::dotenv().ok();

        let repository = PgUserRepository::new(pool);

        let id = UserId::new(String::from("foo"))?;
        let user = User::new(id.clone());

        let fetched_user = repository.find_by_id(&id).await?;
        assert!(fetched_user.is_none());

        repository.create(&user).await?;

        let fetched_user = repository.find_by_id(&id).await?;
        assert_eq!(fetched_user, Some(user));

        Ok(())
    }
}

上の実装と比較すると、InternalUserRepositoryを作る必要もなくなり、非常にシンプルになっています。

ちなみにGitHub Actionsであれば、以下のように実行できます。

https://github.com/hiterm/bookshelf-api/blob/2ccf2f0df77b0aa9a44bca1b3cdfb5a243e7f36a/.github/workflows/ci.yml#L24-L28

Discussion