Scala: マッピングモデルを使う構造で表示用データを高速に組み立てる方法

5 min read読了の目安(約5200字

概要

「ユーザーが複数のグループに所属する」という設計の場合。
こんなかんじのマッピングテーブルを使う実装ってよくあると思います。

case class UserId(value: Long)
case class User (
  id:     UserId,
  name:   String
)

case class GroupId(value: Long)
case class Group (
  id:     GroupId,
  name:   String
)

case class GroupMemberId(value: Long)
case class GroupMember (
  id:      GroupMemberId,
  userId:  UserId,
  groupId: GroupId
)

で、このデータ構造から下記のような「ユーザ一覧とその所属グループ一覧」を取得したい場合のScalaでの良い書き方と悪い書き方を紹介していこうと思います。

case class UserView (
  id:     UserId,
  name:   String,
  groups: Seq[Group]
)

言語やフレームワークによってはLEFT JOINとか使ってクエリの方で解決する手段もあると思いますが、今回はScalaの世界で実装する前提で書きます。
ちなみにバージョンは 2.13.1です。

悪いパターン

実装確認

まず悪いパターンから。

testA.scala
val userSeq: Seq[User] = UserRepository.findAll();
val groupSeq: Seq[Group] = GroupRepository.findAll();
val groupMemberSeq: Seq[GroupMember] = GroupMembersRepository.findAll();

val result: Seq[UserView] = userSeq.map(user => { // ①
  val groupIdSeq: Seq[GroupId] =
    groupMemberSeq
      .filter(_.userId == user.id) // ②
      .map(_.groupId)

  val userGroups: Seq[Group] = // ③
    groupSeq.filter(group =>
      groupIdSeq.contains(group.id)
    )
    
  UserView(
    user.id,
    user.name,
    userGroups
  )
})
  1. userSeqmapで回して
  2. userIdGroupMemberを特定し
  3. それに紐づくGroupを取得する

といった流れですね。
仕様をそのまま素直に書くとこういう書き方になるかと思います。
どこがいけないのでしょうか?
では実行速度を見てみます。
まずは10件。

testA.scala
// データを10件用意
val userSeq:        Seq[User]        = (1L to 10L).map(id => User(UserId(id), "name" + id.toString))
val groupSeq:       Seq[Group]       = (1L to 10L).map(id => Group(GroupId(id), "name" + id.toString))
val groupMemberSeq: Seq[GroupMember] = (1L to 10L).map(id => GroupMember(GroupMemberId(id), UserId(id), GroupId(id)))

val startTime = System.currentTimeMillis
val result: Seq[UserView] = ...
println((System.currentTimeMillis - startTime).toString + "ミリ秒")

3回ほど実行。

$ scala testA.scala
3ミリ秒
$ scala testA.scala
2ミリ秒
$ scala testA.scala
3ミリ秒

問題なさそう。
では10,000件。

testA.scala
// データを10,000件用意
val userSeq:        Seq[User]        = (1L to 10000L).map(id => User(UserId(id), "name" + id.toString))
val groupSeq:       Seq[Group]       = (1L to 10000L).map(id => Group(GroupId(id), "name" + id.toString))
val groupMemberSeq: Seq[GroupMember] = (1L to 10000L).map(id => GroupMember(GroupMemberId(id), UserId(id), GroupId(id)))
...

同じく3回実行。

$ scala testA.scala
5231ミリ秒
$ scala testA.scala
4681ミリ秒
$ scala testA.scala
5225ミリ秒

悲惨ですね。なぜこうなってしまうのでしょうか。

問題箇所

計算量を確認します。

testA.scala
val result: Seq[UserView] = userSeq.map(user => { // ①ユーザー人数分回る
  val groupIdSeq: Seq[GroupId] =
    groupMemberSeq
      .filter(_.userId == user.id) // ②マッピング数分回る
      .map(_.groupId)

  val userGroups: Seq[Group] =
    groupSeq.filter(group => // ③グループ数分回る
      groupIdSeq.contains(group.id)
    )

  UserView(
    user.id,
    user.name,
    userGroups
  )
})    

filterはコレクションの先頭から末尾まで確認するので、計算時間は線形です。
なので計算量は

\begin{aligned} x &= O(①*②+①*③)\\ &= O(n*n+n*n)\\ &= O(n^2+n^2)\\ &= O(2n^2) \end{aligned}

となり、ユーザー数、グループ数、マッピング数それぞれが10,000件(n=10,000)の場合
200,000,000 = 2*10000*10000
2億回。絶望的ですね。
O(n^2) の部分をなんとかすれば速くなりそうです。

余談: Zennのフォント、O(オー)が縦長なので数字の0(ゼロ)っぽくてややこしいですね
数式用のMarkdown記法使えば綺麗に書けました😓

良いパターン

testB.scala
val startTime = System.currentTimeMillis
val groupMap: Map[UserId, Seq[Group]] =
  groupMemberSeq.groupBy(_.userId) map { case (userId, members) =>
    (
      userId,
      members.flatMap(member =>
        groupSeq.find(_.id == member.groupId)
      )
    )
}
val result: Seq[UserView] = userSeq.map(user =>
  UserView(
    user.id,
    user.name,
    groupMap(user.id)
  )
)
println((System.currentTimeMillis - startTime).toString + "ミリ秒")

要らないと思いますが10件で実行した結果。

$ scala testB.scala
4ミリ秒
$ scala testB.scala
4ミリ秒
$ scala testB.scala
2ミリ秒

大差無いですね。
では10,000件

$ scala testB.scala
472ミリ秒
$ scala testB.scala
420ミリ秒
$ scala testB.scala
470ミリ秒

うーん、爆速ってわけじゃないですが割とマシになりました。

工夫した点

testB.scala
// ①マップを生成
val groupMap: Map[UserId, Seq[Group]] =
  groupMemberSeq.groupBy(_.userId) map { case (userId, members) => // ②マッピング数分回る
    (
      userId,
      members.flatMap(member => // ③そのユーザーのマッピング数分回る
        groupSeq.find(_.id == member.groupId) // ④Vectorのfindは実質定数
      )
    )
}
val result: Seq[UserView] = userSeq.map(user => // ⑤ユーザー数分回る
  UserView(
    user.id,
    user.name,
    groupMap(user.id) // ⑥Map(HashMap)のfindなので実質定数
  )
)
  • ①で最初にGroupMemberを使って、ユーザーIdをキーにしたグループリストのマップを作ってしまいます。
  • ②でマッピング数分回るので10,000回
  • ③は今回の例では1回。ユーザーが10グループに所属してたら10回となりますがさほど影響しません
  • ④はfindなので線形、Vectorを使えば実質定数です(性能特性
  • ⑤はユーザー数分

なので

\begin{aligned} x &= O(②*③*④+⑤*⑥)\\ &= O(n*1*1+n*1)\\ &= O(n + n)\\ &= O(2n) \end{aligned}

となりn=10,000なら20,000。
2万回で済む、というわけですね。
全ユーザーが10グループに属しているとしても11万回で、悪いパターンの2億回よりはマシです。

まとめ

Scala勉強し始めたころはfilterが便利すぎて多用してしまいがちですが、計算量はしっかり意識して実装したいですね。
findは定数なので、groupByでマップを作ってしまってキーでfindする、というこの方法は割と色んな場面で使えるのではないかと思います。
また、データ量が少ないと問題が表面化しないのも罠ですね。

お読み頂きありがとうございました。
改善の余地はある気がするので、
間違っている点、改善できる点あればご指摘頂けますと幸いです🙇‍♂️