Open4

iter.Seqを非同期解決して普通のRESTfulでもdataloaderが使えるようにできないか

macopymacopy

動機

  • Webサーバーを書いているとしばしばRDBMSのN+1問題に遭遇する
    • 別名ループクエリ
    • あるレコードを取得する際に同時にこのレコードが持つ(has-oneやhas-many関係)のレコードを同時に取得するケースがある
      • user レコードが持つ user_item 一覧を取るなど
    • 親のレコードが複数ある場合、ナイーブな実装をするとループで回して子のレコードをクエリすることになる
      • 大量にSQLが発行されてDBサーバー(場合によってはアプリケーションサーバーも)負荷が増大する
  • GraphQLではこの問題が起こりがち
    • あるresolverに属する子供のresolverが芋づる式に呼ばれる
    • RESTfulや他のRPCサーバーの形態ではループクエリが発生する場所が比較的見えやすいが(一つのエンドポイントで解決するエンティティ一覧が決まっているため)、GraphQLでは見えにくい
  • そこでdataloaderという仕組みが存在する
    • 様々なdataloaderの実装が存在するが、Goでよく使われるdatalodergendataloadenでは、resolverが非同期に解決されることを利用し、あるキー(DBレコードのIDやGraphQL nodeのID)に対応したレコードの取得を一定時間バッファリングし、バッファされたキーをバッチ取得(RDBMSではINを利用したクエリ)を行うことでN+1の発生を防ぐ
  • RESTfulなどでみられるN+1の解決はINやJOINなどを用いるが、元のナイーブな実装から構造を大きく改変する必要がある
    • 単体取得の関数から複数個バッチで取得する関数へ書き換えるなど
    • 呼び出し側が複数個のレコードを、欲しい位置にマッピングする必要がある
  • dataloaderの方がコードの見た目としては理解がしやすく、またプログラムの構造を大きく変えなくて良いメリットがある(と思われる)
  • RESTfulではdataloaderは採用しずらい
    • 同期的にレスポンスが解決される前提のため
  • Go 1.23でiter.Seq/iter.Seq2が導入され、よりデータ列のより抽象的な記述が可能になった
    • あるiter.Seqに対する変換を非同期に解決するのが素直に書けるのではないか?
      • 今までもsync.WaitGroupとかerrgroup.Groupとか使えばできましたが...
      • より宣言的に書ける可能性がある
    • 非同期変換が書けたらdataloaderが導入できる

というわけで実験してみましょう

macopymacopy

今回使う道具を紹介します

github.com/mackee/iterutils

今回の実験のためにこさえました。普通に実用だとは思います

iterutils.FromNexter2

*database/sql.Rows のような従来のイテレーションをするやつをiter.Seq2に変換するやつです。

FromNexter2は一つ目の引数に iterutils.Nexterというinterfaceを取ります。Nexter interfaceの定義は以下です。

type Nexter interface {
	Next() bool
}

*database/sql.Rowsはこのinterfaceを満たします。以下はsqlxを使った例ですが、sql.Rowsでも似たような感じでできます。

rows, err := db.QueryxContext(
    ctx,
    "SELECT `id`, `user_id`, `body`, `mime`, `created_at` FROM `posts` ORDER BY `created_at` DESC",
)
if err != nil 
    return fmt.Errorf("failed to query posts: %w", err)
}
postsIter := iterutils.FromNexter2(rows, func(rows *sqlx.Rows) (*Post, error) {
    var p Post
    if err := rows.StructScan(&p); err != nil {
        return nil, fmt.Errorf("failed to scan post: %w", err)
    }
    return &p, nil
})

*sqlx.RowsStructScanを持っているので、これでstructにマッピングをしています。これにより、クエリをした上でその結果をiter.Seq2[*Post, error]に変換できます。

async.Map2

