💭

RustとPostgreSQLのストアード・プロシージャーの雛形を作る

2024/12/10に公開

目的

ユニークビジョン株式会社 Advent Calendar 2024のシリーズ2、12/3の記事です。

RustとPostgreSQLでストアード・プロシージャーのTDDの記事では、最初にストアード・プロシージャーとそれをテストするためのRustの雛形のコードを紹介しています。

毎回これをゼロから書くのは大変なので、自動生成することにします。

説明

ERDツール

ERDツールとはユニークビジョンが開発しているVSCode拡張です。
jsonで書かれたERDからSQLやコードを自動生成するためのツールになります。

ERDと似た構文でインターフェースも定義することが可能なのでこれを使って、今回の目的の自動生成をします。

フォルダー構成

- apps
  - services
    - crates
      - postgresql
        -src
          -custom
- db
  - sql
    - stored
- erd
  - templates
    - main_rs.ejs
    - stored_template.ejs
  - sample_get_list_companies.erd.json

コード

インターフェース定義

TDDの記事のサンプルのインターフェースを定義です。

sample_get_list_companies.erd.json
{
  "meta": {
    "version": "1.0.0"
  },
  "domains": [
    {
      "lname": "名前",
      "pname": "name",
      "type": "TEXT",
      "default": "''",
      "notNull": true,
      "as": "名"
    },
    {
      "lname": "UUID",
      "pname": "uuid",
      "type": "UUID",
      "default": "",
      "notNull": true
    },
  ],
  "templates": {
    "struct": [
      {
        "template": "erd/stored/templates/stored_template.ejs",
        "file": "db/sql/stored/${pname}.sql_template"
      },
      {
        "template": "erd/stored/templates/main_rs.ejs",
        "file": "apps/services/crates/postgresql/src/custom/${pname}.rs_template"
      }
    ]
  },
  "structs": [
    {
      "lname": "企業取得インプット",
      "pname": "sample_get_list_companies",
      "option": {
        "type": "input"
      },
      "parameters": [
        {
          "lname": "企業",
          "pname": "company",
          "domain": "名前"
        }
      ]
    },
    {
      "lname": "企業取得アウトプット",
      "pname": "sample_get_list_companies",
      "option": {
        "type": "output"
      },
      "parameters": [
        {
          "lname": "企業",
          "pname": "company",
          "domain": "UUID"
        },
        {
          "lname": "企業",
          "pname": "company",
          "domain": "名前"
        }
      ]
    }
  ]
}

Rustコード自動生成

main_rs.ejs
<%
    let result = ""
    let uuidFlag = false;
    let input_struct = struct
    if (struct.option.type === "output") {
        result = "aaa"
        input_struct = structs.find((it) => {
            return it.pname === struct.pname && it.option.type === "input"
        })
    }

    const getType = (column) => {
        let result = '';
        if (column.type === 'UUID') {
            result = 'Uuid';
        } else if (column.type === 'BIGINT') {
            result = 'i64';
        } else if (column.type === 'TIMESTAMPTZ') {
            result = 'DateTime<Utc>';
        } else if (column.type === 'TIMESTAMPTZ[]') {
            result = 'Vec<DateTime<Utc>>';
        } else if (column.type === 'NUMERIC') {
            result = 'rust_decimal::Decimal';
        } else if (column.type === 'NUMERIC[]') {
            result = 'Vec<rust_decimal::Decimal>';
        } else if (column.type === 'JSONB') {
            result = 'serde_json::Value';
        } else if (column.type === 'TEXT[]') {
            result = 'Vec<String>';
        } else if (column.type === 'UUID[]') {
            result = 'Vec<Uuid>';
        } else if (column.type === 'BOOLEAN') {
            result = 'bool';
        } else {
            result = 'String';
        }
        return result;
    }

    const getRefer = (column) => {
        let result = '';
        if (column.type === 'UUID') {
            result = '';
        } else if (column.type === 'BIGINT') {
            result = '';
        } else if (column.type === 'TIMESTAMPTZ') {
            result = '';
        } else if (column.type === 'TIMESTAMPTZ[]') {
            result = '&';
        } else if (column.type === 'NUMERIC') {
            result = '';
        } else if (column.type === 'NUMERIC[]') {
            result = '&';
        } else if (column.type === 'JSONB') {
            result = '&';
        } else if (column.type === 'TEXT[]') {
            result = '&';
        } else if (column.type === 'UUID[]') {
            result = '&';
        } else {
            result = '&';
        }
        return result;
    };

    const input_columns = input_struct.parameters.map((it) => {
        let typeValue = getType(it);
        if (typeValue === 'Uuid') {
            uuidFlag = true;
        }
        if (!it.notNull) {
            typeValue = `Option<${typeValue}>`;
        }
        return  `pub ${it.pname}: ${typeValue},`;
    }).join('\n    ');

    const output_columns = struct.parameters.map((it) => {
        let typeValue = getType(it);
        if (typeValue === 'Uuid') {
            uuidFlag = true;
        }
        if (!it.notNull) {
            typeValue = `Option<${typeValue}>`;
        }
        return  `pub ${it.pname}: ${typeValue}`;
    }).join(',\n    ');

    const args = input_struct.parameters.map((it, index) => {
        return  `p_${it.pname} := $${index + 1}`;
    }).join('\n        ,');

    const binds = input_struct.parameters.map((it) => {
        return  `.bind(${getRefer(it)}params.${it.pname})`;
    }).join('\n        ');

