candleで簡単なネットワークを実装してみる
candle の練習として ゼロから作る Deep Learningのコード を参考に実装してみる。
cargo new
したら、
cargo add --git https://github.com/huggingface/candle.git candle-core
で cargo-core
を追加すると基本的なテンソル処理とかができるようになる。
参考にするコードは candle
のレポ内にある、candle-core
の examples
と candle-nn
の実装がいいと思う。
HuggingFace 公式による PyTorch との読み替えチートシート:
まだまともなドキュメントがないので、頑張ってコード漁っていく必要があるが、公式のレポ自体が実装例になってるので案外どうにかなる。
ch02/and_gate
use candle_core::{Device, Result, Tensor};
pub fn and(a: &Tensor, b: &Tensor, device: &Device) -> Result<Tensor> {
let inputs = Tensor::cat(&[a, b], 0)?;
let weights = Tensor::new(&[0.5f32, 0.5f32], &device)?;
let bias = Tensor::new(&[-0.7f32], &device)?;
let sum = (weights * &inputs)?.sum(0)?.broadcast_add(&bias)?;
if sum.get(0)?.to_scalar::<f32>()? <= 0.0f32 {
Ok(Tensor::new(&[0.0f32], &device)?)
} else {
Ok(Tensor::new(&[1.0f32], &device)?)
}
}
fn main() -> Result<()> {
let device = Device::Cpu;
let pairs: Vec<[f32; 2]> = vec![[0., 0.], [1., 0.], [0., 1.], [1., 1.]];
for pair in pairs {
let a = Tensor::new(&[pair[0]], &device)?;
let b = Tensor::new(&[pair[1]], &device)?;
let output = and(&a, &b, &device)?;
println!(
"{} AND {} => {}",
pair[0],
pair[1],
output.get(0)?.to_scalar::<f32>()?
);
}
Ok(())
}
0 AND 0 => 0
1 AND 0 => 0
0 AND 1 => 0
1 AND 1 => 1
AND処理という相当シンプルなので、わざわざTensorにしなくてもいいとは思うが、candleの練習としてやってみた。
Tensor::new
するときに毎回デバイスを指定していることや、自動でブロードキャストしてくれないので明示的に .broadcast_add(&bias)?
を呼んだりしていることが PyTorch と比較して差として感じる。
.broadcast_add(&bias)
は + &bias
にするとサイズが合わないというランタイムエラーになった。(.add(&bias)
と同じ処理が実行されていると思う)
あと目立つのはいろんなところについてる ?
だが、これはめんどうなエラー処理を簡略化してるもので、特に気にしなくていい。Tensor の作成や足し算のような演算含めて、ほとんどの操作を行ったあとに ?
をつける感じになっている。
fn and
に pub
がついてるのは、あとで XOR で使うため。
ch2/or_gate, nand_gate, xor_gate
AND 以外も実装する。といっても OR と NAND は関数名とパラメータを変更するだけである。
詳細は本家レポのコードを参照:
or_gate
use candle_core::{Device, Result, Tensor};
pub fn or(a: &Tensor, b: &Tensor, device: &Device) -> Result<Tensor> {
let inputs = Tensor::cat(&[a, b], 0)?;
let weights = Tensor::new(&[0.5f32, 0.5f32], &device)?;
let bias = Tensor::new(&[-0.2f32], &device)?;
let sum = (weights * &inputs)?.sum(0)?.broadcast_add(&bias)?;
if sum.get(0)?.to_scalar::<f32>()? <= 0.0f32 {
Ok(Tensor::new(&[0.0f32], &device)?)
} else {
Ok(Tensor::new(&[1.0f32], &device)?)
}
}
fn main() -> Result<()> {
let device = Device::Cpu;
let pairs: Vec<[f32; 2]> = vec![[0., 0.], [1., 0.], [0., 1.], [1., 1.]];
for pair in pairs {
let a = Tensor::new(&[pair[0]], &device)?;
let b = Tensor::new(&[pair[1]], &device)?;
let output = or(&a, &b, &device)?;
println!(
"{} OR {} => {}",
pair[0],
pair[1],
output.get(0)?.to_scalar::<f32>()?
);
}
Ok(())
}
0 OR 0 => 0
1 OR 0 => 1
0 OR 1 => 1
1 OR 1 => 1
nand_gate
use candle_core::{Device, Result, Tensor};
pub fn nand(a: &Tensor, b: &Tensor, device: &Device) -> Result<Tensor> {
let inputs = Tensor::cat(&[a, b], 0)?;
let weights = Tensor::new(&[-0.5f32, -0.5f32], &device)?;
let bias = Tensor::new(&[0.7f32], &device)?;
let sum = (weights * &inputs)?.sum(0)?.broadcast_add(&bias)?;
if sum.get(0)?.to_scalar::<f32>()? <= 0.0f32 {
Ok(Tensor::new(&[0.0f32], &device)?)
} else {
Ok(Tensor::new(&[1.0f32], &device)?)
}
}
fn main() -> Result<()> {
let device = Device::Cpu;
let pairs: Vec<[f32; 2]> = vec![[0., 0.], [1., 0.], [0., 1.], [1., 1.]];
for pair in pairs {
let a = Tensor::new(&[pair[0]], &device)?;
let b = Tensor::new(&[pair[1]], &device)?;
let output = nand(&a, &b, &device)?;
println!(
"{} NAND {} => {}",
pair[0],
pair[1],
output.get(0)?.to_scalar::<f32>()?
);
}
Ok(())
}
0 NAND 0 => 1
1 NAND 0 => 1
0 NAND 1 => 1
1 NAND 1 => 0
XOR
mod and;
mod nand;
mod or;
use candle_core::{Device, Result, Tensor};
use and::and;
use nand::nand;
use or::or;
pub fn xor(a: &Tensor, b: &Tensor, device: &Device) -> Result<Tensor> {
let nand_output = nand(a, b, device)?;
let or_output = or(a, b, device)?;
let and_output = and(&nand_output, &or_output, device)?;
Ok(and_output)
}
fn main() -> Result<()> {
let device = Device::Cpu;
let pairs: Vec<[f32; 2]> = vec![[0., 0.], [1., 0.], [0., 1.], [1., 1.]];
for pair in pairs {
let a = Tensor::new(&[pair[0]], &device)?;
let b = Tensor::new(&[pair[1]], &device)?;
let output = xor(&a, &b, &device)?;
println!(
"{} XOR {} => {}",
pair[0],
pair[1],
output.get(0)?.to_scalar::<f32>()?
);
}
Ok(())
}
0 XOR 0 => 0
1 XOR 0 => 1
0 XOR 1 => 1
1 XOR 1 => 0
AND、OR、NAND を組み合わせて XOR を作ることができた。