iterutilsのサブパッケージ、asyncにあるMap2という関数は、iter.Seq2の要素を変換して別のiter.Seq2にしますが、後述するパッケージgo-functional/v2にあるMap2とは違い、非同期で変換を行います。つまり、各要素に対する変換関数が同時に走る場合があります。

一方で便利なことに結果はちゃんと元の要素の順番に並べ替えられて戻ってきます。便利ですね。

使い方としては以下のような感じです。

asyncIter := async.Map2(postsIter, func(p *Post, err error) (*Post, error) {
    mp, err := makePost(p, getCSRFToken(r), false)
    if err != nil {
        return nil, err
    }
    return mp, nil
})

makePostは重いが同時実行可能な関数で、Postのうち埋め足りなかったfieldを補うと仮定します。例えば元の要素数が20個あり、それぞれが1秒かかるとした場合、普通にgo-functional/v2/it.Map2を用いると20秒かかりますが、async.Map2では上記のmakePostが本当に同時実行できる場合は1秒程度で変換が完了します。

また、iter.Seq2なので for v, err := range asyncIter の形で結果を受け取れるのも良い点ですね。

github.com/vikstrous/dataloadgen

前述したdataloader実装の一つです。dataloadenとは違いコード生成は使わず、genericsを用いています。一方で同等のパフォーマンスかつ型がついた状態で扱えるので便利ですね。

github.com/BooleanCat/go-functional/v2

Go 1.23 iterを扱うためのユーティリティ関数集です。特にit.Mapit.Map2, it.Collect2, it.TryCollectをよく使います。

macopymacopy

検証方法

private-isu を使って検証していきます。

検証環境

  • Macbook Air M1 16GB
  • Go 1.23.0
  • MySQL 9.0
  • Memcached

Dockerは無し。ベンチマーカーは同一マシン上で動作

導入直前までのスコア

  • 初期スコア 9323
  • commentsテーブルにindexを貼る 65574
macopymacopy

getIndexをチューニング

private-isu の getIndex (GET /) はmakePostsというメソッドがあり、この中でループクエリが発生している。

なので、これを上記の道具を使いながら、修正していく。

makePostsからmakePostを切り出し

func makePosts(results []Post, csrfToken string, allComments bool) ([]Post, error) {
	var posts []Post

	for _, p := range results {
		_p, err := makePost(&p, csrfToken, allComments)
		if err != nil {
			return nil, err
		}

		if _p.User.DelFlg == 0 {
			posts = append(posts, p)
		}
		if len(posts) >= postsPerPage {
			break
		}
	}

	return posts, nil
}

func makePost(p *Post, csrfToken string, allComments bool) (*Post, error) {
	if err := db.Get(&p.CommentCount, "SELECT COUNT(*) AS `count` FROM `comments` WHERE `post_id` = ?", p.ID); err != nil {
		return nil, err
	}

	query := "SELECT * FROM `comments` WHERE `post_id` = ? ORDER BY `created_at` DESC"
	if !allComments {
		query += " LIMIT 3"
	}
	var comments []Comment
	if err := db.Select(&comments, query, p.ID); err != nil {
		return nil, err
	}

	for i := 0; i < len(comments); i++ {
		err := db.Get(&comments[i].User, "SELECT * FROM `users` WHERE `id` = ?", comments[i].UserID)
		if err != nil {
			return nil, err
		}
	}

	// reverse
	for i, j := 0, len(comments)-1; i < j; i, j = i+1, j-1 {
		comments[i], comments[j] = comments[j], comments[i]
	}

	p.Comments = comments

	if err := db.Get(&p.User, "SELECT * FROM `users` WHERE `id` = ?", p.UserID); err != nil {
		return nil, err
	}

	p.CSRFToken = csrfToken
	return p, nil
}

makePostで発行しているクエリをDataloader化する

type dataloaders struct {
	commentCountByPostID   *dataloadgen.Loader[int, int]
	allCommentsByPostID    *dataloadgen.Loader[int, []Comment]
	latestCommentsByPostID *dataloadgen.Loader[int, []Comment]
	userByID               *dataloadgen.Loader[int, User]
}

