🙌

RustでERDからモデルstructを作る

2024/12/18に公開

目的

ユニークビジョン株式会社 Advent Calendar 2024のシリーズ2、12/10の記事です。

Rustでテーブル変更に対して壊れにくいテストを書くで紹介しているようにbuilderパターンを使ってテーブルと同じ構造のstructを作ると単体テストで便利です。

ここではRustとPostgreSQLのストアード・プロシージャーの雛形を作るでも紹介したERDツールを使って、作成します。

説明

ディレクトリ構成

- erd
  - table
    - templates
      - table_rs.ejs
    - table.erd.json
- apps
  - services
    - crates
      - postgresql
        - src
          - table

ERD

企業テーブルを作りました。

table.erd.json
{
  "meta": {
    "version": "1.0.0"
  },
  "domains": [
    {
      "lname": "名前",
      "pname": "name",
      "type": "TEXT",
      "default": "''",
      "notNull": true,
      "as": "名"
    },
    {
      "lname": "UUID",
      "pname": "uuid",
      "type": "UUID",
      "default": "",
      "notNull": true
    },
  ],
  "templates": {
    "table": [
      {
        "template": "erd/table/templates/table_rs.ejs",
        "file": "apps/services/crates/postgresql/src/table/${pname}.rs"
      }
    ]
  },
  "tables": [
    {
      "lname": "企業",
      "pname": "companies",
      "columns": [
        {
          "domain": "UUID",
          "pk": true,
          "default": "gen_random_uuid()"
        },
        {
          "lname": "企業",
          "pname": "company",
          "domain": "名前"
        }
      ]
    }
  ]
}

struct作成

ejsを定義してRustのstructを作成します

table_rs.ejs
<%
    const structName = _.upperFirst(_.camelCase(table.pname));
    const getType = (column) => {
        let result = '';
        if (column.type === 'UUID') {
            result = 'Uuid';
        } else if (column.type === 'BIGINT') {
            result = 'i64';
        } else if (column.type === 'TIMESTAMPTZ') {
            result = 'DateTime<Utc>';
        } else if (column.type === 'TIMESTAMPTZ[]') {
            result = 'Vec<DateTime<Utc>>';
        } else if (column.type === 'NUMERIC') {
            result = 'rust_decimal::Decimal';
        } else if (column.type === 'NUMERIC[]') {
            result = 'Vec<rust_decimal::Decimal>';
        } else if (column.type === 'JSONB') {
            result = 'serde_json::Value';
        } else if (column.type === 'TEXT[]') {
            result = 'Vec<String>';
        } else if (column.type === 'UUID[]') {
            result = 'Vec<Uuid>';
        } else if (column.domain === '区分') {
            result = `crate::kbn_constants::${_.upperFirst(_.camelCase(column.pname))}`;
        } else {
            result = 'String';
        }
        return result;
    };
    const getRefer = (column) => {
        let result = '';
        if (column.type === 'UUID') {
            result = '';
        } else if (column.type === 'BIGINT') {
            result = '';
        } else if (column.type === 'TIMESTAMPTZ') {
            result = '';
        } else if (column.type === 'TIMESTAMPTZ[]') {
            result = '&';
        } else if (column.type === 'NUMERIC') {
            result = '';
        } else if (column.type === 'NUMERIC[]') {
            result = '&';
        } else if (column.type === 'JSONB') {
            result = '&';
        } else if (column.type === 'TEXT[]') {
            result = '&';
        } else if (column.type === 'UUID[]') {
            result = '&';
        } else {
            result = '&';
        }
        return result;
    };
    const columns = table.columns.map((column) => {
        let typeValue = getType(column);
        if (!column.notNull) {
            typeValue = `Option<${typeValue}>`;
        }
        return  `pub ${column.pname}: ${typeValue}`;
    }).join(',\n    ');
    const insertColumns = table.columns.map((column) => {
        return  `${column.pname}`;
    }).join('\n    ,');
    const insertValues = table.columns.map((column, index) => {
        return  `$${index + 1}`;
    }).join('\n    ,');
    const insertBinds = table.columns.map((column, index) => {
        return  `.bind(${getRefer(column)}self.${column.pname})`;
    }).join('\n            ');
    const updates = table.columns.filter(column => column.pname !== 'uuid').map((column, index) => {
        let col_text = `${column.pname} = $${index + 2}`;
        return col_text;
      }).join('\n    ,');
    const selects = table.columns.map((column) => {
        return  `t1.${column.pname}`;
    }).join('\n    ,');
