📘

1ファイルでニューラルネットワークを作れる Genann を試してみた

3 min read

https://github.com/nothings/single_file_libs#ai

シングルファイルで使える C/C++ 向けのライブラリが一覧にまとめられているリポジトリをなんとなく見ていたら、1ファイルでニューラルネットワークを作れるライブラリ Genann という物があったので試してみた。

https://github.com/codeplea/genann

インタフェースはとても簡単。

/* Creates and returns a new ann. */
genann *genann_init(int inputs, int hidden_layers, int hidden, int outputs);

/* Creates ANN from file saved with genann_write. */
genann *genann_read(FILE *in);

/* Sets weights randomly. Called by init. */
void genann_randomize(genann *ann);

/* Returns a new copy of ann. */
genann *genann_copy(genann const *ann);

/* Frees the memory used by an ann. */
void genann_free(genann *ann);

/* Runs the feedforward algorithm to calculate the ann's output. */
double const *genann_run(genann const *ann, double const *inputs);

/* Does a single backprop update. */
void genann_train(genann const *ann, double const *inputs, double const *desired_outputs, double learning_rate);

/* Saves the ann. */
void genann_write(genann const *ann, FILE *out);

void genann_init_sigmoid_lookup(const genann *ann);
double genann_act_sigmoid(const genann *ann, double a);
double genann_act_sigmoid_cached(const genann *ann, double a);
double genann_act_threshold(const genann *ann, double a);
double genann_act_linear(const genann *ann, double a);

この中でも簡単なニューラルネットワークを作るのであれば使うのは4つ。

genann *genann_init(int inputs, int hidden_layers, int hidden, int outputs);
double const *genann_run(genann const *ann, double const *inputs);
void genann_train(genann const *ann, double const *inputs, double const *desired_outputs, double learning_rate);
void genann_free(genann *ann);

試しに FizzBuzz を学習させてみた。

static void
bin(int n, double* d) {
  int i;
  memset(d, 0, sizeof(double)*10);
  for (i = 0; i < 10; i++) {
    d[i] = (double)((n >> i) & 1);
  }
}

static int
dec(const double* d) {
  int i, mi = 0;
  double m = 0;
  for (i = 0; i < 4; i++) {
    if (m < d[i]) {
      m = d[i];
      mi = i;
    }
  }
  return mi;
}

static void
fizz_buzz(int n, double* d) {
  memset(d, 0, sizeof(double)*4);
  if (n%15 == 0) {
    d[0] = 1;
  } else if (n%3 == 0) {
    d[1] = 1;
  } else if (n%5 == 0) {
    d[2] = 1;
  } else {
    d[3] = 1;
  }
}

まずは初期化

  srand(time(0));

  // Initialize 
  genann *ann = genann_init(10, 1, 10, 4);

隠れ層はもっと大きくていいかも。学習は以下の様に。

  // Learn output with input. input is encoded with 10 floating-point numbers.
  // output is encoded with 4 floating-point numbers.
  for (i = 0; i < 500000; ++i) {
    double input[10];
    double output[4];
    bin(i%100, input);
    fizz_buzz(i%100+1, output);
    genann_train(ann, input, output, 3);
  }

genann_train の最後のパラメータは学習率。学習が終わったら答え合わせとして FizzBuzz を出力してみる。

  // Output FizzBuzz
  for (i = 0; i < 100; i++) {
    double input[10];
    bin(i, input);
    const double* f = genann_run(ann, input);
    int n = dec(f);
    switch (n) {
      case 0:
        puts("FizzBuzz");
        break;
      case 1:
        puts("Fizz");
        break;
      case 2:
        puts("Buzz");
        break;
      case 3:
        printf("%d\n", i+1);
        break;
    }
  }

コンパイルは genann.c を一緒にコンパイルリンクするだけ。

$ gcc -o fizzbuzz fizzbuzz.c genann.c

実行すると以下の通り。

1
2
Fizz
4
Buzz
Fizz
7
8
Fizz
Buzz
11
Fizz
13
14
FizzBuzz
16

もちろん 100 以上ではだんだん結果が崩れてくるんだけども。

簡単な機械学習をお手軽にやるには便利かもしれない。ソースは以下に置いておきます。

https://gist.github.com/mattn/d14dbf39b5e5701b161ffb54af222692