🙄

goでNeuralNetworkをフルスクラッチしてみた(コードのみ)

2024/11/12に公開

../apps/scratchai/cmd/digitguesser/main.go

package main

import "scratchai/internal/digitguesser"

func main() {
	digitguesser.TrainMnist()
}

../apps/scratchai/internal/mnist/utils.go

package mnist

import (
	"lib/pkg/ann"
	"path"
)

// Set represents a data set of image-label pairs held in memory
type Set struct {
	NRow   int
	NCol   int
	Images []RawImage
	Labels []Label
}

// ReadSet reads a set from the images file iname and the corresponding labels file lname
func ReadSet(iname, lname string) (set *Set, err error) {
	set = &Set{}
	if set.NRow, set.NCol, set.Images, err = ReadImageFile(iname); err != nil {
		return nil, err
	}
	if set.Labels, err = ReadLabelFile(lname); err != nil {
		return nil, err
	}
	return
}

// Count returns the number of points available in the data set
func (s *Set) Count() int {
	return len(s.Images)
}

// Get returns the i-th image and its corresponding label
func (s *Set) Get(i int) (RawImage, Label) {
	return s.Images[i], s.Labels[i]
}

// Sweeper is an iterator over the points in a data set
type Sweeper struct {
	set *Set
	i   int
}

// Next returns the next image and its label in the data set.
// If the end is reached, present is set to false.
func (sw *Sweeper) Next() (image RawImage, label Label, present bool) {
	if sw.i >= len(sw.set.Images) {
		return nil, 0, false
	}
	return sw.set.Images[sw.i], sw.set.Labels[sw.i], true
}

// Sweep creates a new sweep iterator over the data set
func (s *Set) Sweep() *Sweeper {
	return &Sweeper{set: s}
}

// Load reads both the training and the testing MNIST data sets, given
// a local directory dir, containing the MNIST distribution files.
func Load(dir string) (train, test *Set, err error) {
	if train, err = ReadSet(path.Join(dir, "train-images-idx3-ubyte.gz"), path.Join(dir, "train-labels-idx1-ubyte.gz")); err != nil {
		return nil, nil, err
	}
	if test, err = ReadSet(path.Join(dir, "t10k-images-idx3-ubyte.gz"), path.Join(dir, "t10k-labels-idx1-ubyte.gz")); err != nil {
		return nil, nil, err
	}
	return
}

func LoadData(dir string) (dataset ann.Dataset, labels []ann.Labels, err error) {
	train, test, err := Load(dir)
	if err != nil {
		return nil, nil, err
	}

	// Combine train and test sets
	totalSize := train.Count() + test.Count()
	dataset = make(ann.Dataset, totalSize)
	labels = make([]ann.Labels, totalSize)

	index := 0

	// Process training data
	for i := 0; i < train.Count(); i++ {
		image, label := train.Get(i)
		dataset[index] = imageToData(image)
		labels[index] = labelToLabels(label)
		index++
	}

	// Process test data
	for i := 0; i < test.Count(); i++ {
		image, label := test.Get(i)
		dataset[index] = imageToData(image)
		labels[index] = labelToLabels(label)
		index++
	}

	return dataset, labels, nil
}

// Helper function to convert RawImage to ann.Data
func imageToData(image RawImage) ann.Data {
	data := make(ann.Data, len(image))
	for i, pixel := range image {
		data[i] = ann.Number(pixel) / 255.0 // Normalize pixel values to [0, 1]
	}
	return data
}

// Helper function to convert Label to ann.Labels
func labelToLabels(label Label) ann.Labels {
	labels := make(ann.Labels, 10) // 10 possible digits (0-9)
	labels[label] = 1.0            // Set the correct label to 1.0, others remain 0.0
	return labels
}

../apps/scratchai/internal/mnist/loader_test.go

package mnist

import "testing"

import (
	"fmt"
)

func TestReadLabelFile(t *testing.T) {
	ll, err := ReadLabelFile("data/t10k-labels-idx1-ubyte.gz")
	if err != nil {
		t.Fatalf("read (%s)", err)
	}
	if len(ll) != 10000 {
		t.Errorf("unexpected count %d", len(ll))
	}
}