%>use crate::Pool;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use derive_builder::Builder;
use sqlx::prelude::*;

const SELECT_SQL: &str = r#"
SELECT
    <%- selects %>
    ,t1.created_uuid
    ,t1.updated_uuid
    ,t1.deleted_uuid
    ,t1.created_at
    ,t1.updated_at
    ,t1.deleted_at
    ,t1.created_pg
    ,t1.updated_pg
    ,t1.deleted_pg
    ,t1.bk
FROM
    public.<%= table.pname %> AS t1
"#;

const INSERT_SQL: &str = r#"
INSERT INTO public.<%= table.pname %> (
    <%- insertColumns %>
    ,created_uuid
    ,updated_uuid
    ,deleted_uuid
    ,created_at
    ,updated_at
    ,deleted_at
    ,created_pg
    ,updated_pg
    ,deleted_pg
    ,bk
) VALUES (
    <%- insertValues %>
    ,$<%- table.columns.length + 1 %>
    ,$<%- table.columns.length + 2 %>
    ,$<%- table.columns.length + 3 %>
    ,$<%- table.columns.length + 4 %>
    ,$<%- table.columns.length + 5 %>
    ,$<%- table.columns.length + 6 %>
    ,$<%- table.columns.length + 7 %>
    ,$<%- table.columns.length + 8 %>
    ,$<%- table.columns.length + 9 %>
    ,$<%- table.columns.length + 10 %>
) RETURNING
    <%- insertColumns %>
    ,created_uuid
    ,updated_uuid
    ,deleted_uuid
    ,created_at
    ,updated_at
    ,deleted_at
    ,created_pg
    ,updated_pg
    ,deleted_pg
    ,bk
"#;

const UPDATE_SQL: &str = r#"
UPDATE public.<%= table.pname %> AS t1 SET
    <%- updates %>
    ,created_uuid = $<%- table.columns.length + 1 %>
    ,updated_uuid = $<%- table.columns.length + 2 %>
    ,deletded_uuid = $<%- table.columns.length + 3 %>
    ,created_at = $<%- table.columns.length + 4 %>
    ,updated_at = $<%- table.columns.length + 5 %>
    ,deleted_at = $<%- table.columns.length + 6 %>
    ,created_pg = $<%- table.columns.length + 7 %>
    ,updated_pg = $<%- table.columns.length + 8 %>
    ,deleted_pg = $<%- table.columns.length + 9 %>
    ,bk = $<%- table.columns.length + 10 %>
WHERE
    t1.uuid = $1
RETURNING
    <%- insertColumns %>
    ,created_uuid
    ,updated_uuid
    ,deleted_uuid
    ,created_at
    ,updated_at
    ,deleted_at
    ,created_pg
    ,updated_pg
    ,deleted_pg
    ,bk
"#;

const DELETE_SQL: &str = r#"
DELETE FROM public.<%= table.pname %> AS t1
WHERE
    t1.uuid = $1
RETURNING
    <%- insertColumns %>
    ,created_uuid
    ,updated_uuid
    ,deleted_uuid
    ,created_at
    ,updated_at
    ,deleted_at
    ,created_pg
    ,updated_pg
    ,deleted_pg
    ,bk
"#;

const DELETE_ALL_SQL: &str = r#"
DELETE FROM public.<%= table.pname %> AS t1
"#;

#[derive(Serialize, Deserialize, Debug, Clone, Builder, Default, FromRow)]
#[builder(setter(into), default, field(public))]
pub struct <%= structName %> {
    <%- columns %>,
    pub created_uuid: Uuid,
    pub updated_uuid: Uuid,
    pub deleted_uuid: Uuid,
    pub created_at: DateTime<Utc>,
    pub updated_at: DateTime<Utc>,
    pub deleted_at: Option<DateTime<Utc>>,
    pub created_pg: String,
    pub updated_pg: String,
    pub deleted_pg: String,
    pub bk: Option<String>,
}

