🚀

GoでSQLファイルを実行する

2023/09/16に公開

Goの標準ライブラリにはSQLファイルを実行する簡単な方法はありませんが、インテグレーションテスト等で欲しくなることがあります。これを実現するには、SQLファイルに含まれるSQLのひとつひとつをExec等で実行すればいいのですが、ファイルに含まれる複数のSQLを単一のSQLに分割していく処理はなかなか厄介です。ナイーブな実装としては;でのstrings.Splitがありますが、この区切り文字がコメントの中にある場合や、文字列リテラルの中にある場合を考慮する必要があります。本記事はそういったケースを考慮してパース処理を行ってくれる便利なライブラリと、それを用いたコード例を紹介します。

パースに利用するライブラリ

TiDBの中に含まれるパーサを利用します。パーサ単体でパッケージとして提供されているので、利用が容易です。また、TiDBに含まれていることもあり、コードの品質や今後のメンテナンスが問題になることは少なそうです。

https://github.com/pingcap/tidb/tree/master/parser

TiDBはMySQL互換を謳っているため、MySQLを利用している場合は問題なく利用できますし、その他のDBを使っている場合であっても、そのDB特有の方言を利用していない限りは同様に利用できると思います。

以下がこのライブラリを利用してSQLファイルをパースして一文ずつ実行する例です。

rawSqls, err := os.ReadFile("path/to/sql")
if err != nil {
	log.Fatalf("failed to read seed sql file: %v", err)
}

// Parserを作成
p := parser.New()

// ファイルから読み込んだSQLをパースする
// 結果はast.StmtNodeのSliceで、この時点で単一のSQL(Statement)に分割されている
stmts, _, err := p.Parse(string(rawSqls), "", "")
if err != nil {
	log.Fatalf("failed to parse seed sql: %v", err)
}

var buf bytes.Buffer
for _, stmt := range stmts {
	buf.Reset()

	// 各ast.StmtNodeをSQL文字列に復元する
	stmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &buf))

	sql := buf.String()
	if _, err := db.Exec(sql); err != nil {
		log.Fatalf("failed to execute sql: err=%v, sql=%s", err, sql)
	}
}

テスト用のユーティリティで利用する例

実際に動作を確認できる簡単なサンプルとして、実際にDBと通信を行うインテグレーションテストに組み込む際のユーティリティを書いてみました(コード全体はここにあります)。

簡単のため、以下の前提を置いています。

  • 並列実行はしない
  • DBはMySQL
  • DBは1つ

また、SQLファイルを実行する話なので、シードはSQLファイルで管理されているものとしています。

実際にDBと通信を行うインテグレーションテストでは、テストケースにデータの更新が含まれている場合はテストケース同士が影響を及ぼしあう可能性があります。この問題の解決方法として、各ケースを実行する前にテーブルのデータを全て削除してシードデータを入れ直す方法があり、これを実装しています。

スキーマとシード

サンプルで用いているスキーマとシードは以下の通りです。シードには前述したパースが面倒な記述(コメントや文字列リテラルの中に;がある)を含めています。

db/initdb.d/01_schema.sql
create table users (
  id bigint,
  primary key id (id)
);

create table posts (
  id bigint,
  author_id bigint not null,
  content text not null,
  primary key id (id),
  constraint fk_author_id foreign key (author_id) references users (id)
);
db/initdb.d/02_seed.sql
insert into users (id) values (1);
insert into users (id) values (2);

insert into posts (id, author_id, content) values (1, 1, 'Hello, world!');
-- insert into posts (id, author_id, content) values (2, 1, 'Comment Out');
insert into posts (id, author_id, content) values (2, 1, '(;_;)');
insert into posts (id, author_id, content) values (3, 2, 'Hello, world!');
/*
 * insert into posts (id, author_id, content) values (3, 2, 'Comment Out');
 */
insert into posts (id, author_id, content) values (4, 2, 'insert into posts (id, author_id, content) values (4, 2, ''Hello, world!'');');

ResetAllTables - データ削除とシード投入のユーティリティ

