🐁

Kotlin/JVMでデータサイエンスしたい

2020/10/23に公開

動機

  • 機械学習を勉強したい
  • が、Pythonを書きたくない
  • Kotlinで書きたい
  • IntelliJさんの肩に乗りたい

目標

基礎の基礎である、ベクトル・行列計算とグラフ表示ができるようにする

構成

IntelliJ IDEA + Kotlin/JVM + Gradle + ND4J + lets-plot + batik

環境はWindows10です

構築する

Gradleの設定

まず、IDEAで新しくプロジェクトを作る。GradleとKotlin/JVMを使いたいのでそれらを選択して作成。

次にbuild.gradleを以下のように書き換える。lets-plotを追加するのにrepositoriesにmavenのurlとjcenterが必要になる。

バージョンは投稿時点のものなので随時調べて書き換えてください。

build.gradle
repositories {
    mavenCentral()
    maven {
        url "https://jetbrains.bintray.com/lets-plot-maven"
    }
    jcenter()
}

dependencies {
    implementation "org.jetbrains.kotlin:kotlin-stdlib"
    implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
    implementation "org.nd4j:nd4j-native-platform:0.9.1"
    implementation "org.jetbrains.lets-plot:lets-plot-common:1.5.2"
    implementation "org.jetbrains.lets-plot:lets-plot-batik:1.5.2"
    implementation "org.jetbrains.lets-plot:lets-plot-kotlin-api:1.0.0-rc1"
}

ライブラリを利用する

main.ktを作成する。

ハローワールド!

main.kt
fun main(args: Array<String>) {
    println("Hello, world")
}

環境の準備は整ったので実際に使ってみる。

行列計算(ND4J)

https://deeplearning4j.konduit.ai/nd4j/overview

JAVAの資料だがこの辺りを見れば使える。

main.kt
operator fun INDArray.plus(other: INDArray): INDArray = this.add(other)
operator fun INDArray.minus(other: INDArray): INDArray = this.sub(other)
operator fun INDArray.times(other: INDArray): INDArray = this.mmul(other)
operator fun INDArray.plus(other: Number): INDArray = this.add(other)
operator fun INDArray.minus(other: Number): INDArray = this.sub(other)
operator fun INDArray.times(other: Number): INDArray = this.mul(other)
operator fun INDArray.div(other: Number): INDArray = this.div(other)

fun matrixSample() {
    val a = Nd4j.create(doubleArrayOf(1.0,2.0,3.0,4.0), intArrayOf(2,2))
    val b = InvertMatrix.invert(a, false)
    println("a = \n$a")
    println("b = a^-1 = \n$b")
    println("a + b = \n${a + b}")
    println("a dot b = \n${a * b}")
}

上のサンプルは逆行列を求めて積を求めるもの。当然単位行列が求まる。

a = 
[[1.00,  2.00],  
 [3.00,  4.00]]
b = a^-1 = 
[[-2.00,  1.00],  
 [1.50,  -0.50]]
a + b = 
[[-1.00,  3.00],  
 [4.50,  3.50]]
a dot b = 
[[1.00,  0.00],  
 [0.00,  1.00]]

結果

行列はNd4j.create()DoubleArrayIntArray辺りを渡すことで作ることができる。第1引数が値、第2引数が行列の形。ベクトルはNd4j.arange()でNumpyのように作れたりする。

INDArrayは標準だとoperatorが使えないので、上の拡張関数を定義することで簡潔に書けるようにしている[1]

グラフ表示

画面に表示するには色々と必要。以下を準備する。JVMでの書き方は

https://github.com/JetBrains/lets-plot-kotlin/issues/5

のissueが詳しいので適宜参照。わかりやすい導入もあったり。

main.kt
private val SVG_COMPONENT_FACTORY_BATIK =
    { svg: SvgSvgElement -> BatikMapperComponent(svg, BATIK_MESSAGE_CALLBACK) }

private val BATIK_MESSAGE_CALLBACK = object : BatikMessageCallback {
    override fun handleMessage(message: String) {
        println(message)
    }

    override fun handleException(e: Exception) {
        if (e is RuntimeException) {
            throw e
        }
        throw RuntimeException(e)
    }
}

private val AWT_EDT_EXECUTOR = { runnable: () -> Unit ->
    runnable.invoke()
}

とりあえずサインカーブを書いてみる。

main.kt
fun plotSample() {
    SwingUtilities.invokeLater {

        val xArray = Nd4j.arange(0.0,60.0) / 10.0
        val yArray = Transforms.sin(xArray)
        val data = mapOf<String, Any>(
            "x" to xArray.data().asDouble(),
            "y" to yArray.data().asDouble(),
        )

        // 線グラフをかく
        val geom = geom_line() {
            x = "x"
            y = "y"
        }
        val p = ggplot(data) + geom + ggtitle("sin")


        val plotSpec = p.toSpec()
        val plotSize = DoubleVector(1000.0, 1000.0)

        val component =
            MonolithicAwt.buildPlotFromRawSpecs(plotSpec, plotSize, SVG_COMPONENT_FACTORY_BATIK, AWT_EDT_EXECUTOR) {
                for (message in it) {
                    println("PLOT MESSAGE: $message")
                }
            }


        val frame = JFrame("三角関数")
        frame.contentPane.add(component)
        frame.font = Font(Font.SANS_SERIF, Font.PLAIN, 10)
        frame.defaultCloseOperation = WindowConstants.EXIT_ON_CLOSE
        frame.pack()
        frame.isVisible = true
    }
}


結果

Nd4jのarangeはなぜか小数刻みができなかったので割って表現。その後はNumpyと同様にsinを計算することができる。

val xArray = Nd4j.arange(0.0,60.0) / 10.0
val yArray = Transforms.sin(xArray)

xとyにそれぞれを代入してデータをつくる。INDArray.data().asDouble()で行列から小数配列に変換することができる。

val data = mapOf<String, Any>(
    "x" to xArray.data().asDouble(),
    "y" to yArray.data().asDouble(),
)
// 線グラフをかく
val geom = geom_line() {
    x = "x"
    y = "y"
}
val p = ggplot(data) + geom + ggtitle("sin")

複数の線を引きたい場合は配列を連結した上で区別用のラベルをつけるとうまくいく。

val data = mapOf<String, Any>(
        "x" to xArray.data().asDouble() + xArray.data().asDouble(),
        "y" to yArray.data().asDouble() + (yArray * 0.5).data().asDouble(),
        "s" to List(xArray.size(1)) { "A" } + List(xArray.size(1)) { "B" }
)

val geom = geom_line() {
    x = "x"
    y = "y"
    color = "s"
}


グラフを複数描く

終わりに

  • Kotlinかきやすい
  • Javaの資産が使えるので思ったより色々できる
  • 行列を扱うとエラー吐くのが実行時になるため型が厳格であることの恩恵はあまりない...?
    • Pythonが人気なのも頷ける
  • それでもかきやすい

今回の全体のコードはこちら。

https://gist.github.com/organic-nailer/35fb3eb42ae601e464056201e7f70a90

参考

その他

脚注
  1. ここを参考。 ↩︎

Discussion