🐕

複数列ソートにおけるカーソルページネーションの実装

2024/08/05に公開

こんにちは、@nerusan です。

今回は、複数列ソートにおけるカーソルページネーションについて、説明します。

カーソルページネーションとは?

カーソルページネーションは、データベースのクエリ結果をページごとに取得するための方法です。通常のページネーションでは、ページ番号やオフセットを使用してデータを取得しますが、カーソルページネーションでは、前のページの最後の要素のカーソルを使用して次のページのデータを取得します。

カーソルページネーションの利点は、ページの移動が高速であることです。ページ番号やオフセットを使用する場合、ページ数が増えるにつれてデータベースのクエリが遅くなる可能性がありますが、カーソルページネーションでは常に一定の速度でデータを取得できます。

具体的な実装方法は、データベースの種類や使用するプログラミング言語によって異なりますが、一般的にはクエリパラメータとしてカーソルを指定し、そのカーソルを使用して次のページのデータを取得します。

カーソルページネーションは、大量のデータを効率的に取得するための有用な手法です。特に、ユーザーがスクロールやページングを行うような場合には、有効になります。

オフセットページネーションがなぜ遅くなるのか?

オフセットページネーションがなぜページ数がふえるにつれて遅くなるのかを簡単に説明します。

以下のクエリを見てください。

mysql
mysql> explain select * from users order by id limit 2 offset 1000;
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------+
| id | select_type | table | partitions | type  | possible_keys | key     | key_len | ref  | rows | filtered | Extra |
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------+
|  1 | SIMPLE      | users | NULL       | index | NULL          | PRIMARY | 8       | NULL | 1002 |   100.00 | NULL  |
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------+

rowsを注目してみてください。
2件をのみを取得したいのに、rowsでは、1002件となり、対象テーブルから取得される行が1002になっているのがわかります。
これは、limit 2 offset 1000の部分が関係しており、1000件までのソートを確認し、そこから2件取り出しているため、全体として1000+2=1002件の読み取りが発生します。

つまり、offsetが100000であれば、100002件が取得され、ページ数が増える旅に遅くなるのです。

また、データベースはインデックスを使ってクエリの開始点を特定できますが、大きな OFFSET を使用すると、インデックスまたはテーブル全体をスキャンすることが多くなります。これは特に大きなテーブルでは効率が悪くなります。

typeを見てみると、indexになっており、インデックスのフルスキャンになっており、効率が良いクエリとは言えませんね。

カーソルページネーションのクエリを見てみましょう

mysql> explain select * from users where id > 2030 order by id limit 2;
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------------+
| id | select_type | table | partitions | type  | possible_keys | key     | key_len | ref  | rows | filtered | Extra       |
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------------+
|  1 | SIMPLE      | users | NULL       | range | PRIMARY       | PRIMARY | 8       | NULL |   39 |   100.00 | Using where |
+----+-------------+-------+------------+-------+---------------+---------+---------+------+------+----------+-------------+

typeはrangeとなっておりインデックスを使ったクエリで、rowも少ないですね!

実装

ユーザー間がポイント送付しあうアプリを考えるとします。
ユーザーは、あるユーザーにポイントを送付することができ、送付された合計値が獲得ポイントになります。

ユーザー一覧表示ページがあり、そこでは、無限ローディングで表示することを考えるとします。
その際のソートの要件として、以下を満たすものとします。

  • 獲得ポイント獲得数の降順
  • 獲得ポイントが同じ場合は、ユーザー作成日の昇順
  • 獲得ポイント、作成日も同じ場合は、ユーザーIDの昇順

テーブル

以下テーブルです。

schema.sql
CREATE TABLE `users` (
    `id`                      BIGINT NOT NULL AUTO_INCREMENT COMMENT 'ユーザーの識別子',
    `family_name`             VARCHAR(256) NOT NULL COMMENT '苗字',
    `family_name_kana`        VARCHAR(256) NOT NULL COMMENT '苗字カナ',
    `first_name`              VARCHAR(256) NOT NULL COMMENT '名前',
    `first_name_kana`         VARCHAR(256) NOT NULL COMMENT '名前カナ',
    `email`                   VARCHAR(256) NOT NULL COMMENT 'メールアドレス',
    `password`                VARCHAR(256) NOT NULL COMMENT 'パスワードハッシュ',
    `sending_point`           INT NOT NULL COMMENT '送信可能ポイント',
    `created_at`              DATETIME(6) NOT NULL COMMENT 'レコード作成日時',
    `update_at`               DATETIME(6) NOT NULL COMMENT 'レコード修正日時',
    PRIMARY KEY (`id`),
    UNIQUE KEY `uix_email` (`email`) USING BTREE
) Engine=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='ユーザー';

