🍍

GoでPostgreSQLのRLS(Row Level Security)を実装してみた

2022/06/23に公開

はじめに

最近GoでマルチテナントSaaSの開発をやっておりまして、PostgreSQLのRLSの実装をはじめてやって見ましたので、その内容について共有したいと思います。

RLSとは

テーブルレベルで行う設定でSELECTUPDATEINSERTDELETEで影響を受ける行を制限する仕組みです。データベースエンジンが管理する自動化したWHERE句と考えることができます。

簡単な例で見てみましょう。

-- テナントテーブル作成及び、初期データ導入
CREATE TABLE IF NOT EXISTS tenants (
  id serial PRIMARY KEY,
  name varchar(100) NOT NULL UNIQUE
);

INSERT INTO tenants (name) values ('tenant1');
INSERT INTO tenants (name) values ('tenant2');

上記の状態で全件検索してみましょう。

SELECT * FROM tenants;
 id |  name
----+---------
  1 | tenant1
  3 | tenant2
(2 rows)

全件見れますよね。次はRLSポリシーを設定してみましょう。

-- id = 1のrow dataしか扱うことできないようにRLSを設定する
ALTER TABLE tenants ENABLE ROW LEVEL SECURITY;
CREATE POLICY tenant_isolation_policy ON tenants USING (id = 1);

それでは、もう一度全件検索してみましょう。

SELECT * FROM tenants;
 id |  name
----+---------
  1 | tenant1
(1 row)

どうでしょう。WHERE id = 1みたいな条件を指定してしてないですが、tenant1しか見れないですよね。これがRLSです。

これで何が嬉しいかというと一つのデータベース、一つのschemaでテナントごとのデータにそれぞれ安全にアクセスできるところです。

RLSの必要性

多くのマルチテナントSaaSではテナント間のデータアクセスによるセキュリティリスクを抱えていて、様々な対策を取っていると思いますが、DBだけで簡単に設定できる機能としてRLSは有用ですし、コスパもよく、インフラを分けたり、databaseあるいはschemaを分けるより運用も簡単なので、おすすめしたいです。

RLSを使うための考慮事項

  • Superuser(admin)はRLSポリシーが適用されない
  • BYPASSRLS属性を利用して作成したロール、ユーザーはRLSポリシーが適用されない
  • テーブルのオーナーはFORCE ROW LEVEL SECURITYしない限り、RLSポリシーが適用されない
  • USING句に固定値を使うわけには行かないので、リクエスト度に設定する変数を使う必要がある
  • RLSはテーブルレベルの設定なので、schema、user、roleと別々で考えて設計しても良い

実際のRLS設定

上記をの考慮事項を踏まえて実際の設定を共有します。

DBユーザーの使い分け

  • 読み取り専用ユーザー、書き込みできるユーザー、アプリケーションadminユーザーを作成
    • 読み取り専用ユーザーと書き込みできるユーザーはRLSを適用して用途に合わせてアプリケーションから使う
    • アプリケーションadminユーザーはRLSポリシーを無視できるBYPASSRLS属性を付与することで全テナントを管理するツールなどで使う

USING句で使う条件について

アプリケーションから接続する度に扱うテナントを変えたいという仕様についてはDBのセッション変数を使って制御する仕組みを考えて見ました。

-- アプリケーションからDB接続する度に最初に下記のような変数を設定します。
SET LOCAL current_tenant_id = <アクセスしたいtenantのIDを指定する>

このセッション変数使うためのRLS設定は下記にようになります。

ALTER TABLE tenants ENABLE ROW LEVEL SECURITY;
CREATE POLICY tenant_isolation_policy ON tenants
USING (id = current_setting('current_tenant_id')::INTEGER);

これでセッション変数に扱うテナントIDをセットするだけで、他のテナントのデータが一切見れなくなりますので、テナントを跨いだデータアクセスによるセキュリティ事故を防ぐことができるわけです。

Goでの実装方法

では、実際私が今携わっているプロダクトではどんな実装方法をしているかについて共有したいと思います。主にRest APIの開発になります。

※今回掲載するコードはサンプルコードです。実際のコードと違いますが、イメージを掴んでもらえると幸いです。

アプリケーションの技術スタック

※他にも色々ありますが、今回は使う箇所のみ記載しています。

DB処理の流れ

db-flow

Goコード

実装の要約

  • トランザクション開始、終了はGinのmiddlewareで実装
  • 開始したトランザクションハンドラーをcontextにセットして後続処理に渡す

middlewareのコードを見てみましょう。

トランザクション開始するためのmiddlewareコードはこちら。

beforeHandlerMiddleware
// すべてのハンドラーの処理開始前にトランザクションを開始するためのmiddleware
func beforeHandlerMiddleware() gin.HandlerFunc {
  return func(c *gin.Context) {
    // トランザクション開始
    tx, txErr := boil.BeginTx(c, &sql.TxOptions{Isolation: sql.LevelDefault, ReadOnly: true})
    if txErr != nil {
      c.AbortWithError(http.StatusInternalServerError, errors.New("failed to begin tx"))
      return
    }

    // セッション変数設定
    if _, err := tx.ExecContext(c, "SET LOCAL current_tenant_id=?", 1); err != nil {
      c.AbortWithError(http.StatusInternalServerError, errors.New("failed to set current_tenant_id"))
      return
    }

    // contextにtxを保存する
    c.Set("txParam", tx)

    c.Next()
  }
}

トランザクション終了するためのmiddlewareコードはこちら。

