😺

Kotlin CoroutineでMapに対する計算を並列化して速くしてみた

2024/12/15に公開

最近Kotlinのコードを書いており、以下のわりとシンプルな問題を解決するコードを書くことになりました。

UUIDをキー、ランダムな文字列をバリューとした2つのマップがある。
それら2つのマップの差分を取得したい。
UUIDキーの集合は二つのマップにおいて共通であり、1,000,000個の要素があるとする。

e.g.)

map 1:
{
...
"4a01e0ae-b033-496b-b009-06f6590805f1": "Zenn"
"9c5c6f75-08f1-425a-887c-a2fcf2959367": "Hatena"
...
}

map 2:
{
...
"4a01e0ae-b033-496b-b009-06f6590805f1": "Zenn Dev"
"9c5c6f75-08f1-425a-887c-a2fcf2959367": "はてな ブログ"
...
}

としたら、以下の結果を吐き出すようなコードです。

"4a01e0ae-b033-496b-b009-06f6590805f1": Pair("Zenn", "Zenn Dev")

愚直な方法

愚直にやるとおそらく以下のようになると思います。

fun main() {
    // 1. Generate 1,000,000 UUIDs
    val uuids = List(1_000_000) { UUID.randomUUID() }

    // 2. Create two maps with the UUIDs as keys, assigning random strings as values
    val map1 = uuids.associateWith { generateRandomString(10) }
    val map2 = uuids.associateWith { generateRandomString(10) }

    println("Map sizes: map1=${map1.size}, map2=${map2.size}")

    // Measure the time taken to compute the diff
    val elapsedTime =
        measureTimeMillis {
            val diffResults = computeDiff(map1, map2)
            println("Number of diffs: ${diffResults.size}")
        }

    println("Time taken to compute diffs: $elapsedTime ms")
}

...

fun computeDiff(map1: Map<UUID, String>, map2: Map<UUID, String>): List<Pair<UUID, Pair<String?, String?>>> {
    val allKeys = map1.keys + map2.keys // Union of all keys
    return allKeys.mapNotNull { uuid ->
        val value1 = map1[uuid]
        val value2 = map2[uuid]
        if (value1 != value2) uuid to (value1 to value2) else null
    }
}

かかった時間は自分のローカル環境だと 302msほどでした。

Map sizes: map1=1000000, map2=1000000
Number of diffs: 1000000
Time taken to compute diffs: 302 ms

Coroutineを使って並列化する

まあ悪くはないですが、せっかくなのでcoroutineを使って並列化できないか考えてみました。全文コードは下に載せておきますが、抜粋すると以下のような感じです。

fun main() =
    runBlocking {
        ...
        println("Map sizes: map1=${map1.size}, map2=${map2.size}")

        // 3. Split the map keys into batches
        val batchSize = 10_000
        val batches = map1.keys.chunked(batchSize)

        println("Number of batches: ${batches.size}")

        // Measure the time taken to compute the diffs
        val elapsedTime =
            measureTimeMillis {
                val diffResults =
                    batches.map { batch ->
                        async(Dispatchers.Default) {
                            computeDiff(batch, map1, map2)
                        }
                    }.awaitAll().flatten() // Flatten the results

                println("Number of diffs: ${diffResults.size}")
            }

        println("Time taken to compute diffs: $elapsedTime ms")
    }
...

10000個のバッチ(この数字にとくに意味はないです、テキトーに選びました)に分割して、それぞれに対して新たなcoroutineをディスパッチする感じになります。

こちらは115msでした!

Map sizes: map1=1000000, map2=1000000
Number of batches: 100
Number of diffs: 1000000
Time taken to compute diffs: 115 ms

Dispatcherの種類について

DispatcherにDefaultを指定していますがこれはCPUバウンドの計算に適しているようです。
他にもIOバウンドに適したDispatcherも存在しています。
以下に比較を載せますが、今回のDiff計算はCPUバウンドなのでDefaultが一番速いです。

ディスパッチャ 処理時間
Dispatchers.Default 115 ms
指定なし(Mainスレッド) 157 ms
Dispatchers.IO 181 ms
ーーーー ーーーー
愚直にDiff計算 302 ms

Reference:
https://kotlinlang.org/docs/coroutine-context-and-dispatchers.html#dispatchers-and-threads

Dispatchers.Defaultを使うことで業務において大幅なパフォーマンス改善が見られたのでよかったです。

Dispatcher Defaultを使った場合の全文コード

import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.runBlocking
import java.util.UUID
import kotlin.system.measureTimeMillis

fun main() =
    runBlocking {
        // 1. Generate 100,000 UUIDs
        val uuids = List(100_0000) { UUID.randomUUID() }

        // 2. Create two maps with the UUIDs as keys, assigning random strings as values
        val map1 = uuids.associateWith { generateRandomString(10) }
        val map2 = uuids.associateWith { generateRandomString(10) }

        println("Map sizes: map1=${map1.size}, map2=${map2.size}")

        // 3. Split the map keys into batches
        val batchSize = 10_000
        val batches = map1.keys.chunked(batchSize)

        println("Number of batches: ${batches.size}")

        // Measure the time taken to compute the diffs
        val elapsedTime =
            measureTimeMillis {
                val diffResults =
                    batches.map { batch ->
                        async(Dispatchers.Default) {
                            computeDiff(batch, map1, map2)
                        }
                    }.awaitAll().flatten() // Flatten the results

                println("Number of diffs: ${diffResults.size}")
            }

        println("Time taken to compute diffs: $elapsedTime ms")
    }

// Generate a random string of the given length
fun generateRandomString(length: Int): String {
    val chars = ('A'..'Z') + ('a'..'z') + ('0'..'9')
    return (1..length).map { chars.random() }.joinToString("")
}

// Compute the diff for a batch of UUIDs
fun computeDiff(
    batch: List<UUID>,
    map1: Map<UUID, String>,
    map2: Map<UUID, String>,
): List<Pair<UUID, Pair<String?, String?>>> {
    return batch.mapNotNull { uuid ->
        val value1 = map1[uuid]
        val value2 = map2[uuid]
        if (value1 != value2) uuid to (value1 to value2) else null
    }
}

Discussion