😇

SwiftのMLXでConditionで他のArrayをfilterする方法

に公開
  • let cond = array .< 1とかのconditionを直接otherArray[cond]と与えてしまうと、
    libc++abi: terminating due to uncaught exception of type char const*
    という実行時エラーになってしまうので明示的にInt32のMLXArrayにしないといけない(2025/09時点)
func conditionToIndices(condition: MLXArray) -> MLXArray {
    let arange = MLX.where(condition, MLXArray(0..<condition.shape[0]), MLXArray(Int32.max))
    let sorted = MLX.sorted(arange)
    if sorted.shape[0] == 0 {
        return MLXArray([])
    }
    let index = MLX.argMax(sorted)
    return sorted[0..<index.item(Int.self)]
}
  • Indiceを取らずに、直接MLX.whereでfalseの場合にInt32.maxに振り分けて有効数字までのindiceを持ってくる手法
  • 標準機能でできたら良いのだけれど見つからず

Discussion