🙄
goでNeuralNetworkをフルスクラッチしてみた(コードのみ)
../apps/scratchai/cmd/digitguesser/main.go
package main
import "scratchai/internal/digitguesser"
func main() {
digitguesser.TrainMnist()
}
../apps/scratchai/internal/mnist/utils.go
package mnist
import (
"lib/pkg/ann"
"path"
)
// Set represents a data set of image-label pairs held in memory
type Set struct {
NRow int
NCol int
Images []RawImage
Labels []Label
}
// ReadSet reads a set from the images file iname and the corresponding labels file lname
func ReadSet(iname, lname string) (set *Set, err error) {
set = &Set{}
if set.NRow, set.NCol, set.Images, err = ReadImageFile(iname); err != nil {
return nil, err
}
if set.Labels, err = ReadLabelFile(lname); err != nil {
return nil, err
}
return
}
// Count returns the number of points available in the data set
func (s *Set) Count() int {
return len(s.Images)
}
// Get returns the i-th image and its corresponding label
func (s *Set) Get(i int) (RawImage, Label) {
return s.Images[i], s.Labels[i]
}
// Sweeper is an iterator over the points in a data set
type Sweeper struct {
set *Set
i int
}
// Next returns the next image and its label in the data set.
// If the end is reached, present is set to false.
func (sw *Sweeper) Next() (image RawImage, label Label, present bool) {
if sw.i >= len(sw.set.Images) {
return nil, 0, false
}
return sw.set.Images[sw.i], sw.set.Labels[sw.i], true
}
// Sweep creates a new sweep iterator over the data set
func (s *Set) Sweep() *Sweeper {
return &Sweeper{set: s}
}
// Load reads both the training and the testing MNIST data sets, given
// a local directory dir, containing the MNIST distribution files.
func Load(dir string) (train, test *Set, err error) {
if train, err = ReadSet(path.Join(dir, "train-images-idx3-ubyte.gz"), path.Join(dir, "train-labels-idx1-ubyte.gz")); err != nil {
return nil, nil, err
}
if test, err = ReadSet(path.Join(dir, "t10k-images-idx3-ubyte.gz"), path.Join(dir, "t10k-labels-idx1-ubyte.gz")); err != nil {
return nil, nil, err
}
return
}
func LoadData(dir string) (dataset ann.Dataset, labels []ann.Labels, err error) {
train, test, err := Load(dir)
if err != nil {
return nil, nil, err
}
// Combine train and test sets
totalSize := train.Count() + test.Count()
dataset = make(ann.Dataset, totalSize)
labels = make([]ann.Labels, totalSize)
index := 0
// Process training data
for i := 0; i < train.Count(); i++ {
image, label := train.Get(i)
dataset[index] = imageToData(image)
labels[index] = labelToLabels(label)
index++
}
// Process test data
for i := 0; i < test.Count(); i++ {
image, label := test.Get(i)
dataset[index] = imageToData(image)
labels[index] = labelToLabels(label)
index++
}
return dataset, labels, nil
}
// Helper function to convert RawImage to ann.Data
func imageToData(image RawImage) ann.Data {
data := make(ann.Data, len(image))
for i, pixel := range image {
data[i] = ann.Number(pixel) / 255.0 // Normalize pixel values to [0, 1]
}
return data
}
// Helper function to convert Label to ann.Labels
func labelToLabels(label Label) ann.Labels {
labels := make(ann.Labels, 10) // 10 possible digits (0-9)
labels[label] = 1.0 // Set the correct label to 1.0, others remain 0.0
return labels
}
../apps/scratchai/internal/mnist/loader_test.go
package mnist
import "testing"
import (
"fmt"
)
func TestReadLabelFile(t *testing.T) {
ll, err := ReadLabelFile("data/t10k-labels-idx1-ubyte.gz")
if err != nil {
t.Fatalf("read (%s)", err)
}
if len(ll) != 10000 {
t.Errorf("unexpected count %d", len(ll))
}
}
func TestReadImageFile(t *testing.T) {
nrow, ncol, imgs, err := ReadImageFile("data/t10k-images-idx3-ubyte.gz")
if err != nil {
t.Fatalf("read (%s)", err)
}
if len(imgs) != 10000 {
t.Errorf("unexpected count %d", len(imgs))
}
fmt.Printf("%d images, %dx%d format\n", len(imgs), nrow, ncol)
}
func TestLoad(t *testing.T) {
train, test, err := Load("./data")
if err != nil {
t.Fatalf("load (%s)", err)
}
println(train.Count(), test.Count())
println(train.Images[0])
}
../apps/scratchai/internal/mnist/loader.go
package mnist
import (
"compress/gzip"
"encoding/binary"
"image"
"image/color"
"io"
"os"
)
const (
imageMagic = 0x00000803
labelMagic = 0x00000801
Width = 28
Height = 28
)
// RawImage Image holds the pixel intensities of an image.
// 255 is foreground (black), 0 is background (white).
type RawImage []byte
func (img RawImage) ColorModel() color.Model {
return color.GrayModel
}
func (img RawImage) Bounds() image.Rectangle {
return image.Rectangle{
Min: image.Point{},
Max: image.Point{X: Width, Y: Height},
}
}
func (img RawImage) At(x, y int) color.Color {
return color.Gray{Y: img[y*Width+x]}
}
// ReadImageFile opens the named image file (training or test), parses it and
// returns all images in order.
func ReadImageFile(name string) (rows, cols int, imgs []RawImage, err error) {
f, err := os.Open(name)
if err != nil {
return 0, 0, nil, err
}
defer f.Close()
z, err := gzip.NewReader(f)
if err != nil {
return 0, 0, nil, err
}
return readImageFile(z)
}
func readImageFile(r io.Reader) (rows, cols int, imgs []RawImage, err error) {
var (
magic int32
n int32
nrow int32
ncol int32
)
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
return 0, 0, nil, err
}
if magic != imageMagic {
return 0, 0, nil, os.ErrInvalid
}
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
return 0, 0, nil, err
}
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
return 0, 0, nil, err
}
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
return 0, 0, nil, err
}
imgs = make([]RawImage, n)
m := int(nrow * ncol)
for i := 0; i < int(n); i++ {
imgs[i] = make(RawImage, m)
m_, err := io.ReadFull(r, imgs[i])
if err != nil {
return 0, 0, nil, err
}
if m_ != int(m) {
return 0, 0, nil, os.ErrInvalid
}
}
return int(nrow), int(ncol), imgs, nil
}
// Label is a digit label in 0 to 9
type Label uint8
// ReadLabelFile opens the named label file (training or test), parses it and
// returns all labels in order.
func ReadLabelFile(name string) (labels []Label, err error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
defer f.Close()
z, err := gzip.NewReader(f)
if err != nil {
return nil, err
}
return readLabelFile(z)
}
func readLabelFile(r io.Reader) (labels []Label, err error) {
var (
magic int32
n int32
)
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
return nil, err
}
if magic != labelMagic {
return nil, os.ErrInvalid
}
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
return nil, err
}
labels = make([]Label, n)
for i := 0; i < int(n); i++ {
var l Label
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return nil, err
}
labels[i] = l
}
return labels, nil
}
../apps/scratchai/internal/digitguesser/app.go
package digitguesser
import (
"lib/pkg/ann"
"math/rand"
)
const (
InputSize = 28 * 28
HiddenLayer1Size = 16
HiddenLayer2Size = 16
OutputLayerSize = 10
BatchSize = 1000
InitialLearningRate = ann.Number(5)
MinLearningRate = 0.02
LearningRateDecay = 0.99
)
type App[T any] struct {
network *ann.Network
processor ann.Processor[T]
dataset ann.Dataset
labels []ann.Labels
}
func NewApp[T any](processor ann.Processor[T], dataset ann.Dataset, labels []ann.Labels) *App[T] {
// レイヤーを作成
inputLayer := ann.NewLayer(ann.NewNeurons(InputSize, 0.1))
hiddenLayer1 := ann.NewLayer(ann.NewNeurons(HiddenLayer1Size, 0.1))
//hiddenLayer2 := ann.NewLayer(ann.NewNeurons(HiddenLayer2Size, 0.1))
outputLayer := ann.NewLayer(ann.NewNeurons(OutputLayerSize, 0.1))
// ネットワークを作成
network := ann.NewNetwork([]*ann.Layer{inputLayer, hiddenLayer1, outputLayer})
return &App[T]{
network: network,
processor: processor,
dataset: dataset,
labels: labels,
}
}
func (a *App[T]) Train(epochs int) error {
learningRate := InitialLearningRate
for epoch := 0; epoch < epochs; epoch++ {
// ミニバッチの作成
batchData, batchLabels := a.createMiniBatch()
// フィードフォワードとバックプロパゲーション
contexts := a.processor.FeedForward(a.network, batchData)
adjustments := a.processor.BackPropagate(contexts, batchLabels)
// ネットワークの更新
err := a.updateNetwork(adjustments, learningRate)
if err != nil {
return err
}
// 学習率の更新
learningRate = a.updateLearningRate(learningRate, epoch)
}
return nil
}
func (a *App[T]) createMiniBatch() (ann.Dataset, []ann.Labels) {
indices := rand.Perm(len(a.dataset))[:BatchSize]
batchData := make(ann.Dataset, BatchSize)
batchLabels := make([]ann.Labels, BatchSize)
for i, idx := range indices {
batchData[i] = a.dataset[idx]
batchLabels[i] = a.labels[idx]
}
return batchData, batchLabels
}
func (a *App[T]) updateNetwork(adjustments ann.Adjustments, learningRate ann.Number) error {
// 重みの更新
for layerIndex, layerAdjustments := range adjustments.WeightAdjustments() {
for neuronIndex, neuronAdjustments := range layerAdjustments {
scaledAdjustments := make([]ann.Number, len(neuronAdjustments))
for i, adj := range neuronAdjustments {
scaledAdjustments[i] = -learningRate * adj
}
err := a.network.AdjustNeuronConnections(layerIndex+1, neuronIndex, scaledAdjustments)
if err != nil {
return err
}
}
}
// バイアスの更新
for layerIndex, layerBiasAdjustments := range adjustments.BiasAdjustments() {
for neuronIndex, biasAdjustment := range layerBiasAdjustments {
neuron := a.network.Layers()[layerIndex+1].Neurons()[neuronIndex]
scaledBiasAdjustment := -learningRate * ann.Number(biasAdjustment)
newBias := neuron.Bias() + scaledBiasAdjustment
neuron.SetBias(newBias)
}
}
return nil
}
func (a *App[T]) updateLearningRate(currentRate ann.Number, epoch int) ann.Number {
newRate := currentRate * LearningRateDecay
if newRate < MinLearningRate {
return MinLearningRate
}
return newRate
}
func (a *App[T]) Predict(input ann.Data) ann.LayerActivations {
dataset := make(ann.Dataset, 1)
dataset[0] = input
context := a.processor.FeedForward(a.network, dataset)[0]
return context.Activations()[len(context.Activations())-1]
}
../apps/scratchai/internal/digitguesser/gonumprocessor.go
package digitguesser
import (
"lib/pkg/ann"
"math"
"gonum.org/v1/gonum/mat"
)
type GonumContextData struct {
Weights []*mat.Dense
Biases []*mat.VecDense
}
type GonumProcessor struct{}
func (g GonumProcessor) Activation(weightedSum ann.Number) ann.Number {
// シグモイド関数
return ann.Number(1 / (1 + math.Exp(-float64(weightedSum))))
}
func (g GonumProcessor) Derivative(weightedSum ann.Number) ann.Number {
// シグモイド関数を微分したもの
activation := g.Activation(weightedSum)
return activation * (1 - activation)
}
func (g GonumProcessor) FeedForward(network *ann.Network, dataset ann.Dataset) []ann.Context[GonumContextData] {
contexts := make([]ann.Context[GonumContextData], len(dataset))
layers := network.Layers()
connections := network.Connections()
// ネットワークの重みと偏りを行列に変換
weights := make([]*mat.Dense, len(layers)-1)
biases := make([]*mat.VecDense, len(layers)-1)
for i := 1; i < len(layers); i++ {
layerConnections := connections[i]
rows := len(layerConnections)
cols := len(layerConnections[0])
weightData := make([]float64, rows*cols)
biasData := make([]float64, rows)
for j, neuronConns := range layerConnections {
for k, conn := range neuronConns {
weightData[j*cols+k] = float64(conn.Weight())
}
biasData[j] = float64(layers[i].Neurons()[j].Bias())
}
weights[i-1] = mat.NewDense(rows, cols, weightData)
biases[i-1] = mat.NewVecDense(rows, biasData)
}
// ContextDataの作成と格納
contextData := GonumContextData{
Weights: weights,
Biases: biases,
}
// データセットを行列に変換
dataMatrix := mat.NewDense(len(dataset), len(dataset[0]), nil)
for i, data := range dataset {
for j, val := range data {
dataMatrix.Set(i, j, float64(val))
}
}
// フィードフォワード計算
activations := make([]*mat.Dense, len(layers))
activations[0] = dataMatrix
for i := 1; i < len(layers); i++ {
prevActivation := activations[i-1]
weight := weights[i-1]
bias := biases[i-1]
weightedSum := mat.NewDense(prevActivation.RawMatrix().Rows, weight.RawMatrix().Rows, nil)
weightedSum.Mul(prevActivation, weight.T())
// バイアスの加算
rows, cols := weightedSum.Dims()
biasMatrix := mat.NewDense(rows, cols, nil)
for r := 0; r < rows; r++ {
biasMatrix.SetRow(r, bias.RawVector().Data)
}
weightedSum.Add(weightedSum, biasMatrix)
// 活性化関数の適用
activationFunc := func(_, _ int, v float64) float64 {
return float64(g.Activation(ann.Number(v)))
}
weightedSum.Apply(activationFunc, weightedSum)
activations[i] = weightedSum
}
// 結果をコンテキストに変換
for i := range dataset {
layerSizes := make([]int, len(layers))
for i, layer := range layers {
layerSizes[i] = len(layer.Neurons())
}
context := ann.NewContext[GonumContextData](layerSizes)
context.Data = contextData
for j, activation := range activations {
layerActivation := make(ann.LayerActivations, activation.RawMatrix().Cols)
for k := 0; k < activation.RawMatrix().Cols; k++ {
layerActivation[k] = ann.Number(activation.At(i, k))
}
err := context.SetActivations(j, layerActivation)
if err != nil {
panic(err)
}
}
contexts[i] = *context
}
return contexts
}
func (g GonumProcessor) BackPropagate(contexts []ann.Context[GonumContextData], labels []ann.Labels) ann.Adjustments {
if len(contexts) == 0 || len(labels) != len(contexts) {
panic("Invalid input: contexts and labels must have the same non-zero length")
}
numLayers := len(contexts[0].Activations())
weights := contexts[0].Data.Weights
batchSize := len(contexts)
// Initialize weight and bias adjustments
weightAdjustments := make([]ann.WeightAdjustmentsToLayer, numLayers-1)
biasAdjustments := make([]ann.BiasAdjustmentsToLayer, numLayers-1)
// Convert activations and labels to matrices
activations := make([]*mat.Dense, numLayers)
for l := 0; l < numLayers; l++ {
layerSize := len(contexts[0].Activations()[l])
activationData := make([]float64, batchSize*layerSize)
for i, context := range contexts {
for j, act := range context.Activations()[l] {
activationData[i*layerSize+j] = float64(act)
}
}
activations[l] = mat.NewDense(batchSize, layerSize, activationData)
}
labelsMatrix := mat.NewDense(batchSize, len(labels[0]), nil)
for i, label := range labels {
for j, val := range label {
labelsMatrix.Set(i, j, float64(val))
}
}
// Backpropagation
delta := new(mat.Dense)
delta.Sub(activations[numLayers-1], labelsMatrix)
for l := numLayers - 2; l >= 0; l-- {
prevLayerSize := activations[l].RawMatrix().Cols
currentLayerSize := activations[l+1].RawMatrix().Cols
// Calculate weight adjustments
weightAdj := new(mat.Dense)
weightAdj.Mul(activations[l].T(), delta)
weightAdj.Scale(1/float64(batchSize), weightAdj)
// FIXME: でかい行列用意してから、最後にまとめてやったほうが良い気がする
weightAdjustments[l] = make(ann.WeightAdjustmentsToLayer, currentLayerSize)
for i := 0; i < currentLayerSize; i++ {
weightAdjustments[l][i] = make(ann.WeightAdjustmentsToNeuron, prevLayerSize)
for j := 0; j < prevLayerSize; j++ {
weightAdjustments[l][i][j] = ann.Number(weightAdj.At(j, i))
}
}
// Calculate bias adjustments
biasAdj := make(ann.BiasAdjustmentsToLayer, currentLayerSize)
for i := 0; i < currentLayerSize; i++ {
sum := 0.0
for j := 0; j < batchSize; j++ {
sum += delta.At(j, i)
}
biasAdj[i] = ann.BiasAdjustmentsToNeuron(sum / float64(batchSize))
}
biasAdjustments[l] = biasAdj
if l > 0 {
// Prepare delta for the next layer
newDelta := new(mat.Dense)
// FeedForwardの時にweightsが転地された状態になってる、statefulなオブジェクトのやり取りは注意が必要
newDelta.Mul(delta, weights[l])
// Element-wise multiplication with the derivative of the activation function
derivativeActivation := new(mat.Dense)
derivativeActivation.Apply(func(_, _ int, v float64) float64 {
return float64(g.Derivative(ann.Number(v)))
}, activations[l])
delta = new(mat.Dense)
delta.MulElem(newDelta, derivativeActivation)
}
}
return ann.NewAdjustments(weightAdjustments, biasAdjustments)
}
func NewGonumProcessor() ann.Processor[GonumContextData] {
return GonumProcessor{}
}
../apps/scratchai/internal/digitguesser/trainmnist.go
package digitguesser
import (
"fmt"
"lib/pkg/ann"
"log"
"scratchai/internal/mnist"
)
// maxIndex は与えられたスライスの中で最大値のインデックスを返す
func maxIndex(slice []ann.Number) int {
maxVal := slice[0]
maxIdx := 0
for i, val := range slice {
if val > maxVal {
maxVal = val
maxIdx = i
}
}
return maxIdx
}
func TrainMnist() {
dataset, labels, err := mnist.LoadData("/workspaces/mictlan/apps/scratchai/internal/mnist/data/")
if err != nil {
log.Fatalf("Failed to load MNIST data: %v", err)
}
fmt.Printf("Dataset size: %d, Labels size: %d\n", len(dataset), len(labels))
processor := NewGonumProcessor()
guesser := NewApp(processor, dataset[:60000], labels[:60000])
// 学習の実行
err = guesser.Train(2000)
if err != nil {
panic(err)
}
// テストデータでの予測と正答率の計算
correct := 0
total := 0
for i := 60000; i < 70000; i++ {
prediction := guesser.Predict(dataset[i])
predictedIndex := maxIndex(prediction)
actualIndex := maxIndex(labels[i])
if predictedIndex == actualIndex {
correct++
}
total++
// 予測結果の表示(オプション)
//fmt.Printf("Sample %d: Predicted %d, Actual %d, %v\n", i, predictedIndex, actualIndex, predictedIndex == actualIndex)
}
// 正答率の計算と表示
accuracy := float64(correct) / float64(total) * 100
fmt.Printf("\nAccuracy: %.2f%% (%d/%d)\n", accuracy, correct, total)
}
../apps/scratchai/internal/digitguesser/gonumprocessor_test.go
package digitguesser
import (
"lib/pkg/ann"
"math/rand/v2"
"testing"
)
func TestGonumProcessor(t *testing.T) {
inputSize := 784
outputSize := 10
createNetwork := func() *ann.Network {
inputLayer := ann.NewLayer(ann.NewNeurons(inputSize, 0))
hiddenLayer1 := ann.NewLayer(ann.NewNeurons(16, 0))
hiddenLayer2 := ann.NewLayer(ann.NewNeurons(16, 0))
outputLayer := ann.NewLayer(ann.NewNeurons(outputSize, 0))
// weightをすべて0.1にしてネットワークを作成
// TODO: weightの初期値を複数用意するか、他の方法で極小値に収束するのではなく最小値を探したい
return ann.NewNetwork([]*ann.Layer{inputLayer, hiddenLayer1, hiddenLayer2, outputLayer})
}
network := createNetwork()
oldNetwork := createNetwork()
// トレインデータを作成
trainDataNum := 1000
dataset := make(ann.Dataset, trainDataNum)
allLabels := make([]ann.Labels, trainDataNum)
for i := range dataset {
dataset[i] = make(ann.Data, inputSize)
for j := range dataset[i] {
dataset[i][j] = ann.Number(rand.Float64()*2 - 1)
}
allLabels[i] = make(ann.Labels, outputSize)
for j := range allLabels[i] {
// 教師データが全部0.0 0.1 0.2 ...0.9
allLabels[i][j] = ann.Number(float32(j) * 0.1)
}
}
processor := NewGonumProcessor()
// TODO: 学習率を勾配の急さに併せて調節する
learningRate := ann.Number(0.1)
// 50回の学習ループ
for epoch := 0; epoch < 1000; epoch++ {
contexts := processor.FeedForward(network, dataset)
adjustments := processor.BackPropagate(contexts, allLabels)
// ネットワークの重みを調整
for layerIndex, layerAdjustments := range adjustments.WeightAdjustments() {
for neuronIndex, neuronAdjustments := range layerAdjustments {
scaledAdjustments := make([]ann.Number, len(neuronAdjustments))
for i, adj := range neuronAdjustments {
scaledAdjustments[i] = -learningRate * adj
}
err := network.AdjustNeuronConnections(layerIndex+1, neuronIndex, scaledAdjustments)
if err != nil {
t.Errorf("Error adjusting neuron connections: %v", err)
}
}
}
// バイアスの調整
for layerIndex, layerBiasAdjustments := range adjustments.BiasAdjustments() {
for neuronIndex, biasAdjustment := range layerBiasAdjustments {
neuron := network.Layers()[layerIndex+1].Neurons()[neuronIndex]
scaledBiasAdjustment := -learningRate * ann.Number(biasAdjustment)
newBias := neuron.Bias() + scaledBiasAdjustment
neuron.SetBias(newBias)
}
}
// 10エポックごとに進捗を表示
if epoch%10 == 0 {
t.Logf("Completed epoch %d", epoch)
}
}
t.Log("Training completed - 50 epochs")
testDataset := make(ann.Dataset, 1)
for i := range testDataset {
testDataset[i] = make(ann.Data, inputSize)
for j := range testDataset[i] {
testDataset[i][j] = ann.Number(rand.Float64()*2 - 1)
}
}
// 何を入力しても教師データの0 0.1 0.2 ...0.9という形式に近い出力になるはず
finalContexts := processor.FeedForward(network, testDataset)
t.Logf("Final output after training: %+v", finalContexts[0].Activations()[len(finalContexts[0].Activations())-1])
// トレーニングしていないモデルではでたらめな出力になる
randomContexts := processor.FeedForward(oldNetwork, testDataset)
t.Logf("Output of model before training: %+v", randomContexts[0].Activations()[len(randomContexts[0].Activations())-1])
}
../apps/scratchai/internal/digitguesser/trainmnist_test.go
package digitguesser_test
import (
"scratchai/internal/digitguesser"
"testing"
)
func TestTrainMnist(t *testing.T) {
digitguesser.TrainMnist()
}
Discussion