func TestReadImageFile(t *testing.T) {
	nrow, ncol, imgs, err := ReadImageFile("data/t10k-images-idx3-ubyte.gz")
	if err != nil {
		t.Fatalf("read (%s)", err)
	}
	if len(imgs) != 10000 {
		t.Errorf("unexpected count %d", len(imgs))
	}
	fmt.Printf("%d images, %dx%d format\n", len(imgs), nrow, ncol)
}

func TestLoad(t *testing.T) {
	train, test, err := Load("./data")
	if err != nil {
		t.Fatalf("load (%s)", err)
	}
	println(train.Count(), test.Count())
	println(train.Images[0])
}

../apps/scratchai/internal/mnist/loader.go

package mnist

import (
	"compress/gzip"
	"encoding/binary"
	"image"
	"image/color"
	"io"
	"os"
)

const (
	imageMagic = 0x00000803
	labelMagic = 0x00000801
	Width      = 28
	Height     = 28
)

// RawImage Image holds the pixel intensities of an image.
// 255 is foreground (black), 0 is background (white).
type RawImage []byte

func (img RawImage) ColorModel() color.Model {
	return color.GrayModel
}

func (img RawImage) Bounds() image.Rectangle {
	return image.Rectangle{
		Min: image.Point{},
		Max: image.Point{X: Width, Y: Height},
	}
}

func (img RawImage) At(x, y int) color.Color {
	return color.Gray{Y: img[y*Width+x]}
}

// ReadImageFile opens the named image file (training or test), parses it and
// returns all images in order.
func ReadImageFile(name string) (rows, cols int, imgs []RawImage, err error) {
	f, err := os.Open(name)
	if err != nil {
		return 0, 0, nil, err
	}
	defer f.Close()
	z, err := gzip.NewReader(f)
	if err != nil {
		return 0, 0, nil, err
	}
	return readImageFile(z)
}

func readImageFile(r io.Reader) (rows, cols int, imgs []RawImage, err error) {
	var (
		magic int32
		n     int32
		nrow  int32
		ncol  int32
	)
	if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
		return 0, 0, nil, err
	}
	if magic != imageMagic {
		return 0, 0, nil, os.ErrInvalid
	}
	if err = binary.Read(r, binary.BigEndian, &n); err != nil {
		return 0, 0, nil, err
	}
	if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
		return 0, 0, nil, err
	}
	if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
		return 0, 0, nil, err
	}
	imgs = make([]RawImage, n)
	m := int(nrow * ncol)
	for i := 0; i < int(n); i++ {
		imgs[i] = make(RawImage, m)
		m_, err := io.ReadFull(r, imgs[i])
		if err != nil {
			return 0, 0, nil, err
		}
		if m_ != int(m) {
			return 0, 0, nil, os.ErrInvalid
		}
	}
	return int(nrow), int(ncol), imgs, nil
}

// Label is a digit label in 0 to 9
type Label uint8

// ReadLabelFile opens the named label file (training or test), parses it and
// returns all labels in order.
func ReadLabelFile(name string) (labels []Label, err error) {
	f, err := os.Open(name)
	if err != nil {
		return nil, err
	}
	defer f.Close()
	z, err := gzip.NewReader(f)
	if err != nil {
		return nil, err
	}
	return readLabelFile(z)
}

func readLabelFile(r io.Reader) (labels []Label, err error) {
	var (
		magic int32
		n     int32
	)
	if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
		return nil, err
	}
	if magic != labelMagic {
		return nil, os.ErrInvalid
	}
	if err = binary.Read(r, binary.BigEndian, &n); err != nil {
		return nil, err
	}
	labels = make([]Label, n)
	for i := 0; i < int(n); i++ {
		var l Label
		if err := binary.Read(r, binary.BigEndian, &l); err != nil {
			return nil, err
		}
		labels[i] = l
	}
	return labels, nil
}

../apps/scratchai/internal/digitguesser/app.go

package digitguesser

import (
	"lib/pkg/ann"
	"math/rand"
)

const (
	InputSize           = 28 * 28
	HiddenLayer1Size    = 16
	HiddenLayer2Size    = 16
	OutputLayerSize     = 10
	BatchSize           = 1000
	InitialLearningRate = ann.Number(5)
	MinLearningRate     = 0.02
	LearningRateDecay   = 0.99
)

type App[T any] struct {
	network   *ann.Network
	processor ann.Processor[T]
	dataset   ann.Dataset
	labels    []ann.Labels
}