%>use serde::{Deserialize, Serialize};
use chrono::prelude::*;
use sqlx::prelude::*;
use derive_builder::Builder;<%= uuidFlag ? "\nuse uuid::Uuid;" : "" %>

#[derive(Serialize, Deserialize, Debug, Clone, Builder, Default, PartialEq, Eq)]
#[builder(setter(into))]
#[builder(default)]
#[builder(field(public))]
pub struct DbInput {
    <%- input_columns %>
    pub now: Option<DateTime<Utc>>,
    pub pg: Option<String>,
    pub operator_uuid: Option<Uuid>,
}

#[derive(Serialize, Deserialize, Debug, Clone, Builder, Default, PartialEq, Eq, FromRow)]
#[builder(setter(into))]
#[builder(default)]
#[builder(field(public))]
pub struct DbOutput {
    <%- output_columns %>
}

const SQL: &str = r#"
SELECT
    t1.*
FROM
    <%= struct.pname %>(
        <%- args %>
        ,p_now := $<%= input_struct.parameters.length + 1 %>
        ,p_pg := $<%= input_struct.parameters.length + 2 %>
        ,p_operator_uuid := $<%= input_struct.parameters.length + 3 %>
    ) AS t1
"#;

pub async fn execute(
    pg_pool: &crate::Pool,
    params: DbInput,
) -> Result<Vec<DbOutput>, sqlx::Error> {
    let res: Vec<DbOutput> = sqlx::query_as(SQL)
        <%- binds %>
        .bind(params.now)
        .bind(&params.pg)
        .bind(params.operator_uuid)
        .fetch_all(pg_pool)
        .await?;
    Ok(res)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{prelude::*, setup_test};

    // RUST_LOG=info REALM_CODE=test cargo test -p postgresql test_postgresql_<%= struct.pname %> -- --nocapture --test-threads=1
    #[tokio::test]
    async fn test_postgresql_<%= struct.pname %> () -> anyhow::Result<()> {
        let (pool, _) = setup_test().await?;

        let params = DbInput {
            ..Default::default()
        };

        let result = execute(&pool, params).await?;
        assert_eq!(result.len(), 0);

        Ok(())
    }
}

ストアード・プロシージャー自動生成

stored_template.ejs
<%
    let input_struct = struct
    if (struct.option.type === "output") {
        input_struct = structs.find((it) => {
            return it.pname === struct.pname && it.option.type === "input"
        })
    }

    const output_types = struct.parameters.map((it) => {
        return `${it.pname} ${it.type}`
    }).join('\n  ,');

    const input_params = input_struct.parameters.map((it) => {
        return `p_${it.pname} ${it.type} DEFAULT NULL`
    }).join('\n  ,');

    const raise_persent = input_struct.parameters.map((it) => {
      return `p_${it.pname} = %`
    }).join(', ');

    const raise_params = input_struct.parameters.map((it) => {
      return `p_${it.pname}`
  }).join(', ');

%>DROP TYPE IF EXISTS type_<%= struct.pname %> CASCADE;
CREATE TYPE type_<%= struct.pname %> AS (
  <%= output_types %>
);

CREATE OR REPLACE FUNCTION <%= struct.pname %> (
  <%= input_params %>
  ,p_now TIMESTAMPTZ DEFAULT NULL
  ,p_pg TEXT DEFAULT NULL
  ,p_operator_uuid UUID DEFAULT NULL
) RETURNS SETOF type_<%= struct.pname %> AS $FUNCTION$
DECLARE
  w_now TIMESTAMPTZ := COALESCE(p_now, NOW());
  w_pg TEXT := COALESCE(p_pg, '<%= struct.pname %>');
  w_operator_uuid UUID := COALESCE(p_operator_uuid, '00000000-0000-0000-0000-000000000000');
BEGIN
  RAISE NOTICE '<%= struct.pname %> started <%= raise_persent %>', <%= raise_params %>;
END;
$FUNCTION$ LANGUAGE plpgsql;

使い方

VSCodeでインターフェースファイルを開くと、右上に以下の画像のようなボタンが表示されます。一番左のロケットのようなボタンを押すと指定されたフォルダーにファイルが生成されます。

Rustもストアード・プロシージャーも雛形なので.templateという拡張子がついています。これを外して目的のファイルを作ります。
もし既に作成済でコードを書いている場合は、再生成してもtemplateという拡張子なので上書きされる心配はありません。必要な部分をコピーして利用します。

まとめ

ストアード・プロシージャーのTDDのための雛形作成を説明しました。ERDツールはユニークビジョンの社内ではよく使われていますがドキュメント不足なので使いにくいかと思います。その場合は自前の自動生成を用意すると良いです。

Discussion