iter.Seqを非同期解決して普通のRESTfulでもdataloaderが使えるようにできないか
動機
- Webサーバーを書いているとしばしばRDBMSのN+1問題に遭遇する
- 別名ループクエリ
- あるレコードを取得する際に同時にこのレコードが持つ(has-oneやhas-many関係)のレコードを同時に取得するケースがある
-
user
レコードが持つuser_item
一覧を取るなど
-
- 親のレコードが複数ある場合、ナイーブな実装をするとループで回して子のレコードをクエリすることになる
- 大量にSQLが発行されてDBサーバー(場合によってはアプリケーションサーバーも)負荷が増大する
- GraphQLではこの問題が起こりがち
- あるresolverに属する子供のresolverが芋づる式に呼ばれる
- RESTfulや他のRPCサーバーの形態ではループクエリが発生する場所が比較的見えやすいが(一つのエンドポイントで解決するエンティティ一覧が決まっているため)、GraphQLでは見えにくい
- そこでdataloaderという仕組みが存在する
- 様々なdataloaderの実装が存在するが、Goでよく使われるdatalodergenやdataloadenでは、resolverが非同期に解決されることを利用し、あるキー(DBレコードのIDやGraphQL nodeのID)に対応したレコードの取得を一定時間バッファリングし、バッファされたキーをバッチ取得(RDBMSではINを利用したクエリ)を行うことでN+1の発生を防ぐ
- 使ったことがないが、graph-gophers/dataloaderも多分このやり方
- 様々なdataloaderの実装が存在するが、Goでよく使われるdatalodergenやdataloadenでは、resolverが非同期に解決されることを利用し、あるキー(DBレコードのIDやGraphQL nodeのID)に対応したレコードの取得を一定時間バッファリングし、バッファされたキーをバッチ取得(RDBMSではINを利用したクエリ)を行うことでN+1の発生を防ぐ
- RESTfulなどでみられるN+1の解決はINやJOINなどを用いるが、元のナイーブな実装から構造を大きく改変する必要がある
- 単体取得の関数から複数個バッチで取得する関数へ書き換えるなど
- 呼び出し側が複数個のレコードを、欲しい位置にマッピングする必要がある
- dataloaderの方がコードの見た目としては理解がしやすく、またプログラムの構造を大きく変えなくて良いメリットがある(と思われる)
- RESTfulではdataloaderは採用しずらい
- 同期的にレスポンスが解決される前提のため
- Go 1.23でiter.Seq/iter.Seq2が導入され、よりデータ列のより抽象的な記述が可能になった
- あるiter.Seqに対する変換を非同期に解決するのが素直に書けるのではないか?
- 今までもsync.WaitGroupとかerrgroup.Groupとか使えばできましたが...
- より宣言的に書ける可能性がある
- 非同期変換が書けたらdataloaderが導入できる
- あるiter.Seqに対する変換を非同期に解決するのが素直に書けるのではないか?
というわけで実験してみましょう
今回使う道具を紹介します
github.com/mackee/iterutils
今回の実験のためにこさえました。普通に実用だとは思います
iterutils.FromNexter2
*database/sql.Rows
のような従来のイテレーションをするやつをiter.Seq2
に変換するやつです。
FromNexter2
は一つ目の引数に iterutils.Nexter
というinterfaceを取ります。Nexter
interfaceの定義は以下です。
type Nexter interface {
Next() bool
}
*database/sql.Rows
はこのinterfaceを満たします。以下はsqlxを使った例ですが、sql.Rows
でも似たような感じでできます。
rows, err := db.QueryxContext(
ctx,
"SELECT `id`, `user_id`, `body`, `mime`, `created_at` FROM `posts` ORDER BY `created_at` DESC",
)
if err != nil
return fmt.Errorf("failed to query posts: %w", err)
}
postsIter := iterutils.FromNexter2(rows, func(rows *sqlx.Rows) (*Post, error) {
var p Post
if err := rows.StructScan(&p); err != nil {
return nil, fmt.Errorf("failed to scan post: %w", err)
}
return &p, nil
})
*sqlx.Rows
はStructScan
を持っているので、これでstructにマッピングをしています。これにより、クエリをした上でその結果をiter.Seq2[*Post, error]
に変換できます。
async.Map2
iterutilsのサブパッケージ、async
にあるMap2
という関数は、iter.Seq2
の要素を変換して別のiter.Seq2
にしますが、後述するパッケージgo-functional/v2
にあるMap2
とは違い、非同期で変換を行います。つまり、各要素に対する変換関数が同時に走る場合があります。
一方で便利なことに結果はちゃんと元の要素の順番に並べ替えられて戻ってきます。便利ですね。
使い方としては以下のような感じです。
asyncIter := async.Map2(postsIter, func(p *Post, err error) (*Post, error) {
mp, err := makePost(p, getCSRFToken(r), false)
if err != nil {
return nil, err
}
return mp, nil
})
makePost
は重いが同時実行可能な関数で、Post
のうち埋め足りなかったfieldを補うと仮定します。例えば元の要素数が20個あり、それぞれが1秒かかるとした場合、普通にgo-functional/v2/it.Map2
を用いると20秒かかりますが、async.Map2
では上記のmakePost
が本当に同時実行できる場合は1秒程度で変換が完了します。
また、iter.Seq2
なので for v, err := range asyncIter
の形で結果を受け取れるのも良い点ですね。
github.com/vikstrous/dataloadgen
前述したdataloader実装の一つです。dataloaden
とは違いコード生成は使わず、genericsを用いています。一方で同等のパフォーマンスかつ型がついた状態で扱えるので便利ですね。
github.com/BooleanCat/go-functional/v2
Go 1.23 iterを扱うためのユーティリティ関数集です。特にit.Map
やit.Map2
, it.Collect2
, it.TryCollect
をよく使います。
検証方法
private-isu を使って検証していきます。
検証環境
- Macbook Air M1 16GB
- Go 1.23.0
- MySQL 9.0
- Memcached
Dockerは無し。ベンチマーカーは同一マシン上で動作
導入直前までのスコア
- 初期スコア 9323
- commentsテーブルにindexを貼る 65574
getIndexをチューニング
private-isu の getIndex (GET /
) はmakePostsというメソッドがあり、この中でループクエリが発生している。
なので、これを上記の道具を使いながら、修正していく。
makePostsからmakePostを切り出し
func makePosts(results []Post, csrfToken string, allComments bool) ([]Post, error) {
var posts []Post
for _, p := range results {
_p, err := makePost(&p, csrfToken, allComments)
if err != nil {
return nil, err
}
if _p.User.DelFlg == 0 {
posts = append(posts, p)
}
if len(posts) >= postsPerPage {
break
}
}
return posts, nil
}
func makePost(p *Post, csrfToken string, allComments bool) (*Post, error) {
if err := db.Get(&p.CommentCount, "SELECT COUNT(*) AS `count` FROM `comments` WHERE `post_id` = ?", p.ID); err != nil {
return nil, err
}
query := "SELECT * FROM `comments` WHERE `post_id` = ? ORDER BY `created_at` DESC"
if !allComments {
query += " LIMIT 3"
}
var comments []Comment
if err := db.Select(&comments, query, p.ID); err != nil {
return nil, err
}
for i := 0; i < len(comments); i++ {
err := db.Get(&comments[i].User, "SELECT * FROM `users` WHERE `id` = ?", comments[i].UserID)
if err != nil {
return nil, err
}
}
// reverse
for i, j := 0, len(comments)-1; i < j; i, j = i+1, j-1 {
comments[i], comments[j] = comments[j], comments[i]
}
p.Comments = comments
if err := db.Get(&p.User, "SELECT * FROM `users` WHERE `id` = ?", p.UserID); err != nil {
return nil, err
}
p.CSRFToken = csrfToken
return p, nil
}
makePostで発行しているクエリをDataloader化する
type dataloaders struct {
commentCountByPostID *dataloadgen.Loader[int, int]
allCommentsByPostID *dataloadgen.Loader[int, []Comment]
latestCommentsByPostID *dataloadgen.Loader[int, []Comment]
userByID *dataloadgen.Loader[int, User]
}
func newDataloaders() *dataloaders {
return &dataloaders{
commentCountByPostID: dataloadgen.NewLoader(
func(ctx context.Context, keys []int) ([]int, []error) {
query := "SELECT `post_id`, COUNT(*) AS `count` FROM `comments` WHERE `post_id` IN (?) GROUP BY `post_id`"
query, args, err := sqlx.In(query, keys)
if err != nil {
errs := slices.Repeat([]error{err}, len(keys))
return nil, errs
}
type postCommentCount struct {
PostID int `db:"post_id"`
Count int `db:"count"`
}
var results []postCommentCount
if err := db.SelectContext(ctx, &results, query, args...); err != nil {
errs := slices.Repeat([]error{err}, len(keys))
return nil, errs
}
m := make(map[int]int, len(results))
for _, r := range results {
m[r.PostID] = r.Count
}
mm := it.Map(slices.Values(keys), func(k int) int {
return m[k]
})
counts := slices.Collect(mm)
return counts, nil
}),
allCommentsByPostID: dataloadgen.NewLoader(
func(ctx context.Context, keys []int) ([][]Comment, []error) {
query := "SELECT * FROM `comments` WHERE `post_id` IN (?) ORDER BY `created_at` DESC"
query, args, err := sqlx.In(query, keys)
if err != nil {
errs := slices.Repeat([]error{err}, len(keys))
return nil, errs
}
var results []Comment
if err := db.SelectContext(ctx, &results, query, args...); err != nil {
errs := slices.Repeat([]error{err}, len(keys))
return nil, errs
}
m := make(map[int][]Comment, len(keys))
for _, r := range results {
m[r.PostID] = append(m[r.PostID], r)
}
mm := it.Map(slices.Values(keys), func(k int) []Comment {
return m[k]
})
comments := slices.Collect(mm)
return comments, nil
},
),
latestCommentsByPostID: dataloadgen.NewLoader(
func(ctx context.Context, keys []int) ([][]Comment, []error) {
query := `
SELECT c.id, c.post_id, c.user_id, c.comment, c.created_at
FROM (
SELECT id, post_id, user_id, comment, created_at,
ROW_NUMBER() OVER (PARTITION BY post_id ORDER BY created_at DESC) as rn
FROM comments
WHERE post_id IN (?)
) c
WHERE c.rn <= 3
ORDER BY c.post_id, c.created_at DESC;`
query, args, err := sqlx.In(query, keys)
if err != nil {
errs := slices.Repeat([]error{err}, len(keys))
return nil, errs
}
var results []Comment
if err := db.SelectContext(ctx, &results, query, args...); err != nil {
errs := slices.Repeat([]error{err}, len(keys))
return nil, errs
}
m := make(map[int][]Comment, len(keys))
for _, r := range results {
m[r.PostID] = append(m[r.PostID], r)
}
mm := it.Map(slices.Values(keys), func(k int) []Comment {
return m[k]
})
comments := slices.Collect(mm)
return comments, nil
},
),
userByID: dataloadgen.NewLoader(
func(ctx context.Context, keys []int) ([]User, []error) {
query := "SELECT * FROM `users` WHERE `id` IN (?)"
query, args, err := sqlx.In(query, keys)
if err != nil {
errs := slices.Repeat([]error{err}, len(keys))
return nil, errs
}
var results []User
if err := db.SelectContext(ctx, &results, query, args...); err != nil {
errs := slices.Repeat([]error{err}, len(keys))
return nil, errs
}
m := make(map[int]User, len(results))
for _, r := range results {
m[r.ID] = r
}
mm := it.Map(slices.Values(keys), func(k int) User {
return m[k]
})
users := slices.Collect(mm)
return users, nil
},
),
}
}
makePostの同時実行版を作る
干渉・依存しない値は同時に取得するようにerrgroup.Groupを使っている
func makePostDL(ctx context.Context, p *Post, csrfToken string, allComments bool, dl *dataloaders) (*Post, error) {
eg := errgroup.Group{}
eg.Go(func() error {
count, err := dl.commentCountByPostID.Load(ctx, p.ID)
if err != nil {
return err
}
p.CommentCount = count
return nil
})
eg.Go(func() error {
var comments []Comment
if allComments {
_comments, err := dl.allCommentsByPostID.Load(ctx, p.ID)
if err != nil {
return err
}
comments = _comments
} else {
_comments, err := dl.latestCommentsByPostID.Load(ctx, p.ID)
if err != nil {
return err
}
comments = _comments
}
cwuiter := async.Map2(slices.All(comments), func(_ int, c Comment) (*Comment, error) {
user, err := dl.userByID.Load(ctx, c.UserID)
if err != nil {
return nil, err
}
c.User = user
return &c, nil
})
cwus, err := it.TryCollect(cwuiter)
if err != nil {
return err
}
slices.Reverse(cwus)
derefIter := it.Map(slices.Values(cwus), func(c *Comment) Comment {
return *c
})
p.Comments = slices.Collect(derefIter)
return nil
})
eg.Go(func() error {
u, err := dl.userByID.Load(ctx, p.UserID)
if err != nil {
return err
}
p.User = u
return nil
})
if err := eg.Wait(); err != nil {
return nil, err
}
p.CSRFToken = csrfToken
return p, nil
}
getIndex側でasync実行するようにする
比較できるようにgetIndex2
として作成。
並列数が多くなるとむしろスコアが下がってしまうので、LIMITをかけたりJOINによる絞り込みをやった。なので元のgetIndex
側も同様のクエリに改変して比較できるようにした。
と思ったが、73kぐらいでサチってしまっているのは、画像配信側がボトルネックになっている説がある。
func getIndex2(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
me := getSessionUser(r)
dl := newDataloaders()
rows, err := db.QueryxContext(
ctx,
"SELECT posts.id, posts.user_id, posts.body, posts.mime, posts.created_at FROM posts INNER JOIN users ON posts.user_id = users.id WHERE users.del_flg = 0 ORDER BY posts.created_at DESC LIMIT ?",
postsPerPage,
)
if err != nil {
log.Print(err)
return
}
postsIter := iterutils.FromNexter2(rows, func(rows *sqlx.Rows) (*Post, error) {
var p Post
if err := rows.StructScan(&p); err != nil {
return nil, err
}
return &p, nil
})
asyncIter := async.Map2(postsIter, func(p *Post, err error) (*Post, error) {
mp, err := makePostDL(ctx, p, getCSRFToken(r), false, dl)
if err != nil {
return nil, err
}
return mp, nil
})
var posts []Post
for p, err := range asyncIter {
if err != nil {
log.Print(err)
return
}
posts = append(posts, *p)
}
fmap := template.FuncMap{
"imageURL": imageURL,
}
template.Must(template.New("layout.html").Funcs(fmap).ParseFiles(
getTemplPath("layout.html"),
getTemplPath("index.html"),
getTemplPath("posts.html"),
getTemplPath("post.html"),
)).Execute(w, struct {
Posts []Post
Me User
CSRFToken string
Flash string
}{posts, me, getCSRFToken(r), getFlash(w, r, "notice")})
}