func NewApp[T any](processor ann.Processor[T], dataset ann.Dataset, labels []ann.Labels) *App[T] {
	// レイヤーを作成
	inputLayer := ann.NewLayer(ann.NewNeurons(InputSize, 0.1))
	hiddenLayer1 := ann.NewLayer(ann.NewNeurons(HiddenLayer1Size, 0.1))
	//hiddenLayer2 := ann.NewLayer(ann.NewNeurons(HiddenLayer2Size, 0.1))
	outputLayer := ann.NewLayer(ann.NewNeurons(OutputLayerSize, 0.1))

	// ネットワークを作成
	network := ann.NewNetwork([]*ann.Layer{inputLayer, hiddenLayer1, outputLayer})

	return &App[T]{
		network:   network,
		processor: processor,
		dataset:   dataset,
		labels:    labels,
	}
}

func (a *App[T]) Train(epochs int) error {
	learningRate := InitialLearningRate

	for epoch := 0; epoch < epochs; epoch++ {
		// ミニバッチの作成
		batchData, batchLabels := a.createMiniBatch()

		// フィードフォワードとバックプロパゲーション
		contexts := a.processor.FeedForward(a.network, batchData)
		adjustments := a.processor.BackPropagate(contexts, batchLabels)

		// ネットワークの更新
		err := a.updateNetwork(adjustments, learningRate)
		if err != nil {
			return err
		}

		// 学習率の更新
		learningRate = a.updateLearningRate(learningRate, epoch)
	}

	return nil
}

func (a *App[T]) createMiniBatch() (ann.Dataset, []ann.Labels) {
	indices := rand.Perm(len(a.dataset))[:BatchSize]
	batchData := make(ann.Dataset, BatchSize)
	batchLabels := make([]ann.Labels, BatchSize)

	for i, idx := range indices {
		batchData[i] = a.dataset[idx]
		batchLabels[i] = a.labels[idx]
	}

	return batchData, batchLabels
}

func (a *App[T]) updateNetwork(adjustments ann.Adjustments, learningRate ann.Number) error {
	// 重みの更新
	for layerIndex, layerAdjustments := range adjustments.WeightAdjustments() {
		for neuronIndex, neuronAdjustments := range layerAdjustments {
			scaledAdjustments := make([]ann.Number, len(neuronAdjustments))
			for i, adj := range neuronAdjustments {
				scaledAdjustments[i] = -learningRate * adj
			}
			err := a.network.AdjustNeuronConnections(layerIndex+1, neuronIndex, scaledAdjustments)
			if err != nil {
				return err
			}
		}
	}

	// バイアスの更新
	for layerIndex, layerBiasAdjustments := range adjustments.BiasAdjustments() {
		for neuronIndex, biasAdjustment := range layerBiasAdjustments {
			neuron := a.network.Layers()[layerIndex+1].Neurons()[neuronIndex]
			scaledBiasAdjustment := -learningRate * ann.Number(biasAdjustment)
			newBias := neuron.Bias() + scaledBiasAdjustment
			neuron.SetBias(newBias)
		}
	}

	return nil
}

func (a *App[T]) updateLearningRate(currentRate ann.Number, epoch int) ann.Number {
	newRate := currentRate * LearningRateDecay
	if newRate < MinLearningRate {
		return MinLearningRate
	}
	return newRate
}

func (a *App[T]) Predict(input ann.Data) ann.LayerActivations {
	dataset := make(ann.Dataset, 1)
	dataset[0] = input
	context := a.processor.FeedForward(a.network, dataset)[0]
	return context.Activations()[len(context.Activations())-1]
}

../apps/scratchai/internal/digitguesser/gonumprocessor.go

package digitguesser

import (
	"lib/pkg/ann"
	"math"

	"gonum.org/v1/gonum/mat"
)

type GonumContextData struct {
	Weights []*mat.Dense
	Biases  []*mat.VecDense
}

type GonumProcessor struct{}

func (g GonumProcessor) Activation(weightedSum ann.Number) ann.Number {
	// シグモイド関数
	return ann.Number(1 / (1 + math.Exp(-float64(weightedSum))))
}

func (g GonumProcessor) Derivative(weightedSum ann.Number) ann.Number {
	// シグモイド関数を微分したもの
	activation := g.Activation(weightedSum)
	return activation * (1 - activation)
}

