🙆‍♀️

MLX(Swift)でシンプルな自動微分のサンプル

に公開

iOS上で自動微分したかった
公式にmodule定義のサンプルしかなくて、自動微分だけしたかったので調べた

サンプルコード

  • 行列式をかけてtargetの値になるようなparamを推定する
var param = MLXArray([0,0,0,0] as [Float])
let matrix = MLXArray([1,2,3,4,1,2,3,4,1,2,3,4,1,2,3,4] as [Float],[4,4])
let target = MLXArray([1,2,3,4] as [Float])
let train: ([MLXArray]) -> [MLXArray] = { (param: [MLXArray]) -> [MLXArray] in
    return [(param[0].matmul(matrix) - target).square().mean()]
}
let optimizer = Adam(learningRate: 1e-2)
var state = optimizer.newState(parameter: param)
for step in 0..<200 {
    let (lossArr, grads) = MLX.valueAndGrad(train)([param])
    let (newParam, newState) = optimizer.applySingle(
        gradient: grads[0],
        parameter: param,
        state: state
    )
    param = newParam
    state = newState
}
print(param)

解説

  • MLX.valueAndGradを使う
  • trainの関数は([MLXArray]) -> [MLXArray]の形式にする
  • trainにその他の引数を渡したい場合は、closureで他の変数をcaptureできるので殆どのケースで対応可能
  • valueAndGradの引数にgradの更新をしたいMLXArrayを配列で渡すと、戻りのgradsに同じ要素数の配列でgradのMLXArrayが返ってくる
  • これをoptimizerのapplySingleに渡すと更新されたparamが返ってくる
    • optimizer.newStateは更新対象のparamが複数存在する場合はmap等で同じ個数分生成してそれぞれapplySingleが必要
  • 更新されたparamを使ってtrainのループを継続する

Discussion