CREATE TABLE `transactions` (
    `id`                 BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '取引の識別子',
    `sending_user_id`    BIGINT UNSIGNED NOT NULL COMMENT '送信ユーザのID',
    `receiving_user_id`  BIGINT UNSIGNED NOT NULL COMMENT '受信ユーザのID',
    `transaction_point`  INT NOT NULL COMMENT '取引ポイント',
    `transaction_at`     DATETIME(6) NOT NULL COMMENT '取引日時',
    PRIMARY KEY (`id`)
) Engine=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='取引';

Goのコード

repository/users.go
type GetAllWithCursorParam struct {
	Size            int          `db:"size"`
	CursorPoint     int          `db:"point"`
	CursorUserID    model.UserID `db:"user_id"`
	CursorCreatedAt time.Time    `db:"created_at"`
}

// GetAllWithCursor ポイント順にユーザを取得する
// カーソルページネーションを使用して取得する
// ソート順は以下の通り
//
// 1. ポイントが多い順
// 2. ポイントが同じ場合は登録日が古い順
// 3. ポイントと登録日が同じ場合はユーザIDが大きい順
func (r *Repository) GetAllWithCursor(ctx context.Context, db Queryer, param GetAllWithCursorParam) ([]*entities.User, error) {
	sql := `
		WITH points AS (
			SELECT receiving_user_id AS user_id, SUM(transaction_point) AS point
			FROM transactions
			GROUP BY receiving_user_id
		)
		SELECT u.*
		FROM users AS u
		INNER JOIN points AS p
		ON u.id = p.user_id
		WHERE
		  p.point < ? 
			OR (p.point = ? AND u.created_at > ?) 
			OR (p.point = ? AND u.created_at = ? AND u.id > ?)
		ORDER BY p.point DESC, u.created_at ASC, u.id ASC
		LIMIT ?;`

	var users []*entities.User
	err := db.SelectContext(ctx, &users, sql,
		param.CursorPoint,
		param.CursorPoint,
		param.CursorCreatedAt,
		param.CursorPoint,
		param.CursorCreatedAt,
		param.CursorUserID,
		param.Size,
	)
	if err != nil {
		return users, errors.Wrap(err, "failed to get all users in user repo")
	}

	return users, nil
}

type GetUsersParam struct {
	Size int
}

// GetUsers ポイント順にユーザを取得する
//
// ソート順は以下の通り
//
// 1. ポイントが多い順
// 2. ポイントが同じ場合は登録日が古い順
// 3. ポイントと登録日が同じ場合はユーザIDが大きい順
func (r *Repository) GetUsers(ctx context.Context, db Queryer, param GetUsersParam) ([]*entities.User, error) {
	sql := `
		WITH points AS (
			SELECT receiving_user_id AS user_id, SUM(transaction_point) AS point
			FROM transactions
			GROUP BY receiving_user_id
		)
		SELECT u.*
		FROM users AS u
		INNER JOIN points AS p
		ON u.id = p.user_id
		ORDER BY p.point DESC, u.created_at ASC, u.id ASC
		LIMIT ?;`

	var users []*entities.User
	err := db.SelectContext(ctx, &users, sql,
		param.Size,
	)
	if err != nil {
		return users, errors.Wrap(err, "failed to get all users in user repo")
	}

	return users, nil
}
servce/get_users.go
package service

import (
	"context"
	"encoding/base64"
	"encoding/json"
	"time"

	"github.com/cockroachdb/errors"
	"github.com/hack-31/point-app-backend/domain"
	"github.com/hack-31/point-app-backend/domain/model"
	"github.com/hack-31/point-app-backend/repository"
	"github.com/hack-31/point-app-backend/repository/entities"
	"github.com/jmoiron/sqlx"
)