impl <%= structName %> {
    pub async fn insert(&self, pool: &Pool) -> Result<Self, sqlx::Error> {
        sqlx::query_as(INSERT_SQL)
            <%- insertBinds %>
            .bind(self.created_uuid)
            .bind(self.updated_uuid)
            .bind(self.deleted_uuid)
            .bind(self.created_at)
            .bind(self.updated_at)
            .bind(self.deleted_at)
            .bind(&self.created_pg)
            .bind(&self.updated_pg)
            .bind(&self.deleted_pg)
            .bind(&self.bk)
            .fetch_one(pool)
            .await
    }

    pub async fn update(&self, pool: &Pool) -> Result<Self, sqlx::Error> {
        sqlx::query_as(UPDATE_SQL)
            <%- insertBinds %>
            .bind(self.created_uuid)
            .bind(self.updated_uuid)
            .bind(self.deleted_uuid)
            .bind(self.created_at)
            .bind(self.updated_at)
            .bind(self.deleted_at)
            .bind(&self.created_pg)
            .bind(&self.updated_pg)
            .bind(&self.deleted_pg)
            .bind(&self.bk)
            .fetch_one(pool)
            .await
    }

    pub async fn delete(&self, pool: &Pool) -> Result<Self, sqlx::Error> {
        Self::delete_one(pool, &self.uuid).await
    }

    pub async fn delete_one(pool: &Pool, uuid: &Uuid) -> Result<Self, sqlx::Error> {
        sqlx::query_as(DELETE_SQL)
            .bind(uuid)
            .fetch_one(pool)
            .await
    }

    pub async fn delete_all(pool: &Pool) -> Result<(), sqlx::Error> {
        let _ = sqlx::query(DELETE_ALL_SQL)
            .execute(pool)
            .await?;
        Ok(())
    }

    pub async fn select_all(pool: &Pool) -> Result<Vec<Self>, sqlx::Error> {
        let rows: Vec<Self> = sqlx::query_as(SELECT_SQL)
            .fetch_all(pool)
            .await?;
        Ok(rows)
    }

    pub async fn select_one(pool: &Pool, uuid: &Uuid) -> Result<Option<Self>, sqlx::Error> {
        let one: Option<Self> = sqlx::query_as(&format!("{} WHERE t1.uuid = $1", SELECT_SQL))
            .bind(uuid)
            .fetch_optional(pool)
            .await?;
        Ok(one)
    }
}

作成された企業テーブル

companies.rs
use crate::Pool;
use chrono::{DateTime, Utc};
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use sqlx::prelude::*;
use uuid::Uuid;

const SELECT_SQL: &str = r#"
SELECT
    t1.uuid
    ,t1.company_name
    ,t1.created_uuid
    ,t1.updated_uuid
    ,t1.deleted_uuid
    ,t1.created_at
    ,t1.updated_at
    ,t1.deleted_at
    ,t1.created_pg
    ,t1.updated_pg
    ,t1.deleted_pg
    ,t1.bk
FROM
    public.companies AS t1
"#;

const INSERT_SQL: &str = r#"
INSERT INTO public.companies (
    uuid
    ,company_name
    ,created_uuid
    ,updated_uuid
    ,deleted_uuid
    ,created_at
    ,updated_at
    ,deleted_at
    ,created_pg
    ,updated_pg
    ,deleted_pg
    ,bk
) VALUES (
    $1
    ,$2
    ,$3
    ,$4
    ,$5
    ,$6
    ,$7
    ,$8
    ,$9
    ,$10
    ,$11
    ,$12
) RETURNING
    uuid
    ,company_name
    ,created_uuid
    ,updated_uuid
    ,deleted_uuid
    ,created_at
    ,updated_at
    ,deleted_at
    ,created_pg
    ,updated_pg
    ,deleted_pg
    ,bk
"#;

const UPDATE_SQL: &str = r#"
UPDATE public.companies AS t1 SET
    company_name = $2
    ,created_uuid = $3
    ,updated_uuid = $4
    ,deletded_uuid = $5
    ,created_at = $6
    ,updated_at = $7
    ,deleted_at = $8
    ,created_pg = $9
    ,updated_pg = $10
    ,deleted_pg = $11
    ,bk = $12
WHERE
    t1.uuid = $1
