複数列ソートにおけるカーソルページネーションの実装
こんにちは、@nerusan です。
今回は、複数列ソートにおけるカーソルページネーションについて、説明します。
カーソルページネーションとは?
カーソルページネーションは、データベースのクエリ結果をページごとに取得するための方法です。通常のページネーションでは、ページ番号やオフセットを使用してデータを取得しますが、カーソルページネーションでは、前のページの最後の要素のカーソルを使用して次のページのデータを取得します。
カーソルページネーションの利点は、ページの移動が高速であることです。ページ番号やオフセットを使用する場合、ページ数が増えるにつれてデータベースのクエリが遅くなる可能性がありますが、カーソルページネーションでは常に一定の速度でデータを取得できます。
具体的な実装方法は、データベースの種類や使用するプログラミング言語によって異なりますが、一般的にはクエリパラメータとしてカーソルを指定し、そのカーソルを使用して次のページのデータを取得します。
カーソルページネーションは、大量のデータを効率的に取得するための有用な手法です。特に、ユーザーがスクロールやページングを行うような場合には、有効になります。
オフセットページネーションがなぜ遅くなるのか?
オフセットページネーションがなぜページ数がふえるにつれて遅くなるのかを簡単に説明します。
以下のクエリを見てください。
mysql> explain select * from users order by id limit 2 offset 1000;
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------+
| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------+
| 1 | SIMPLE | users | NULL | index | NULL | PRIMARY | 8 | NULL | 1002 | 100.00 | NULL |
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------+
rowsを注目してみてください。
2件をのみを取得したいのに、rowsでは、1002件となり、対象テーブルから取得される行が1002になっているのがわかります。
これは、limit 2 offset 1000
の部分が関係しており、1000件までのソートを確認し、そこから2件取り出しているため、全体として1000+2=1002件の読み取りが発生します。
つまり、offsetが100000であれば、100002件が取得され、ページ数が増える旅に遅くなるのです。
また、データベースはインデックスを使ってクエリの開始点を特定できますが、大きな OFFSET を使用すると、インデックスまたはテーブル全体をスキャンすることが多くなります。これは特に大きなテーブルでは効率が悪くなります。
typeを見てみると、indexになっており、インデックスのフルスキャンになっており、効率が良いクエリとは言えませんね。
カーソルページネーションのクエリを見てみましょう
mysql> explain select * from users where id > 2030 order by id limit 2;
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------------+
| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------------+
| 1 | SIMPLE | users | NULL | range | PRIMARY | PRIMARY | 8 | NULL | 39 | 100.00 | Using where |
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------------+
typeはrangeとなっておりインデックスを使ったクエリで、rowも少ないですね!
実装
ユーザー間がポイント送付しあうアプリを考えるとします。
ユーザーは、あるユーザーにポイントを送付することができ、送付された合計値が獲得ポイントになります。
ユーザー一覧表示ページがあり、そこでは、無限ローディングで表示することを考えるとします。
その際のソートの要件として、以下を満たすものとします。
- 獲得ポイント獲得数の降順
- 獲得ポイントが同じ場合は、ユーザー作成日の昇順
- 獲得ポイント、作成日も同じ場合は、ユーザーIDの昇順
テーブル
以下テーブルです。
CREATE TABLE `users` (
`id` BIGINT NOT NULL AUTO_INCREMENT COMMENT 'ユーザーの識別子',
`family_name` VARCHAR(256) NOT NULL COMMENT '苗字',
`family_name_kana` VARCHAR(256) NOT NULL COMMENT '苗字カナ',
`first_name` VARCHAR(256) NOT NULL COMMENT '名前',
`first_name_kana` VARCHAR(256) NOT NULL COMMENT '名前カナ',
`email` VARCHAR(256) NOT NULL COMMENT 'メールアドレス',
`password` VARCHAR(256) NOT NULL COMMENT 'パスワードハッシュ',
`sending_point` INT NOT NULL COMMENT '送信可能ポイント',
`created_at` DATETIME(6) NOT NULL COMMENT 'レコード作成日時',
`update_at` DATETIME(6) NOT NULL COMMENT 'レコード修正日時',
PRIMARY KEY (`id`),
UNIQUE KEY `uix_email` (`email`) USING BTREE
) Engine=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='ユーザー';
CREATE TABLE `transactions` (
`id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '取引の識別子',
`sending_user_id` BIGINT UNSIGNED NOT NULL COMMENT '送信ユーザのID',
`receiving_user_id` BIGINT UNSIGNED NOT NULL COMMENT '受信ユーザのID',
`transaction_point` INT NOT NULL COMMENT '取引ポイント',
`transaction_at` DATETIME(6) NOT NULL COMMENT '取引日時',
PRIMARY KEY (`id`)
) Engine=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='取引';
Goのコード
type GetAllWithCursorParam struct {
Size int `db:"size"`
CursorPoint int `db:"point"`
CursorUserID model.UserID `db:"user_id"`
CursorCreatedAt time.Time `db:"created_at"`
}
// GetAllWithCursor ポイント順にユーザを取得する
// カーソルページネーションを使用して取得する
// ソート順は以下の通り
//
// 1. ポイントが多い順
// 2. ポイントが同じ場合は登録日が古い順
// 3. ポイントと登録日が同じ場合はユーザIDが大きい順
func (r *Repository) GetAllWithCursor(ctx context.Context, db Queryer, param GetAllWithCursorParam) ([]*entities.User, error) {
sql := `
WITH points AS (
SELECT receiving_user_id AS user_id, SUM(transaction_point) AS point
FROM transactions
GROUP BY receiving_user_id
)
SELECT u.*
FROM users AS u
INNER JOIN points AS p
ON u.id = p.user_id
WHERE
p.point < ?
OR (p.point = ? AND u.created_at > ?)
OR (p.point = ? AND u.created_at = ? AND u.id > ?)
ORDER BY p.point DESC, u.created_at ASC, u.id ASC
LIMIT ?;`
var users []*entities.User
err := db.SelectContext(ctx, &users, sql,
param.CursorPoint,
param.CursorPoint,
param.CursorCreatedAt,
param.CursorPoint,
param.CursorCreatedAt,
param.CursorUserID,
param.Size,
)
if err != nil {
return users, errors.Wrap(err, "failed to get all users in user repo")
}
return users, nil
}
type GetUsersParam struct {
Size int
}
// GetUsers ポイント順にユーザを取得する
//
// ソート順は以下の通り
//
// 1. ポイントが多い順
// 2. ポイントが同じ場合は登録日が古い順
// 3. ポイントと登録日が同じ場合はユーザIDが大きい順
func (r *Repository) GetUsers(ctx context.Context, db Queryer, param GetUsersParam) ([]*entities.User, error) {
sql := `
WITH points AS (
SELECT receiving_user_id AS user_id, SUM(transaction_point) AS point
FROM transactions
GROUP BY receiving_user_id
)
SELECT u.*
FROM users AS u
INNER JOIN points AS p
ON u.id = p.user_id
ORDER BY p.point DESC, u.created_at ASC, u.id ASC
LIMIT ?;`
var users []*entities.User
err := db.SelectContext(ctx, &users, sql,
param.Size,
)
if err != nil {
return users, errors.Wrap(err, "failed to get all users in user repo")
}
return users, nil
}
package service
import (
"context"
"encoding/base64"
"encoding/json"
"time"
"github.com/cockroachdb/errors"
"github.com/hack-31/point-app-backend/domain"
"github.com/hack-31/point-app-backend/domain/model"
"github.com/hack-31/point-app-backend/repository"
"github.com/hack-31/point-app-backend/repository/entities"
"github.com/jmoiron/sqlx"
)
type GetUsers struct {
DB repository.Queryer
UserRepo domain.UserRepo
TransactionRepo domain.TransactionRepo
TokenGenerator domain.TokenGenerator
}
func NewGetUsers(db *sqlx.DB, repo *repository.Repository, jwter domain.TokenGenerator) *GetUsers {
return &GetUsers{
DB: db,
UserRepo: repo,
TransactionRepo: repo,
TokenGenerator: jwter,
}
}
type GetUsersRequest struct {
Size int
NextCursor string
}
type GetUsersResponse struct {
Users []struct {
ID model.UserID
FirstName string
FirstNameKana string
FamilyName string
FamilyNameKana string
Email string
AcquisitionPoint int
}
NextCursor string
}
// ユーザ一覧取得サービス
//
// @params ctx コンテキスト
//
// @return
// ユーザ一覧
func (r *GetUsers) GetUsers(ctx context.Context, input GetUsersRequest) (GetUsersResponse, error) {
// ユーザ一覧を取得する
type cursor struct {
UserID model.UserID `json:"user_id"`
Point int `json:"point"`
CreatedAt time.Time `json:"created_at"`
}
var users []*entities.User
var c cursor
// 2回目以降のリクエスト
if input.NextCursor != "" {
// base64をデコードして、JSONを構造体にマッピング
data, err := base64.URLEncoding.DecodeString(input.NextCursor)
if err != nil {
return GetUsersResponse{}, errors.Wrap(err, "failed to decode nextCursor in GetUsersService.GetUsers")
}
if err := json.Unmarshal(data, &c); err != nil {
return GetUsersResponse{}, errors.Wrap(err, "failed to unmarshal nextCursor in GetUsersService.GetUsers")
}
// カーソルをもとに、次のユーザーリストを取得
users, err = r.UserRepo.GetAllWithCursor(ctx, r.DB, repository.GetAllWithCursorParam{
Size: input.Size,
CursorPoint: c.Point,
CursorUserID: c.UserID,
CursorCreatedAt: c.CreatedAt,
})
if err != nil {
return GetUsersResponse{}, errors.Wrap(err, "failed to get users in GetUsersService.GetUsers")
}
}
// 初回のリクエスト
if input.NextCursor == "" {
var err error
if input.Size == 0 {
input.Size = 10
}
// 初回はカーソルはないので、ソートして、上位のサイズ分を取得
users, err = r.UserRepo.GetUsers(ctx, r.DB, repository.GetUsersParam{
Size: input.Size,
})
if err != nil {
return GetUsersResponse{}, errors.Wrap(err, "failed to get users in GetUsersService.GetUsers")
}
}
if len(users) == 0 {
return GetUsersResponse{
Users: []struct {
ID model.UserID
FirstName string
FirstNameKana string
FamilyName string
FamilyNameKana string
Email string
AcquisitionPoint int
}{},
NextCursor: "",
}, nil
}
// ユーザIDsを取得する
userIDs := make([]model.UserID, 0, len(users))
for _, user := range users {
userIDs = append(userIDs, model.UserID(user.ID))
}
// 取得ポイントを取得する
points, err := r.TransactionRepo.GetAquistionPoint(ctx, r.DB, userIDs)
if err != nil {
return GetUsersResponse{}, errors.Wrap(err, "failed to get points in GetUsersService.GetUsers")
}
res := make([]struct {
ID model.UserID
FirstName string
FirstNameKana string
FamilyName string
FamilyNameKana string
Email string
AcquisitionPoint int
}, 0, len(users))
// ユーザに取得ポイントを設定する
for _, v := range users {
res = append(res, struct {
ID model.UserID
FirstName string
FirstNameKana string
FamilyName string
FamilyNameKana string
Email string
AcquisitionPoint int
}{
ID: model.UserID(v.ID),
FirstName: v.FirstName,
FirstNameKana: v.FirstNameKana,
FamilyName: v.FamilyName,
FamilyNameKana: v.FamilyNameKana,
Email: v.Email,
AcquisitionPoint: points[model.UserID(v.ID)],
})
}
var nextCursorStr string
// 取得ユーザー数とリクエストサイズ数が同じ場合、
// 次のページが存在する可能性があるので、カーソルを作成する
if len(users) == input.Size {
// JSONにして、base64エンコードしてクライアントに返す
data, err := json.Marshal(cursor{
UserID: model.UserID(users[len(users)-1].ID),
Point: points[model.UserID(users[len(users)-1].ID)],
CreatedAt: users[len(users)-1].CreatedAt,
})
if err != nil {
return GetUsersResponse{}, errors.Wrap(err, "failed to marshal nextCursor in GetUsersService.GetUsers")
}
nextCursorStr = base64.URLEncoding.EncodeToString(data)
}
return GetUsersResponse{
Users: res,
NextCursor: nextCursorStr,
}, nil
}
詳しくはコードの説明を見てもらえたらと思いますが、
簡単に説明します。
カーソルページネーションではカーソルというものを使います。
カーソルページネーションにおける「カーソル」とは、データベースやAPIからデータを取得する際に使用される一種のマーカーや参照点のことです。このカーソルは、データセット内の現在の位置を示し、次のページのデータを取得するために使用されます。
repository.GetAllWithCursor関数では、カーソルとして渡されたcursorPoint、cursorCreateAt、cursorUserIDを基にクエリを構築します。
カーソルは、ORDER BYで指定されている point、created_at、idが対象になります。
また、カーソルはクライアントから送付されることを仮定し、ソート順の最後のレコードの値になります。
このクエリは、次の条件を満たすレコードを取得します:
- pointがcursorPointより小さい
- または、pointがcursorPointと等しく、create_atがcursorCreateAtより大きい
- または、pointがcursorPointと等しく、create_atがcursorCreateAtと等しく、idがidcursorUserIDより大きい
この条件を付け加えることで、カーソルで指定された値よりも後のレコードのみを抽出できます。
この場合、point, create_atなどが重複する場合でも適切に抽出してくれます。
ただし、idは必ずユニークである必要があります。(詳しくは後述)
少し条件としては複雑そうに見えますが、よく考えると、正しいことがわかります。
以下のデータが入っているとします。
points.point | user.created_at | user.id |
---|---|---|
112 | 2020-10-9 | 80 |
110 | 2020-10-10 | 8 |
100 | 2020-10-10 | 1 |
100 | 2020-10-10 | 2 |
100 | 2020-10-10 | 3 |
90 | 2020-10-10 | 30 |
カーソルがの値が
- cursorPoint: 100
- cursorCretedAt: 2020-10-10
- cursorUserID: 1
の場合、以下のレコードが抽出されて、うまく抽出できているのがわかります。
points.point | user.created_at | user.id |
---|---|---|
100 | 2020-10-10 | 2 |
100 | 2020-10-10 | 3 |
90 | 2020-10-10 | 30 |
ソートの優先順位、昇順、降順でWHERE句のクエリの組み立てが異なるので注意しましょう。
service.GetUsers関数では、repository.GetAllWithCursor関数を呼び出し、レスポンスを作成します。また、次のページを取得するためのカーソル値をソート順の最後のレコードから作成します。
カーソルの作成は、JSONエンコードおよびBase64エンコードを利用します。
JSONエンコードとBase64エンコードの組み合わせを使うことで、安全かつ効率的にカーソルを取り扱うことができます。
{"user_id":46,"point":3300,"created_at":"2024-03-11T00:18:37.116025+09:00"}
⇩base64エンコード
eyJ1c2VyX2lkIjo0NiwicG9pbnQiOjMzMDAsImNyZWF0ZWRfYXQiOiIyMDI0LTAzLTExVDAwOjE4OjM3LjExNjAyNSswOTowMCJ9
ただし、クライアントが指定されたsizeよりも取得数が少ない場合は、次のページはないので、nextTokenは返却しません。
全体のコードは以下のリポジトリにおいていますので、参考にしていただければと思います。
まとめ
カーソルページネーションについて記載しました。
最近よく見るページネーション方法なので、しっかり押さえておきましょう。
Discussion