Open2

【dbt】dbt dbterd

YuichiYuichi
import argparse
import json
import os
import subprocess
import sys
from itertools import chain
from typing import Any, Dict, List, Set


# ER図を出力するのに必要なrelationshipsを含むノードの情報を抽出する関数
# manifest内の"nodes"から、"resource_type"が"test"で、"test_metadata"の"name"が"relationships"のノードを抽出し、リストとして返します。
def extract_relationships_nodes(manifest: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    ER図を出力するために必要な"relationships"テストを含むノードを抽出します。

    Parameters:
    manifest (Dict[str, Any]): manifestファイルの内容を表す辞書

    Returns:
    List[Dict[str, Any]]: "relationships"テストを含むノードのリスト
    """
    result = []
    for node in manifest["nodes"].values():
        if (
            node["resource_type"] == "test"
            and node.get("test_metadata", {}).get("name") == "relationships"
        ):
            result.append(node)
    return result


# target_nodesをスタートとして、Nホップ先のノードを返す関数
# 指定されたrelationships_nodes内で、target_nodesに関連するノードを最大max_hops回のホップ数で探索し、最大max_nodes個のノードを返します。
def search_n_hops_nodes(
    relationships_nodes: List[Dict[str, Any]],
    target_nodes: Set[str],
    max_hops: int,
    max_nodes: int = 10,  # ホップ数が多い場合、ER図に登場するノード数が多いため、上限を設ける
    n: int = 0,
) -> Set[str]:
    """
    relationshipsノードに基づいて、指定したホップ数で依存関係にあるノードを探索します。

    Parameters:
    relationships_nodes (List[Dict[str, Any]]): "relationships"テストを含むノードのリスト
    target_nodes (Set[str]): スタートノードのセット
    max_hops (int): 最大ホップ数
    max_nodes (int): 最大ノード数(ER図に含めるノード数の制限)
    n (int): 現在のホップ数(再帰のための引数)

    Returns:
    Set[str]: ホップ数に基づいて取得したノードのセット
    """
    if n >= max_hops:
        return target_nodes

    initial_target_nodes = target_nodes.copy()
    for target_node_name in initial_target_nodes:
        for node in relationships_nodes:
            if target_node_name in node["depends_on"]["nodes"]:
                for new_node in node["depends_on"]["nodes"]:
                    if len(target_nodes) >= max_nodes:
                        print(
                            f"ER図内にあまりに多くのノード(n={len(target_nodes)})が含まれるため、出力を制限します",
                            file=sys.stderr,
                        )
                        return target_nodes
                    target_nodes.add(new_node)
    if initial_target_nodes == target_nodes:
        return target_nodes
    return search_n_hops_nodes(
        relationships_nodes,
        target_nodes,
        max_hops=max_hops,
        max_nodes=max_nodes,
        n=n + 1,
    )


# dbterdコマンドを実行して、ER図を出力する関数
# 指定されたノード群に対してdbterdを実行し、ER図を生成します。
def run_dbterd(n_hops_nodes: Set[str]) -> None:
    """
    dbterdコマンドを実行して、指定されたノードに基づいたER図を生成します。

    Parameters:
    n_hops_nodes (Set[str]): ER図に含めるノードのセット
    """
    dbterd_command = [
        "dbterd",
        "run",
        "--target",
        "mermaid",
        "--entity-name-format",
        "schema.table",
        "--resource-type",
        "source",
        "--resource-type",
        "model",
    ]
    dbterd_select_targets = list(
        chain.from_iterable([["--select", f"exact:{node}"] for node in n_hops_nodes])
    )
    subprocess.run(dbterd_command + dbterd_select_targets)


# ノード名に対応するMarkdownファイルのパスを取得する関数
# manifest内のノード情報を基に、対応するMarkdownファイルのパスを返します。
def get_markdown_file_path(manifest: Dict[str, Any], node_name: str) -> str:
    """
    ノード名に対応するMarkdownファイルのパスを取得します。

    Parameters:
    manifest (Dict[str, Any]): manifestファイルの内容を表す辞書
    node_name (str): ノード名

    Returns:
    str: ノードに対応するMarkdownファイルのパス
    """
    filename_without_extension, _ = os.path.splitext(
        manifest["nodes"][node_name]["original_file_path"]
    )
    return filename_without_extension + ".md"


# 相対パスを取得する関数
# base_file_pathからtarget_file_pathへの相対パスを返します。
def get_relative_path(base_file_path: str, target_file_path: str) -> str:
    """
    2つのファイルパス間の相対パスを取得します。

    Parameters:
    base_file_path (str): 基準となるファイルのパス
    target_file_path (str): 相対パスを取得するターゲットファイルのパス

    Returns:
    str: base_file_pathからtarget_file_pathへの相対パス
    """
    base_dir = os.path.dirname(base_file_path)
    return os.path.relpath(target_file_path, base_dir)


def main():
    """
    コマンドライン引数を解析し、manifestファイルに基づいてER図を生成するメイン関数です。

    引数として、manifestファイルのパス、ターゲットノード名、最大ホップ数、最大ノード数を受け取り、関連するノードを探索後、ER図を生成します。
    ER図はMarkdown形式で出力され、関連する他のモデルのER図へのリンクも追加されます。
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--manifest_filename",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--target_node_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--max_hops",
        type=int,
        required=True,
    )
    parser.add_argument(
        "--max_nodes",
        type=int,
        required=True,
    )
    args, _ = parser.parse_known_args()
    target_node_name = args.target_node_name

    with open(args.manifest_filename, "r") as file:
        manifest = json.load(file)

    n_hops_nodes = search_n_hops_nodes(
        extract_relationships_nodes(manifest),
        {target_node_name},
        max_hops=args.max_hops,
        max_nodes=args.max_nodes,
    )

    if len(n_hops_nodes) <= 1:
        print(
            "ER図を出力するのに十分なノード数がありません。ER図を書きたい場合はrelationshipsのテストを追加することを検討してください",
            file=sys.stderr,
        )
        sys.exit(1)

    run_dbterd(n_hops_nodes)

    with open("target/output.md", "r") as file:
        mermaid = file.read()

    markdown_file_path = get_markdown_file_path(manifest, target_node_name)
    with open(f"{markdown_file_path}", "w") as file:
        file.write(f"## {target_node_name}のER図\n")
        file.write("```mermaid\n")
        file.write(mermaid)
        file.write("```\n")
        file.write("## 関連するモデルのER図へのリンク\n")
        for node in n_hops_nodes:
            if node != args.target_node_name:
                node_file_path = get_relative_path(
                    markdown_file_path, get_markdown_file_path(manifest, node)
                )
                file.write(f"- [{node}]({node_file_path})\n")
    print(f"{markdown_file_path}にER図を出力しました")


if __name__ == "__main__":
    main()