afterHandlerMiddleware
// すべてのハンドラーの処理終了後にトランザクションを終了するためのmiddleware
func afterHandlerMiddleware() gin.HandlerFunc {
  return func(c *gin.Context) {
    // ハンドラー処理が終わったあとに実行するためにdeferを使っている
    defer func() {
      // contextからtxを取り出す
      txValue, exists := c.Get("txParam")
      if !exists {
        c.AbortWithError(http.StatusInternalServerError, errors.New("tx not exits"))
        return
      }

      tx, ok := txValue.(*sql.Tx)
      if !ok {
        c.AbortWithError(http.StatusInternalServerError, errors.New("failed to convert tx"))
        return
      }

      // エラー発生したときの処理
      if err := c.Errors.Last(); err != nil {
        // rollbackする
        tx.Rollback()
        c.AbortWithStatus(http.StatusInternalServerError)
        return
      }

      // 成功した時はcommitして終了
      tx.Commit()
    }()
    c.Next()
  }
}

そして上記のmiddlewareを含めて全体のサンプルコードはこちらです。

main.go
package main

import (
  "errors"
  "database/sql"
  "net/http"

  "github.com/gin-gonic/gin"

  "github.com/volatiletech/sqlboiler/v4/boil"
)

func main() {
  // DB接続
  db, dbErr := sql.Open("postgres", "host= port= dbname= user= search_path= sslmode=")
  if dbErr != nil {
    panic("failed to open database")
  }

  if err := db.Ping(); err != nil {
    panic("failed to connect database")
  }

  // sqlboilerにDBハンドラーを設定する
  boil.SetDB(db)

  router := gin.Default()

  router.Use(afterHandlerMiddleware())
  router.Use(beforeHandlerMiddleware())

  router.GET("/tenants", func(c *gin.Context) {
    // contextからtxを取り出す
    txValue, exists := c.Get("txParam")
    if !exists {
      c.AbortWithError(http.StatusInternalServerError, errors.New("tx not exits"))
      return
    }

    tx, ok := txValue.(*sql.Tx)
    if !ok {
      c.AbortWithError(http.StatusInternalServerError, errors.New("failed to convert tx"))
      return
    }

    // トランザクション内でSQL実行
    // tenant_id = 1のデータしか取れない
    _, err := tx.QueryContext(c, "SELECT name FROM tenants")

    if err != nil {
      c.AbortWithError(http.StatusInternalServerError, errors.New("failed to execute sql"))
      return
    }

    c.Status(http.StatusOK)
  })
}

// すべてのハンドラーの処理開始前にトランザクションを開始するためのmiddleware
func beforeHandlerMiddleware() gin.HandlerFunc {
  return func(c *gin.Context) {
    // トランザクション開始
    tx, txErr := boil.BeginTx(c, &sql.TxOptions{Isolation: sql.LevelDefault, ReadOnly: true})
    if txErr != nil {
      c.AbortWithError(http.StatusInternalServerError, errors.New("failed to begin tx"))
      return
    }

    // セッション変数設定
    if _, err := tx.ExecContext(c, "SET LOCAL current_tenant_id = ?", 1); err != nil {
      c.AbortWithError(http.StatusInternalServerError, errors.New("failed to set current_tenant_id"))
      return
    }

    // contextにtxを保存する
    c.Set("txParam", tx)

    c.Next()
  }
}

// すべてのハンドラーの処理終了後にトランザクションを終了するためのmiddleware
func afterHandlerMiddleware() gin.HandlerFunc {
  return func(c *gin.Context) {
    // ハンドラー処理が終わったあとに実行するためにdeferを使っている
    defer func() {
      // contextからtxを取り出す
      txValue, exists := c.Get("txParam")
      if !exists {
        c.AbortWithError(http.StatusInternalServerError, errors.New("tx not exits"))
        return
      }

      tx, ok := txValue.(*sql.Tx)
      if !ok {
        c.AbortWithError(http.StatusInternalServerError, errors.New("failed to convert tx"))
        return
      }

      // エラー発生したときの処理
      if err := c.Errors.Last(); err != nil {
        // rollbackする
        tx.Rollback()
        c.AbortWithStatus(http.StatusInternalServerError)
        return
      }

      // 成功した時はcommitして終了
      tx.Commit()
    }()
    c.Next()
  }
}

すべてのDB処理にトランザクションを使う理由

セッション変数を設定する方法として、JavaだとgetConnectionメソッドをOverrideして簡単に実装できそうですが、GoだとConnectionを取得するところをOverrideする術を見つからなかったので、トランザクションを使うことにしました。

javaの例
@Override
public Connection getConnection() throws SQLException {
  Connection connection = super.getConnection();
  try (Statement sql = connection.createStatement()) {
    sql.execute("SET current_tenant_id = 1");
  }
  return connection;
}

Contextにトランザクションハンドラーを保存する理由

複数のgoroutineが絡むことによって生じる煩わしさを解決するために用意されたのがcontextなので、所謂マルチスレッド対策としてリクエストスコープな値として扱う必要があるので、contextを使いました。

最後に

はじめてGoでアプリケーション開発する上で、RLSまで採用したことによって結構ハードルあげてしまったんですが、自分なりのベストプラクティス方法を見つけて実装まで完成しただけで嬉しかったです。もちろんもっと賢い方法があるかもしれませんが、とりあえずGoでRLS導入を検討している方達に参考になる情報になれれば幸いでございます。

Discussion