テーブルのデータを全て削除して、シードデータを入れ直す処理としてResetAllTablesを定義します。実際の処理は、前者をTruncateAllTables、後者をSeedTablesとして実装しています。SeedTablesの方は前の章のコード例とほぼ同じです。

func ResetAllTables() {
	TruncateAllTables()
	SeedTables()
}
func TruncateAllTables() {
	// VIEWやSYSTEM VIEWを除いた普通のテーブルのみをリストアップ
	rows, err := db.Query("show full tables where Table_Type = 'BASE TABLE'")
	if err != nil {
		log.Fatalf("failed to show tables: %v", err)
	}
	defer rows.Close()

	if _, err := db.Exec("set foreign_key_checks = 0"); err != nil {
		log.Fatalf("failed to disable foreign_key_checks: %v", err)
	}

	// 外部キー制約を一時的に外した状態で、各テーブルのデータを削除(truncate)していく
	for rows.Next() {
		var tableName, tableType string
		err = rows.Scan(&tableName, &tableType)
		if err != nil {
			log.Fatalf("failed to show tables: %v", err)
		}

		truncateSQL := fmt.Sprintf("truncate `%s`", tableName)
		if _, err := db.Exec(truncateSQL); err != nil {
			log.Fatalf("failed to truncate table: %v", err)
		}
	}

	if _, err := db.Exec("set foreign_key_checks = 1"); err != nil {
		log.Fatalf("failed to enable foreign_key_checks: %v", err)
	}

	if rows.Err() != nil {
		log.Fatalf("failed to iterate show tables results: %v", rows.Err())
	}
}
//go:embed db/initdb.d/*
var sqlDir embed.FS

var p *parser.Parser = parser.New()

func SeedTables() {
	seedSQL, err := sqlDir.ReadFile("db/initdb.d/02_seed.sql")
	if err != nil {
		log.Fatalf("failed to read seed sql file: %v", err)
	}

	stmts, _, err := p.Parse(string(seedSQL), "", "")
	if err != nil {
		log.Fatalf("failed to parse seed sql: %v", err)
	}

	var buf bytes.Buffer
	for _, stmt := range stmts {
		buf.Reset()

		stmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &buf))

		sql := buf.String()
		if _, err := db.Exec(sql); err != nil {
			log.Fatalf("failed to execute sql: err=%v, sql=%s", err, sql)
		}
	}
}

ResetAllTablesの動作確認

ResetAllTablesの動作を確認するために簡単なテストコードを書いています。まず、postsテーブルからデータを読み取るAllPostsを以下のように定義しています。

type Post struct {
	ID       int
	AuthorID int
	Content  string
}

func AllPosts(ctx context.Context) ([]Post, error) {
	rows, err := db.QueryContext(ctx, "select * from posts order by id asc")
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var posts []Post
	for rows.Next() {
		var post Post
		if err := rows.Scan(&post.ID, &post.AuthorID, &post.Content); err != nil {
			return nil, err
		}
		posts = append(posts, post)
	}

	if err := rows.Err(); err != nil {
		return nil, err
	}

	return posts, nil
}

このAllPostsの結果がpostsテーブルへのinsertを行った後でも、ResetAllTablesを行うことで変わっていないことを確認しています。

func TestResetAllTables(t *testing.T) {
	ResetAllTables()

	want := []Post{
		{ID: 1, AuthorID: 1, Content: "Hello, world!"},
		{ID: 2, AuthorID: 1, Content: "(;_;)"},
		{ID: 3, AuthorID: 2, Content: "Hello, world!"},
		{ID: 4, AuthorID: 2, Content: "insert into posts (id, author_id, content) values (4, 2, 'Hello, world!');"},
	}

	got, err := AllPosts(context.Background())
	if err != nil {
		t.Fatal(err)
	}

	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("mismatch detected (-want +got):\n%s", diff)
	}

	db.Exec("insert into posts (id, author_id, content) values (5, 2, 'This record does not exist in seed.');")
	ResetAllTables()

	got, err = AllPosts(context.Background())
	if err != nil {
		t.Fatal(err)
	}

	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("mismatch detected (-want +got):\n%s", diff)
	}
}

Discussion