func (g GonumProcessor) FeedForward(network *ann.Network, dataset ann.Dataset) []ann.Context[GonumContextData] {
	contexts := make([]ann.Context[GonumContextData], len(dataset))
	layers := network.Layers()
	connections := network.Connections()

	// ネットワークの重みと偏りを行列に変換
	weights := make([]*mat.Dense, len(layers)-1)
	biases := make([]*mat.VecDense, len(layers)-1)

	for i := 1; i < len(layers); i++ {
		layerConnections := connections[i]
		rows := len(layerConnections)
		cols := len(layerConnections[0])

		weightData := make([]float64, rows*cols)
		biasData := make([]float64, rows)

		for j, neuronConns := range layerConnections {
			for k, conn := range neuronConns {
				weightData[j*cols+k] = float64(conn.Weight())
			}
			biasData[j] = float64(layers[i].Neurons()[j].Bias())
		}

		weights[i-1] = mat.NewDense(rows, cols, weightData)
		biases[i-1] = mat.NewVecDense(rows, biasData)
	}

	// ContextDataの作成と格納
	contextData := GonumContextData{
		Weights: weights,
		Biases:  biases,
	}

	// データセットを行列に変換
	dataMatrix := mat.NewDense(len(dataset), len(dataset[0]), nil)
	for i, data := range dataset {
		for j, val := range data {
			dataMatrix.Set(i, j, float64(val))
		}
	}

	// フィードフォワード計算
	activations := make([]*mat.Dense, len(layers))
	activations[0] = dataMatrix

	for i := 1; i < len(layers); i++ {
		prevActivation := activations[i-1]
		weight := weights[i-1]
		bias := biases[i-1]

		weightedSum := mat.NewDense(prevActivation.RawMatrix().Rows, weight.RawMatrix().Rows, nil)
		weightedSum.Mul(prevActivation, weight.T())

		// バイアスの加算
		rows, cols := weightedSum.Dims()
		biasMatrix := mat.NewDense(rows, cols, nil)
		for r := 0; r < rows; r++ {
			biasMatrix.SetRow(r, bias.RawVector().Data)
		}
		weightedSum.Add(weightedSum, biasMatrix)

		// 活性化関数の適用
		activationFunc := func(_, _ int, v float64) float64 {
			return float64(g.Activation(ann.Number(v)))
		}
		weightedSum.Apply(activationFunc, weightedSum)

		activations[i] = weightedSum
	}

	// 結果をコンテキストに変換
	for i := range dataset {
		layerSizes := make([]int, len(layers))
		for i, layer := range layers {
			layerSizes[i] = len(layer.Neurons())
		}
		context := ann.NewContext[GonumContextData](layerSizes)
		context.Data = contextData
		for j, activation := range activations {
			layerActivation := make(ann.LayerActivations, activation.RawMatrix().Cols)
			for k := 0; k < activation.RawMatrix().Cols; k++ {
				layerActivation[k] = ann.Number(activation.At(i, k))
			}
			err := context.SetActivations(j, layerActivation)
			if err != nil {
				panic(err)
			}
		}
		contexts[i] = *context
	}

	return contexts
}