func newDataloaders() *dataloaders {
	return &dataloaders{
		commentCountByPostID: dataloadgen.NewLoader(
			func(ctx context.Context, keys []int) ([]int, []error) {
				query := "SELECT `post_id`, COUNT(*) AS `count` FROM `comments` WHERE `post_id` IN (?) GROUP BY `post_id`"
				query, args, err := sqlx.In(query, keys)
				if err != nil {
					errs := slices.Repeat([]error{err}, len(keys))
					return nil, errs
				}
				type postCommentCount struct {
					PostID int `db:"post_id"`
					Count  int `db:"count"`
				}
				var results []postCommentCount
				if err := db.SelectContext(ctx, &results, query, args...); err != nil {
					errs := slices.Repeat([]error{err}, len(keys))
					return nil, errs
				}
				m := make(map[int]int, len(results))
				for _, r := range results {
					m[r.PostID] = r.Count
				}
				mm := it.Map(slices.Values(keys), func(k int) int {
					return m[k]
				})
				counts := slices.Collect(mm)
				return counts, nil
			}),
		allCommentsByPostID: dataloadgen.NewLoader(
			func(ctx context.Context, keys []int) ([][]Comment, []error) {
				query := "SELECT * FROM `comments` WHERE `post_id` IN (?) ORDER BY `created_at` DESC"
				query, args, err := sqlx.In(query, keys)
				if err != nil {
					errs := slices.Repeat([]error{err}, len(keys))
					return nil, errs
				}
				var results []Comment
				if err := db.SelectContext(ctx, &results, query, args...); err != nil {
					errs := slices.Repeat([]error{err}, len(keys))
					return nil, errs
				}
				m := make(map[int][]Comment, len(keys))
				for _, r := range results {
					m[r.PostID] = append(m[r.PostID], r)
				}
				mm := it.Map(slices.Values(keys), func(k int) []Comment {
					return m[k]
				})
				comments := slices.Collect(mm)
				return comments, nil
			},
		),
		latestCommentsByPostID: dataloadgen.NewLoader(
			func(ctx context.Context, keys []int) ([][]Comment, []error) {
				query := `
SELECT c.id, c.post_id, c.user_id, c.comment, c.created_at
FROM (
    SELECT id, post_id, user_id, comment, created_at,
           ROW_NUMBER() OVER (PARTITION BY post_id ORDER BY created_at DESC) as rn
    FROM comments
    WHERE post_id IN (?)
) c
WHERE c.rn <= 3
ORDER BY c.post_id, c.created_at DESC;`
				query, args, err := sqlx.In(query, keys)
				if err != nil {
					errs := slices.Repeat([]error{err}, len(keys))
					return nil, errs
				}
				var results []Comment
				if err := db.SelectContext(ctx, &results, query, args...); err != nil {
					errs := slices.Repeat([]error{err}, len(keys))
					return nil, errs
				}
				m := make(map[int][]Comment, len(keys))
				for _, r := range results {
					m[r.PostID] = append(m[r.PostID], r)
				}
				mm := it.Map(slices.Values(keys), func(k int) []Comment {
					return m[k]
				})
				comments := slices.Collect(mm)
				return comments, nil
			},
		),
		userByID: dataloadgen.NewLoader(
			func(ctx context.Context, keys []int) ([]User, []error) {
				query := "SELECT * FROM `users` WHERE `id` IN (?)"
				query, args, err := sqlx.In(query, keys)
				if err != nil {
					errs := slices.Repeat([]error{err}, len(keys))
					return nil, errs
				}
				var results []User
				if err := db.SelectContext(ctx, &results, query, args...); err != nil {
					errs := slices.Repeat([]error{err}, len(keys))
					return nil, errs
				}
				m := make(map[int]User, len(results))
				for _, r := range results {
					m[r.ID] = r
				}
				mm := it.Map(slices.Values(keys), func(k int) User {
					return m[k]
				})
				users := slices.Collect(mm)
				return users, nil
			},
		),
	}
}

