📷

画像処理100本ノックに挑戦|平均プーリング(007/100)

2025/01/08に公開

これはなに?

画像処理100本ノックを、TypeScriptとlibvipsで挑戦してみる記事の7本目です。

前回

https://zenn.dev/nyagato_00/articles/de360085b46e19

実装

お題

ここでは画像をグリッド分割(ある固定長の領域に分ける)し、かく領域内(セル)の平均値でその領域内の値を埋める。 このようにグリッド分割し、その領域内の代表値を求める操作はPooling(プーリング) と呼ばれる。 これらプーリング操作はCNN(Convolutional Neural Network) において重要な役割を持つ。

これは次式で定義される。

v = 1/|R| * Sum_{i in R} v_i

https://github.com/minido/Gasyori100knock-1/tree/master/Question_01_10#q7-平均プーリング

Coding

import sharp from 'sharp';

export async function averagePooling(
  inputPath: string, 
  outputPath: string, 
  gridSize: number = 8
): Promise<void> {
  try {
    const image = await sharp(inputPath)
      .raw()
      .toBuffer({ resolveWithObject: true });

    const { data, info } = image;
    const { width, height, channels } = info;

    // グリッド数を計算
    const numGridsH = Math.floor(height / gridSize);
    const numGridsW = Math.floor(width / gridSize);

    // 新しい画像データ用のバッファを作成
    const newData = Buffer.from(data);  // copy the original image

    // 各グリッドに対して処理
    for (let y = 0; y < numGridsH; y++) {
      for (let x = 0; x < numGridsW; x++) {
        // 各チャネルに対して処理
        for (let c = 0; c < channels; c++) {
          let sum = 0;
          const startY = y * gridSize;
          const startX = x * gridSize;

          // グリッド内の値の平均を計算
          for (let dy = 0; dy < gridSize; dy++) {
            for (let dx = 0; dx < gridSize; dx++) {
              const pos = ((startY + dy) * width + (startX + dx)) * channels + c;
              sum += data[pos];
            }
          }

          // 平均値を計算
          const mean = Math.round(sum / (gridSize * gridSize));

          // グリッド内の全ピクセルに平均値を設定
          for (let dy = 0; dy < gridSize; dy++) {
            for (let dx = 0; dx < gridSize; dx++) {
              const pos = ((startY + dy) * width + (startX + dx)) * channels + c;
              newData[pos] = mean;
            }
          }
        }
      }
    }

    // 結果を保存
    await sharp(newData, {
      raw: {
        width,
        height,
        channels
      }
    })
    .toFile(outputPath);

    console.log('プーリング処理が完了しました');
  } catch (error) {
    console.error('画像処理中にエラーが発生しました:', error);
    throw error;
  }
}

Test

import { existsSync, unlinkSync } from 'fs';
import { join } from 'path';
import sharp from 'sharp';
import { averagePooling } from './imageProcessor';

describe('Average Pooling Tests', () => {
  const testInputPath = join(__dirname, '../test-images/test.jpeg');
  const testOutputPath = join(__dirname, '../test-images/test-pooled.png');

  afterEach(() => {
    if (existsSync(testOutputPath)) {
      unlinkSync(testOutputPath);
    }
  });

  test('should successfully process image', async () => {
    await expect(averagePooling(testInputPath, testOutputPath))
      .resolves.not.toThrow();
    expect(existsSync(testOutputPath)).toBe(true);
  });

  test('should maintain image dimensions', async () => {
    await averagePooling(testInputPath, testOutputPath);
    
    const inputMetadata = await sharp(testInputPath).metadata();
    const outputMetadata = await sharp(testOutputPath).metadata();

    expect(outputMetadata.width).toBe(inputMetadata.width);
    expect(outputMetadata.height).toBe(inputMetadata.height);
  });

  test('should create uniform blocks', async () => {
    const gridSize = 8;
    await averagePooling(testInputPath, testOutputPath, gridSize);
    
    const outputImage = await sharp(testOutputPath)
      .raw()
      .toBuffer({ resolveWithObject: true });

    // グリッドの一つを選んでチェック
    const { data, info } = outputImage;
    const { width, channels } = info;
    const startX = gridSize;
    const startY = gridSize;

    // グリッド内の全ピクセルが同じ値を持つことを確認
    const firstPixelPos = (startY * width + startX) * channels;
    const referenceValues = [
      data[firstPixelPos],
      data[firstPixelPos + 1],
      data[firstPixelPos + 2]
    ];

    // グリッド内の他のピクセルをチェック
    for (let dy = 0; dy < gridSize; dy++) {
      for (let dx = 0; dx < gridSize; dx++) {
        const pos = ((startY + dy) * width + (startX + dx)) * channels;
        expect(data[pos]).toBe(referenceValues[0]);     // R
        expect(data[pos + 1]).toBe(referenceValues[1]); // G
        expect(data[pos + 2]).toBe(referenceValues[2]); // B
      }
    }
  });
});

結果

入力 結果

Discussion