type GetUsers struct {
	DB              repository.Queryer
	UserRepo        domain.UserRepo
	TransactionRepo domain.TransactionRepo
	TokenGenerator  domain.TokenGenerator
}

func NewGetUsers(db *sqlx.DB, repo *repository.Repository, jwter domain.TokenGenerator) *GetUsers {
	return &GetUsers{
		DB:              db,
		UserRepo:        repo,
		TransactionRepo: repo,
		TokenGenerator:  jwter,
	}
}

type GetUsersRequest struct {
	Size       int
	NextCursor string
}

type GetUsersResponse struct {
	Users []struct {
		ID               model.UserID
		FirstName        string
		FirstNameKana    string
		FamilyName       string
		FamilyNameKana   string
		Email            string
		AcquisitionPoint int
	}
	NextCursor string
}


// ユーザ一覧取得サービス
//
// @params ctx コンテキスト
//
// @return
// ユーザ一覧
func (r *GetUsers) GetUsers(ctx context.Context, input GetUsersRequest) (GetUsersResponse, error) {
	// ユーザ一覧を取得する
	type cursor struct {
		UserID    model.UserID `json:"user_id"`
		Point     int          `json:"point"`
		CreatedAt time.Time    `json:"created_at"`
	}
	var users []*entities.User
	var c cursor
  // 2回目以降のリクエスト
	if input.NextCursor != "" {
    // base64をデコードして、JSONを構造体にマッピング
		data, err := base64.URLEncoding.DecodeString(input.NextCursor)
		if err != nil {
			return GetUsersResponse{}, errors.Wrap(err, "failed to decode nextCursor in GetUsersService.GetUsers")
		}
		if err := json.Unmarshal(data, &c); err != nil {
			return GetUsersResponse{}, errors.Wrap(err, "failed to unmarshal nextCursor in GetUsersService.GetUsers")
		}
    // カーソルをもとに、次のユーザーリストを取得
		users, err = r.UserRepo.GetAllWithCursor(ctx, r.DB, repository.GetAllWithCursorParam{
			Size:            input.Size,
			CursorPoint:     c.Point,
			CursorUserID:    c.UserID,
			CursorCreatedAt: c.CreatedAt,
		})
		if err != nil {
			return GetUsersResponse{}, errors.Wrap(err, "failed to get users in GetUsersService.GetUsers")
		}
	}
  // 初回のリクエスト
	if input.NextCursor == "" {
		var err error
		if input.Size == 0 {
			input.Size = 10
		}
    // 初回はカーソルはないので、ソートして、上位のサイズ分を取得
		users, err = r.UserRepo.GetUsers(ctx, r.DB, repository.GetUsersParam{
			Size: input.Size,
		})
		if err != nil {
			return GetUsersResponse{}, errors.Wrap(err, "failed to get users in GetUsersService.GetUsers")
		}
	}
	if len(users) == 0 {
		return GetUsersResponse{
			Users: []struct {
				ID               model.UserID
				FirstName        string
				FirstNameKana    string
				FamilyName       string
				FamilyNameKana   string
				Email            string
				AcquisitionPoint int
			}{},
			NextCursor: "",
		}, nil
	}

	// ユーザIDsを取得する
	userIDs := make([]model.UserID, 0, len(users))
	for _, user := range users {
		userIDs = append(userIDs, model.UserID(user.ID))
	}

	// 取得ポイントを取得する
	points, err := r.TransactionRepo.GetAquistionPoint(ctx, r.DB, userIDs)
	if err != nil {
		return GetUsersResponse{}, errors.Wrap(err, "failed to get points in GetUsersService.GetUsers")
	}

	res := make([]struct {
		ID               model.UserID
		FirstName        string
		FirstNameKana    string
		FamilyName       string
		FamilyNameKana   string
		Email            string
		AcquisitionPoint int
	}, 0, len(users))

	// ユーザに取得ポイントを設定する
	for _, v := range users {
		res = append(res, struct {
			ID               model.UserID
			FirstName        string
			FirstNameKana    string
			FamilyName       string
			FamilyNameKana   string
			Email            string
			AcquisitionPoint int
		}{
			ID:               model.UserID(v.ID),
			FirstName:        v.FirstName,
			FirstNameKana:    v.FirstNameKana,
			FamilyName:       v.FamilyName,
			FamilyNameKana:   v.FamilyNameKana,
			Email:            v.Email,
			AcquisitionPoint: points[model.UserID(v.ID)],
		})
	}
	var nextCursorStr string
  // 取得ユーザー数とリクエストサイズ数が同じ場合、
  // 次のページが存在する可能性があるので、カーソルを作成する
	if len(users) == input.Size {
    // JSONにして、base64エンコードしてクライアントに返す
		data, err := json.Marshal(cursor{
			UserID:    model.UserID(users[len(users)-1].ID),
			Point:     points[model.UserID(users[len(users)-1].ID)],
			CreatedAt: users[len(users)-1].CreatedAt,
		})
		if err != nil {
			return GetUsersResponse{}, errors.Wrap(err, "failed to marshal nextCursor in GetUsersService.GetUsers")
		}
		nextCursorStr = base64.URLEncoding.EncodeToString(data)
	}

	return GetUsersResponse{
		Users:      res,
		NextCursor: nextCursorStr,
	}, nil
}