func (g GonumProcessor) BackPropagate(contexts []ann.Context[GonumContextData], labels []ann.Labels) ann.Adjustments {
	if len(contexts) == 0 || len(labels) != len(contexts) {
		panic("Invalid input: contexts and labels must have the same non-zero length")
	}

	numLayers := len(contexts[0].Activations())
	weights := contexts[0].Data.Weights
	batchSize := len(contexts)

	// Initialize weight and bias adjustments
	weightAdjustments := make([]ann.WeightAdjustmentsToLayer, numLayers-1)
	biasAdjustments := make([]ann.BiasAdjustmentsToLayer, numLayers-1)

	// Convert activations and labels to matrices
	activations := make([]*mat.Dense, numLayers)
	for l := 0; l < numLayers; l++ {
		layerSize := len(contexts[0].Activations()[l])
		activationData := make([]float64, batchSize*layerSize)
		for i, context := range contexts {
			for j, act := range context.Activations()[l] {
				activationData[i*layerSize+j] = float64(act)
			}
		}
		activations[l] = mat.NewDense(batchSize, layerSize, activationData)
	}

	labelsMatrix := mat.NewDense(batchSize, len(labels[0]), nil)
	for i, label := range labels {
		for j, val := range label {
			labelsMatrix.Set(i, j, float64(val))
		}
	}

	// Backpropagation
	delta := new(mat.Dense)
	delta.Sub(activations[numLayers-1], labelsMatrix)

	for l := numLayers - 2; l >= 0; l-- {
		prevLayerSize := activations[l].RawMatrix().Cols
		currentLayerSize := activations[l+1].RawMatrix().Cols

		// Calculate weight adjustments
		weightAdj := new(mat.Dense)
		weightAdj.Mul(activations[l].T(), delta)
		weightAdj.Scale(1/float64(batchSize), weightAdj)

		// FIXME: でかい行列用意してから、最後にまとめてやったほうが良い気がする
		weightAdjustments[l] = make(ann.WeightAdjustmentsToLayer, currentLayerSize)
		for i := 0; i < currentLayerSize; i++ {
			weightAdjustments[l][i] = make(ann.WeightAdjustmentsToNeuron, prevLayerSize)
			for j := 0; j < prevLayerSize; j++ {
				weightAdjustments[l][i][j] = ann.Number(weightAdj.At(j, i))
			}
		}

		// Calculate bias adjustments
		biasAdj := make(ann.BiasAdjustmentsToLayer, currentLayerSize)
		for i := 0; i < currentLayerSize; i++ {
			sum := 0.0
			for j := 0; j < batchSize; j++ {
				sum += delta.At(j, i)
			}
			biasAdj[i] = ann.BiasAdjustmentsToNeuron(sum / float64(batchSize))
		}
		biasAdjustments[l] = biasAdj

		if l > 0 {
			// Prepare delta for the next layer
			newDelta := new(mat.Dense)
			// FeedForwardの時にweightsが転地された状態になってる、statefulなオブジェクトのやり取りは注意が必要
			newDelta.Mul(delta, weights[l])

			// Element-wise multiplication with the derivative of the activation function
			derivativeActivation := new(mat.Dense)
			derivativeActivation.Apply(func(_, _ int, v float64) float64 {
				return float64(g.Derivative(ann.Number(v)))
			}, activations[l])

			delta = new(mat.Dense)
			delta.MulElem(newDelta, derivativeActivation)
		}
	}

	return ann.NewAdjustments(weightAdjustments, biasAdjustments)
}

func NewGonumProcessor() ann.Processor[GonumContextData] {
	return GonumProcessor{}
}

../apps/scratchai/internal/digitguesser/trainmnist.go

package digitguesser

import (
	"fmt"
	"lib/pkg/ann"
	"log"
	"scratchai/internal/mnist"
)

// maxIndex は与えられたスライスの中で最大値のインデックスを返す
func maxIndex(slice []ann.Number) int {
	maxVal := slice[0]
	maxIdx := 0
	for i, val := range slice {
		if val > maxVal {
			maxVal = val
			maxIdx = i
		}
	}
	return maxIdx
}

func TrainMnist() {
	dataset, labels, err := mnist.LoadData("/workspaces/mictlan/apps/scratchai/internal/mnist/data/")
	if err != nil {
		log.Fatalf("Failed to load MNIST data: %v", err)
	}
	fmt.Printf("Dataset size: %d, Labels size: %d\n", len(dataset), len(labels))

	processor := NewGonumProcessor()
	guesser := NewApp(processor, dataset[:60000], labels[:60000])

	// 学習の実行
	err = guesser.Train(2000)
	if err != nil {
		panic(err)
	}

	// テストデータでの予測と正答率の計算
	correct := 0
	total := 0
	for i := 60000; i < 70000; i++ {
		prediction := guesser.Predict(dataset[i])

		predictedIndex := maxIndex(prediction)
		actualIndex := maxIndex(labels[i])

		if predictedIndex == actualIndex {
			correct++
		}
		total++

		// 予測結果の表示(オプション)
		//fmt.Printf("Sample %d: Predicted %d, Actual %d, %v\n", i, predictedIndex, actualIndex, predictedIndex == actualIndex)
	}

	// 正答率の計算と表示
	accuracy := float64(correct) / float64(total) * 100
	fmt.Printf("\nAccuracy: %.2f%% (%d/%d)\n", accuracy, correct, total)
}

../apps/scratchai/internal/digitguesser/gonumprocessor_test.go