makePostの同時実行版を作る

干渉・依存しない値は同時に取得するようにerrgroup.Groupを使っている

func makePostDL(ctx context.Context, p *Post, csrfToken string, allComments bool, dl *dataloaders) (*Post, error) {
	eg := errgroup.Group{}
	eg.Go(func() error {
		count, err := dl.commentCountByPostID.Load(ctx, p.ID)
		if err != nil {
			return err
		}
		p.CommentCount = count
		return nil
	})
	eg.Go(func() error {
		var comments []Comment
		if allComments {
			_comments, err := dl.allCommentsByPostID.Load(ctx, p.ID)
			if err != nil {
				return err
			}
			comments = _comments
		} else {
			_comments, err := dl.latestCommentsByPostID.Load(ctx, p.ID)
			if err != nil {
				return err
			}
			comments = _comments
		}

		cwuiter := async.Map2(slices.All(comments), func(_ int, c Comment) (*Comment, error) {
			user, err := dl.userByID.Load(ctx, c.UserID)
			if err != nil {
				return nil, err
			}
			c.User = user
			return &c, nil
		})
		cwus, err := it.TryCollect(cwuiter)
		if err != nil {
			return err
		}
		slices.Reverse(cwus)
		derefIter := it.Map(slices.Values(cwus), func(c *Comment) Comment {
			return *c
		})

		p.Comments = slices.Collect(derefIter)
		return nil
	})

	eg.Go(func() error {
		u, err := dl.userByID.Load(ctx, p.UserID)
		if err != nil {
			return err
		}
		p.User = u
		return nil
	})
	if err := eg.Wait(); err != nil {
		return nil, err
	}

	p.CSRFToken = csrfToken
	return p, nil
}

getIndex側でasync実行するようにする

比較できるようにgetIndex2として作成。
並列数が多くなるとむしろスコアが下がってしまうので、LIMITをかけたりJOINによる絞り込みをやった。なので元のgetIndex側も同様のクエリに改変して比較できるようにした。

と思ったが、73kぐらいでサチってしまっているのは、画像配信側がボトルネックになっている説がある。

func getIndex2(w http.ResponseWriter, r *http.Request) {
	ctx := r.Context()
	me := getSessionUser(r)
	dl := newDataloaders()

	rows, err := db.QueryxContext(
		ctx,
		"SELECT posts.id, posts.user_id, posts.body, posts.mime, posts.created_at FROM posts INNER JOIN users ON posts.user_id = users.id WHERE users.del_flg = 0 ORDER BY posts.created_at DESC LIMIT ?",
		postsPerPage,
	)
	if err != nil {
		log.Print(err)
		return
	}
	postsIter := iterutils.FromNexter2(rows, func(rows *sqlx.Rows) (*Post, error) {
		var p Post
		if err := rows.StructScan(&p); err != nil {
			return nil, err
		}
		return &p, nil
	})
	asyncIter := async.Map2(postsIter, func(p *Post, err error) (*Post, error) {
		mp, err := makePostDL(ctx, p, getCSRFToken(r), false, dl)
		if err != nil {
			return nil, err
		}
		return mp, nil
	})
	var posts []Post
	for p, err := range asyncIter {
		if err != nil {
			log.Print(err)
			return
		}
		posts = append(posts, *p)
	}

	fmap := template.FuncMap{
		"imageURL": imageURL,
	}

	template.Must(template.New("layout.html").Funcs(fmap).ParseFiles(
		getTemplPath("layout.html"),
		getTemplPath("index.html"),
		getTemplPath("posts.html"),
		getTemplPath("post.html"),
	)).Execute(w, struct {
		Posts     []Post
		Me        User
		CSRFToken string
		Flash     string
	}{posts, me, getCSRFToken(r), getFlash(w, r, "notice")})
}