詳しくはコードの説明を見てもらえたらと思いますが、
簡単に説明します。

カーソルページネーションではカーソルというものを使います。
カーソルページネーションにおける「カーソル」とは、データベースやAPIからデータを取得する際に使用される一種のマーカーや参照点のことです。このカーソルは、データセット内の現在の位置を示し、次のページのデータを取得するために使用されます。

repository.GetAllWithCursor関数では、カーソルとして渡されたcursorPoint、cursorCreateAt、cursorUserIDを基にクエリを構築します。
カーソルは、ORDER BYで指定されている point、created_at、idが対象になります。
また、カーソルはクライアントから送付されることを仮定し、ソート順の最後のレコードの値になります。
このクエリは、次の条件を満たすレコードを取得します:

  1. pointがcursorPointより小さい
  2. または、pointがcursorPointと等しく、create_atがcursorCreateAtより大きい
  3. または、pointがcursorPointと等しく、create_atがcursorCreateAtと等しく、idがidcursorUserIDより大きい

この条件を付け加えることで、カーソルで指定された値よりも後のレコードのみを抽出できます。
この場合、point, create_atなどが重複する場合でも適切に抽出してくれます。
ただし、idは必ずユニークである必要があります。(詳しくは後述)

少し条件としては複雑そうに見えますが、よく考えると、正しいことがわかります。

以下のデータが入っているとします。

points.point user.created_at user.id
112 2020-10-9 80
110 2020-10-10 8
100 2020-10-10 1
100 2020-10-10 2
100 2020-10-10 3
90 2020-10-10 30

カーソルがの値が

  • cursorPoint: 100
  • cursorCretedAt: 2020-10-10
  • cursorUserID: 1

の場合、以下のレコードが抽出されて、うまく抽出できているのがわかります。

points.point user.created_at user.id
100 2020-10-10 2
100 2020-10-10 3
90 2020-10-10 30

ソートの優先順位、昇順、降順でWHERE句のクエリの組み立てが異なるので注意しましょう。

service.GetUsers関数では、repository.GetAllWithCursor関数を呼び出し、レスポンスを作成します。また、次のページを取得するためのカーソル値をソート順の最後のレコードから作成します。

カーソルの作成は、JSONエンコードおよびBase64エンコードを利用します。

JSONエンコードとBase64エンコードの組み合わせを使うことで、安全かつ効率的にカーソルを取り扱うことができます。

{"user_id":46,"point":3300,"created_at":"2024-03-11T00:18:37.116025+09:00"}
⇩base64エンコード
eyJ1c2VyX2lkIjo0NiwicG9pbnQiOjMzMDAsImNyZWF0ZWRfYXQiOiIyMDI0LTAzLTExVDAwOjE4OjM3LjExNjAyNSswOTowMCJ9

ただし、クライアントが指定されたsizeよりも取得数が少ない場合は、次のページはないので、nextTokenは返却しません。

全体のコードは以下のリポジトリにおいていますので、参考にしていただければと思います。

https://github.com/hack-31/point-app-backend

まとめ

カーソルページネーションについて記載しました。
最近よく見るページネーション方法なので、しっかり押さえておきましょう。

GitHubで編集を提案

Discussion