package digitguesser

import (
	"lib/pkg/ann"
	"math/rand/v2"
	"testing"
)

func TestGonumProcessor(t *testing.T) {
	inputSize := 784
	outputSize := 10

	createNetwork := func() *ann.Network {
		inputLayer := ann.NewLayer(ann.NewNeurons(inputSize, 0))
		hiddenLayer1 := ann.NewLayer(ann.NewNeurons(16, 0))
		hiddenLayer2 := ann.NewLayer(ann.NewNeurons(16, 0))
		outputLayer := ann.NewLayer(ann.NewNeurons(outputSize, 0))

		// weightをすべて0.1にしてネットワークを作成
		// TODO: weightの初期値を複数用意するか、他の方法で極小値に収束するのではなく最小値を探したい
		return ann.NewNetwork([]*ann.Layer{inputLayer, hiddenLayer1, hiddenLayer2, outputLayer})
	}

	network := createNetwork()
	oldNetwork := createNetwork()

	// トレインデータを作成
	trainDataNum := 1000
	dataset := make(ann.Dataset, trainDataNum)
	allLabels := make([]ann.Labels, trainDataNum)
	for i := range dataset {
		dataset[i] = make(ann.Data, inputSize)
		for j := range dataset[i] {
			dataset[i][j] = ann.Number(rand.Float64()*2 - 1)
		}
		allLabels[i] = make(ann.Labels, outputSize)
		for j := range allLabels[i] {
			// 教師データが全部0.0 0.1 0.2 ...0.9
			allLabels[i][j] = ann.Number(float32(j) * 0.1)
		}
	}

	processor := NewGonumProcessor()
	// TODO: 学習率を勾配の急さに併せて調節する
	learningRate := ann.Number(0.1)

	// 50回の学習ループ
	for epoch := 0; epoch < 1000; epoch++ {
		contexts := processor.FeedForward(network, dataset)
		adjustments := processor.BackPropagate(contexts, allLabels)

		// ネットワークの重みを調整
		for layerIndex, layerAdjustments := range adjustments.WeightAdjustments() {
			for neuronIndex, neuronAdjustments := range layerAdjustments {
				scaledAdjustments := make([]ann.Number, len(neuronAdjustments))
				for i, adj := range neuronAdjustments {
					scaledAdjustments[i] = -learningRate * adj
				}

				err := network.AdjustNeuronConnections(layerIndex+1, neuronIndex, scaledAdjustments)
				if err != nil {
					t.Errorf("Error adjusting neuron connections: %v", err)
				}
			}
		}

		// バイアスの調整
		for layerIndex, layerBiasAdjustments := range adjustments.BiasAdjustments() {
			for neuronIndex, biasAdjustment := range layerBiasAdjustments {
				neuron := network.Layers()[layerIndex+1].Neurons()[neuronIndex]
				scaledBiasAdjustment := -learningRate * ann.Number(biasAdjustment)
				newBias := neuron.Bias() + scaledBiasAdjustment
				neuron.SetBias(newBias)
			}
		}

		// 10エポックごとに進捗を表示
		if epoch%10 == 0 {
			t.Logf("Completed epoch %d", epoch)
		}
	}

	t.Log("Training completed - 50 epochs")

	testDataset := make(ann.Dataset, 1)
	for i := range testDataset {
		testDataset[i] = make(ann.Data, inputSize)
		for j := range testDataset[i] {
			testDataset[i][j] = ann.Number(rand.Float64()*2 - 1)
		}
	}

	// 何を入力しても教師データの0 0.1 0.2 ...0.9という形式に近い出力になるはず
	finalContexts := processor.FeedForward(network, testDataset)
	t.Logf("Final output after training: %+v", finalContexts[0].Activations()[len(finalContexts[0].Activations())-1])

	// トレーニングしていないモデルではでたらめな出力になる
	randomContexts := processor.FeedForward(oldNetwork, testDataset)
	t.Logf("Output of model before training: %+v", randomContexts[0].Activations()[len(randomContexts[0].Activations())-1])
}

../apps/scratchai/internal/digitguesser/trainmnist_test.go

package digitguesser_test

import (
	"scratchai/internal/digitguesser"
	"testing"
)

func TestTrainMnist(t *testing.T) {
	digitguesser.TrainMnist()
}

Discussion