RETURNING
    uuid
    ,company_name
    ,created_uuid
    ,updated_uuid
    ,deleted_uuid
    ,created_at
    ,updated_at
    ,deleted_at
    ,created_pg
    ,updated_pg
    ,deleted_pg
    ,bk
"#;

const DELETE_SQL: &str = r#"
DELETE FROM public.companies AS t1
WHERE
    t1.uuid = $1
RETURNING
    uuid
    ,company_name
    ,created_uuid
    ,updated_uuid
    ,deleted_uuid
    ,created_at
    ,updated_at
    ,deleted_at
    ,created_pg
    ,updated_pg
    ,deleted_pg
    ,bk
"#;

const DELETE_ALL_SQL: &str = r#"
DELETE FROM public.companies AS t1
"#;

#[derive(Serialize, Deserialize, Debug, Clone, Builder, Default, FromRow)]
#[builder(setter(into), default, field(public))]
pub struct Companies {
    pub uuid: Uuid,
    pub company_name: String,
    pub created_uuid: Uuid,
    pub updated_uuid: Uuid,
    pub deleted_uuid: Uuid,
    pub created_at: DateTime<Utc>,
    pub updated_at: DateTime<Utc>,
    pub deleted_at: Option<DateTime<Utc>>,
    pub created_pg: String,
    pub updated_pg: String,
    pub deleted_pg: String,
    pub bk: Option<String>,
}

impl Companies {
    pub async fn insert(&self, pool: &Pool) -> Result<Self, sqlx::Error> {
        sqlx::query_as(INSERT_SQL)
            .bind(self.uuid)
            .bind(&self.company_name)
            .bind(self.created_uuid)
            .bind(self.updated_uuid)
            .bind(self.deleted_uuid)
            .bind(self.created_at)
            .bind(self.updated_at)
            .bind(self.deleted_at)
            .bind(&self.created_pg)
            .bind(&self.updated_pg)
            .bind(&self.deleted_pg)
            .bind(&self.bk)
            .fetch_one(pool)
            .await
    }

    pub async fn update(&self, pool: &Pool) -> Result<Self, sqlx::Error> {
        sqlx::query_as(UPDATE_SQL)
            .bind(self.uuid)
            .bind(&self.company_name)
            .bind(self.created_uuid)
            .bind(self.updated_uuid)
            .bind(self.deleted_uuid)
            .bind(self.created_at)
            .bind(self.updated_at)
            .bind(self.deleted_at)
            .bind(&self.created_pg)
            .bind(&self.updated_pg)
            .bind(&self.deleted_pg)
            .bind(&self.bk)
            .fetch_one(pool)
            .await
    }

    pub async fn delete(&self, pool: &Pool) -> Result<Self, sqlx::Error> {
        Self::delete_one(pool, &self.uuid).await
    }

    pub async fn delete_one(pool: &Pool, uuid: &Uuid) -> Result<Self, sqlx::Error> {
        sqlx::query_as(DELETE_SQL).bind(uuid).fetch_one(pool).await
    }

    pub async fn delete_all(pool: &Pool) -> Result<(), sqlx::Error> {
        let _ = sqlx::query(DELETE_ALL_SQL).execute(pool).await?;
        Ok(())
    }

    pub async fn select_all(pool: &Pool) -> Result<Vec<Self>, sqlx::Error> {
        let rows: Vec<Self> = sqlx::query_as(SELECT_SQL).fetch_all(pool).await?;
        Ok(rows)
    }

    pub async fn select_one(pool: &Pool, uuid: &Uuid) -> Result<Option<Self>, sqlx::Error> {
        let one: Option<Self> = sqlx::query_as(&format!("{} WHERE t1.uuid = $1", SELECT_SQL))
            .bind(uuid)
            .fetch_optional(pool)
            .await?;
        Ok(one)
    }
}

まとめ

今回作成したものは主キーがUUIDであることを決め打ちしています。
もし色々な主キーが必要ならejsあたりを修正すれば対応可能です。

ユニークビジョンではこんな感じで簡単なSQLなら自動生成しています。
複雑なSQLはRustとPostgreSQLでストアード・プロシージャーのTDDにあるようにストアード・プロシージャーで作成していています。
なのでRustの中では手でSQLの書くことは少なくなっています。

Discussion