Closed4

graphql-rubyで SUM/COUNT句 を使ったN+1を解消するLoaderを試す

necocoanecocoa

Postに対して点数レビューをつけられるテーブルがあるとする。

Post Modelはこのような形。

app/models/post.rb
class Post < ApplicationRecord
  has_many :reviews, class_name: 'PostReview', dependent: :destroy
end

1投稿に対してのレビュー数とレビュー得点の平均値を出したい。

app/graphql/types/post_type.rb
module Types
  class PostType < Types::BaseObject
    field :average_review_score, Float, null: true
    field :body, String, null: false
    field :created_at, GraphQL::Types::ISO8601DateTime, null: false
    field :id, ID, null: false
    field :title, String, null: false
    field :total_reviews, Integer, null: false
    field :updated_at, GraphQL::Types::ISO8601DateTime, null: false

    def total_reviews
      Loaders::AssociationCountLoader.for(Post, :reviews).load(object)
    end

    def average_review_score
      Loaders::AssociationCountLoader.for(Post, :reviews).load(object).then do |reviews_count|
        Loaders::AssociationSumLoader.for(Post, :reviews, :score).load(object).then do |total_score|
          (total_score / reviews_count.to_f).round(2)
        end
      end
    end
  end
end
app/graphql/loaders/association_count_loader.rb
module Loaders
  class AssociationCountLoader < GraphQL::Batch::Loader
    def self.validate(model, association_name)
      new(model, association_name)
      nil
    end

    def initialize(model, association_name, where: nil)
      super()
      @model = model
      @association_name = association_name
      @reflection = reflection
      @where = where
    end

    def load(record)
      raise TypeError, "#{@model} loader can't load association for #{record.class}" unless record.is_a?(@model)

      super
    end

    def perform(records)
      counts = query(records)
      records.each do |record|
        key = record_key(record)
        fulfill(record, counts[key] || 0)
      end
    end

    private

    def reflection
      reflection = @model.reflect_on_association(@association_name)
      return reflection if reflection

      raise ArgumentError, "No association #{@association_name} on #{@model}"
    end

    def query(records)
      column = @reflection.join_primary_key
      scope = @reflection.klass
      scope = scope.where(@where) if @where
      scope.where(column => records).group(column).count
    end

    def record_key(record)
      record[@reflection.active_record_primary_key]
    end
  end
end

app/graphql/loaders/association_count_loader.rb
module Loaders
  class AssociationSumLoader < GraphQL::Batch::Loader
    def initialize(model, association_name, sum_column, where: nil)
      super()
      @model = model
      @association_name = association_name
      @reflection = reflection
      @sum_column = sum_column
      @where = where
    end

    def load(record)
      raise TypeError, "#{@model} loader can't load association for #{record.class}" unless record.is_a?(@model)

      super
    end

    def perform(records)
      sums = query(records)
      records.each do |record|
        key = record_key(record)
        fulfill(record, sums[key] || 0)
      end
    end

    private

    def reflection
      reflection = @model.reflect_on_association(@association_name)
      return reflection if reflection

      raise ArgumentError, "No association #{@association_name} on #{@model}"
    end

    def query(records)
      column = @reflection.join_primary_key
      scope = @reflection.klass
      scope = scope.where(@where) if @where
      scope.where(column => records).group(column).sum(@sum_column)
    end

    def record_key(record)
      record[@reflection.active_record_primary_key]
    end
  end
end

necocoanecocoa

record_loader のような関連ではなく、特定のレコードから取得するときの loader

Loaders::CountLoader.for(Review, :score).load(object.id)
こんな感じで使えるはず(合計得点の取得)

count_loader.rb
module Loaders
  class CountLoader < GraphQL::Batch::Loader
    def initialize(model, column, where: nil)
      super()
      @model = model
      @column = column
      @column_type = model.type_for_attribute(@column)
      @where = where
    end

    def load(key)
      super(@column_type.cast(key))
    end

    def perform(keys)
      counts = query(keys)
      keys.each { |key| fulfill(key, counts[key] || 0) }
    end

    private

    def query(keys)
      scope = @model
      scope = scope.where(@where) if @where
      scope.where(@column => keys).group(@column).count
    end
  end
end
sum_loader.rb
module Loaders
  class SumLoader < GraphQL::Batch::Loader
    def initialize(model, column, sum_column, where: nil)
      super()
      @model = model
      @column = column
      @column_type = model.type_for_attribute(@column)
      @sum_column = sum_column
      @where = where
    end

    def load(key)
      super(@column_type.cast(key))
    end

    def perform(keys)
      sums = query(keys)
      keys.each { |key| fulfill(key, sums[key] || 0) }
    end

    private

    def query(keys)
      scope = @model
      scope = scope.where(@where) if @where
      scope.where(@column => keys).group(@column).sum(@sum_column)
    end
  end
end

necocoanecocoa

CountLoader
AssociationCountLoader
どちらも where を引数として渡せる
これは RecordLoader のインターフェイスを参考にしたため。

AssociationLoader には where 引数がなかったため Associationの方はいらなかったかもしれない

necocoanecocoa

Batchの簡素な説明

def perform(records) には .load(object) で N数の object が配列で渡る。

sums = query(records) で実行したいQueryを実行し、

その結果を

records.each do |record|
  key = record_key(record)
  fulfill(record, counts[key] || 0)
end

fulfill という関数に key, value の形式で渡すと cache される。

そうすると、.load(object)objectkey が一致する value が返される。
value を使って実行を重ねたいときはブロックを渡す。

なので、Loaderの performquery を書き換えれば好きな Loader を作ることができる。
むしろ Loader を作らないと N+1 をいい感じに解消できないと思われる。

このスクラップは2021/07/31にクローズされました