Open1

俺は一行も仕様を読まずに Claude 3 Opus に component model の canonical ABI を実装させる

mizchimizchi

まず仕様書をぶち込む。
https://raw.githubusercontent.com/WebAssembly/component-model/main/design/mvp/CanonicalABI.md

ここまで生成させた。

/**
以下は今までの実装の要約です。

-----

はい、わかりました。これまでの議論を要約しますね。
WebAssembly Component Model の Canonical ABI 仕様に基づいて、TypeScript でエンコーダとデコーダの実装を進めてきました。
主な実装内容は以下の通りです:

メモリ操作のためのユーティリティ関数 loadInt、storeInt などを実装しました。
基本的な型(bool、u32、s32 など)に対する lift と lower 関数を実装しました。
文字列型に対する liftString と lowerString 関数を実装しました。
レコード型の例として Person 型を定義し、liftPerson と lowerPerson 関数を実装しました。
Variant 型に対する liftVariant と lowerVariant 関数を実装しました。
各実装に対するテストケースを Deno を使って作成しました。

実装の過程では、以下のような課題に直面し、議論してきました:

メモリレイアウトとアライメントの正しい扱い方
ポインタの扱い(相対位置 vs 絶対位置)
可変長データ(文字列など)の扱い
Variant 型のディスクリミナントとペイロードの扱い
未実装の型や関数のエラー処理

現在、Variant 型のテストがまだ落ちている状態です。具体的には、sizeOf 関数と alignment 関数が u8 や u16 などの型に対応していないことが原因でエラーが発生しています。
これを解決するために、以下の修正を提案しました:

sizeOf 関数と alignment 関数で u8 と u16 をサポート
liftPayload と lowerPayload 関数で u8 と u16 のケースを追加
liftU8、lowerU8、liftU16、lowerU16 関数の実装

しかし、これらの修正を適用してもテストは依然として落ちている状態です。
実装を進める中で、Canonical ABI の仕様を正しく理解し、それに忠実に従うことの重要性を再認識しました。特に、メモリレイアウトとポインタの扱いには細心の注意が必要です。
また、テスト駆動開発(TDD)の有用性も実感しました。テストを書くことで、実装の正確性を確保し、リグレッションを防ぐことができます。
今後は、以下の点に注力していく必要があると考えています:

現在のエラーの原因を特定し、修正する
まだ実装されていない型(Option、Result、Tuple など)の実装を進める
エラー処理とメモリ管理を改善する
より網羅的なテストケースを追加する

Canonical ABI の完全な実装は容易ではありませんが、一歩一歩着実に進めていきたいと思います。仕様書を理解し、コードに反映していくことを繰り返し、正確で効率的な実装を目指します。
 */

import { assertEquals, assertThrows } from "jsr:@std/assert@0.223.0";

// データ型の定義
type Bool = boolean;
type S8 = number;
type S16 = number;
type S32 = number;
type S64 = bigint;
type U8 = number;
type U16 = number;
type U32 = number;
type U64 = bigint;

// メモリ操作ユーティリティ
function loadInt(memory: ArrayBuffer, ptr: number, size: number, signed: boolean): number | bigint {
  const view = new DataView(memory);
  switch (size) {
    case 1: return signed ? view.getInt8(ptr) : view.getUint8(ptr);
    case 2: return signed ? view.getInt16(ptr, true) : view.getUint16(ptr, true);
    case 4: return signed ? view.getInt32(ptr, true) : view.getUint32(ptr, true);
    case 8: return signed ? view.getBigInt64(ptr, true) : view.getBigUint64(ptr, true);
    default: throw new Error(`Invalid size: ${size}`);
  }
}

function storeInt(memory: ArrayBuffer, ptr: number, value: number | bigint, size: number, signed: boolean): void {
  const view = new DataView(memory);
  switch (size) {
    case 1: signed ? view.setInt8(ptr, Number(value)) : view.setUint8(ptr, Number(value)); break;
    case 2: signed ? view.setInt16(ptr, Number(value), true) : view.setUint16(ptr, Number(value), true); break;
    case 4: signed ? view.setInt32(ptr, Number(value), true) : view.setUint32(ptr, Number(value), true); break;
    case 8: signed ? view.setBigInt64(ptr, BigInt(value), true) : view.setBigUint64(ptr, BigInt(value), true); break;
    default: throw new Error(`Invalid size: ${size}`);
  }
}

// lift 関数の一部
function liftBool(memory: ArrayBuffer, ptr: number): Bool {
  return loadInt(memory, ptr, 1, false) !== 0;
}

function liftS32(memory: ArrayBuffer, ptr: number): S32 {
  return loadInt(memory, ptr, 4, true) as number;
}

function liftU32(memory: ArrayBuffer, ptr: number): U32 {
  return loadInt(memory, ptr, 4, false) as number;
}

// lower 関数の一部
function lowerBool(memory: ArrayBuffer, ptr: number, value: Bool): void {
  storeInt(memory, ptr, value ? 1 : 0, 1, false);
}

function lowerS32(memory: ArrayBuffer, ptr: number, value: S32): void {
  storeInt(memory, ptr, value, 4, true);
}

function lowerU32(memory: ArrayBuffer, ptr: number, value: U32): void {
  storeInt(memory, ptr, value, 4, false);
}

// flatten 関数の一部
function flattenBool(): string[] {
  return ['i32'];
}

function flattenS32(): string[] {
  return ['i32'];
}

function flattenU32(): string[] {
  return ['i32'];
}

// 文字列型
function liftString(memory: ArrayBuffer, ptr: number): [string, string, number] {
  const view = new DataView(memory);
  const len = view.getUint32(ptr + 4, true);
  const buf = new Uint8Array(memory, view.getUint32(ptr, true), len);
  const decoder = new TextDecoder('utf-8');
  const str = decoder.decode(buf);
  return [str, 'utf8', len];
}

function lowerString(memory: ArrayBuffer, ptr: number, value: [string, string, number]): void {
  const [str, _, len] = value;
  const encoder = new TextEncoder();
  const buf = encoder.encode(str);
  const view = new DataView(memory);
  view.setUint32(ptr, ptr + 8, true);
  view.setUint32(ptr + 4, len, true);
  new Uint8Array(memory, ptr + 8, len).set(buf);
}
// function liftString(memory: ArrayBuffer, ptr: number): [string, string, number] {
//   const view = new DataView(memory);
//   const len = view.getUint32(ptr + 4, true);
//   const buf = new Uint8Array(memory, view.getUint32(ptr, true), len);
//   const decoder = new TextDecoder('utf-8');
//   const str = decoder.decode(buf);
//   return [str, 'utf8', len];
// }

// function lowerString(memory: ArrayBuffer, ptr: number, value: [string, string, number]): void {
//   const [str, encoding, len] = value;
//   const view = new DataView(memory);
//   const encoder = new TextEncoder();
//   const buf = encoder.encode(str);
//   const newPtr = realloc(memory, 0, 0, 1, buf.length);
//   new Uint8Array(memory, newPtr, buf.length).set(buf);
//   view.setUint32(ptr, newPtr, true);
//   view.setUint32(ptr + 4, buf.length, true);
// }

function flattenString(): string[] {
  return ['i32', 'i32'];
}

// 配列型
function liftArray(memory: ArrayBuffer, ptr: number, liftElem: (memory: ArrayBuffer, ptr: number) => any): any[] {
  const view = new DataView(memory);
  const len = view.getUint32(ptr + 4, true);
  const arr = [];
  let elemPtr = view.getUint32(ptr, true);
  for (let i = 0; i < len; i++) {
    arr.push(liftElem(memory, elemPtr));
    elemPtr += sizeOf(typeof liftElem);
  }
  return arr;
}

function lowerArray(memory: ArrayBuffer, ptr: number, value: any[], lowerElem: (memory: ArrayBuffer, ptr: number, value: any) => void): void {
  const view = new DataView(memory);
  const len = value.length;
  const buf = new ArrayBuffer(len * sizeOf(typeof lowerElem));
  const newPtr = realloc(memory, 0, 0, alignment(typeof lowerElem), buf.byteLength);
  view.setUint32(ptr, newPtr, true);
  view.setUint32(ptr + 4, len, true);
  let elemPtr = newPtr;
  for (let i = 0; i < len; i++) {
    lowerElem(memory, elemPtr, value[i]);
    elemPtr += sizeOf(typeof lowerElem);
  }
}

function flattenArray(flattenElem: () => string[]): string[] {
  return ['i32', 'i32', ...flattenElem()];
}

// レコード型の一部
type Person = {
  name: string;
  age: U32;
};

function liftPerson(memory: ArrayBuffer, ptr: number): Person {
  let offset = 0;
  const [name, _, length] = liftString(memory, ptr + offset);
  offset += 8 + length;
  const age = liftU32(memory, ptr + offset);
  return { name, age };
}

function lowerPerson(memory: ArrayBuffer, ptr: number, value: Person): void {
  let offset = 0;
  const encoder = new TextEncoder();
  const buf = encoder.encode(value.name);
  const view = new DataView(memory);
  view.setUint32(ptr + offset, ptr + offset + 8, true);
  view.setUint32(ptr + offset + 4, buf.length, true);
  new Uint8Array(memory, ptr + offset + 8, buf.length).set(buf);
  offset += 8 + buf.length;
  lowerU32(memory, ptr + offset, value.age);
}

type Variant = { [key: string]: any };

function discriminant_type(cases: string[]): string {
  const n = cases.length;
  if (n <= 255) return 'u8';
  if (n <= 65535) return 'u16';
  return 'u32';
}

function max_case_alignment(cases: string[]): number {
  // 今回は簡単のため、全てのケースが同じアライメントを持つと仮定します。
  return 8;
}

function match_case(v: Variant, cases: string[]): [number, any] {
  const label = Object.keys(v)[0];
  const index = cases.indexOf(label);
  if (index === -1) throw new Error(`Unhandled case: ${label}`);
  return [index, v[label]];
}

function flattenPerson(): string[] {
  return [...flattenString(), ...flattenU32()];
}

function liftVariant(memory: ArrayBuffer, ptr: number, cases: string[]): Variant {
  const disc_type = discriminant_type(cases);
  const disc_size = sizeOf(disc_type);
  const case_index = liftInt(memory, ptr, disc_size, false);
  ptr += disc_size;
  if (case_index >= cases.length) throw new Error('Invalid discriminant');
  const label = cases[case_index];
  ptr = align_to(ptr, max_case_alignment(cases));
  const payload = label === 'void' ? null : liftPayload(memory, ptr, label);
  return { [label]: payload };
}

function align_to(ptr: number, alignment: number): number {
  return Math.ceil(ptr / alignment) * alignment;
}

function liftU8(memory: ArrayBuffer, ptr: number): number {
  return liftInt(memory, ptr, 1, false);
}

function lowerU8(memory: ArrayBuffer, ptr: number, value: number): void {
  lowerInt(memory, ptr, value, 1, false);
}

function liftU16(memory: ArrayBuffer, ptr: number): number {
  return liftInt(memory, ptr, 2, false);
}

function lowerU16(memory: ArrayBuffer, ptr: number, value: number): void {
  lowerInt(memory, ptr, value, 2, false);
}

// function liftPayload(memory: ArrayBuffer, ptr: number, label: string): any {
//   switch (label) {
//     case 'bool': return liftBool(memory, ptr);
//     case 'u32': return liftU32(memory, ptr);
//     // ... 他の型のliftを追加 ...
//     default: throw new Error(`Unsupported payload type: ${label}`);
//   }
// }

// function liftPayload(memory: ArrayBuffer, ptr: number, label: string): any {
//   switch (label) {
//     case 'bool': return liftBool(memory, ptr);
//     case 'u8': return liftU8(memory, ptr);
//     case 'u16': return liftU16(memory, ptr);
//     case 'u32': return liftU32(memory, ptr);
//     // ... 他の型のliftを追加 ...
//     default: throw new Error(`Unsupported payload type: ${label}`);
//   }
// }

function liftPayload(memory: ArrayBuffer, ptr: number, label: string): any {
  switch (label) {
    case 'bool': return liftBool(memory, ptr);
    case 'u8': return liftU8(memory, ptr);
    case 'u16': return liftU16(memory, ptr);
    case 'u32': return liftU32(memory, ptr);
    // その他の型の処理
    default: throw new Error(`Unsupported payload type: ${label}`);
  }
}

function lowerInt(memory: ArrayBuffer, ptr: number, value: number, size: number, signed: boolean): void {
  const view = new DataView(memory);
  switch (size) {
    case 1:
      if (signed) view.setInt8(ptr, value);
      else view.setUint8(ptr, value);
      break;
    case 2:
      if (signed) view.setInt16(ptr, value, true);
      else view.setUint16(ptr, value, true);
      break;
    case 4:
      if (signed) view.setInt32(ptr, value, true);
      else view.setUint32(ptr, value, true);
      break;
    default:
      throw new Error(`Unsupported integer size: ${size}`);
  }
}

function liftInt(memory: ArrayBuffer, ptr: number, size: number, signed: boolean): number {
  const view = new DataView(memory);
  switch (size) {
    case 1: return signed ? view.getInt8(ptr) : view.getUint8(ptr);
    case 2: return signed ? view.getInt16(ptr, true) : view.getUint16(ptr, true);
    case 4: return signed ? view.getInt32(ptr, true) : view.getUint32(ptr, true);
    default: throw new Error(`Unsupported integer size: ${size}`);
  }
}
// function liftInt(memory: ArrayBuffer, ptr: number, size: number, signed: boolean): number {
//   const view = new DataView(memory);
//   if (ptr < 0 || ptr >= memory.byteLength) {
//       throw new Error("Pointer out of bounds");
//   }
//   // 他のエラー条件も同様にチェック
// }

function lowerVariant(memory: ArrayBuffer, ptr: number, v: Variant, cases: string[]): void {
  const [case_index, payload] = match_case(v, cases);
  const disc_type = discriminant_type(cases);
  const disc_size = sizeOf(disc_type);
  lowerInt(memory, ptr, case_index, disc_size, false);
  ptr += disc_size;
  ptr = align_to(ptr, max_case_alignment(cases));
  const label = cases[case_index];
  if (label !== 'void') {
    lowerPayload(memory, ptr, payload, label);
  }
}


function lowerPayload(memory: ArrayBuffer, ptr: number, payload: any, label: string): void {
  switch (label) {
    case 'bool': lowerBool(memory, ptr, payload); break;
    case 'u8': lowerU8(memory, ptr, payload); break;
    case 'u16': lowerU16(memory, ptr, payload); break;
    case 'u32': lowerU32(memory, ptr, payload); break;
    // ... 他の型のlowerを追加 ...
    default: throw new Error(`Unsupported payload type: ${label}`);
  }
}


// realloc 関数 (簡易版)
function realloc(memory: ArrayBuffer, ptr: number, oldSize: number, alignment: number, newSize: number): number {
  const alignedPtr = Math.ceil(ptr / alignment) * alignment;
  if (alignedPtr + newSize > memory.byteLength) {
    throw new Error('Out of memory');
  }
  return alignedPtr;
}

function sizeOf(type: string): number {
  switch (type) {
    case 'bool': return 1;
    case 'u8': return 1;
    case 'u16': return 2;
    case 'u32': case 'i32': return 4;
    case 'u64': case 'i64': return 8;
    case 'f32': return 4;
    case 'f64': return 8;
    default: throw new Error(`Unsupported type: ${type}`);
  }
}

// alignment 関数 (簡易版)
function alignment(type: string): number {
  return sizeOf(type);
}

Deno.test('Canonical ABI', async (t) => {
  let memory: ArrayBuffer;
  const memorySize = 1024;

  await t.step({
    name: 'lift',
    fn: async (t) => {
      await t.step({
        name: 'liftBool',
        fn: () => {
          memory = new ArrayBuffer(memorySize);
          const ptr = 0;
          const view = new DataView(memory);
          view.setUint8(ptr, 1);
          assertEquals(liftBool(memory, ptr), true);
          view.setUint8(ptr, 0);
          assertEquals(liftBool(memory, ptr), false);
        },
      });

      await t.step({
        name: 'liftS32',
        fn: () => {
          memory = new ArrayBuffer(memorySize);
          const ptr = 0;
          const view = new DataView(memory);
          view.setInt32(ptr, -42, true);
          assertEquals(liftS32(memory, ptr), -42);
        },
      });

      await t.step({
        name: 'liftU32',
        fn: () => {
          memory = new ArrayBuffer(memorySize);
          const ptr = 0;
          const view = new DataView(memory);
          view.setUint32(ptr, 42, true);
          assertEquals(liftU32(memory, ptr), 42);
        },
      });
    },
  });

  await t.step({
    name: 'lower',
    fn: async (t) => {
      await t.step({
        name: 'lowerBool',
        fn: () => {
          memory = new ArrayBuffer(memorySize);
          const ptr = 0;
          lowerBool(memory, ptr, true);
          assertEquals(new DataView(memory).getUint8(ptr), 1);
          lowerBool(memory, ptr, false);
          assertEquals(new DataView(memory).getUint8(ptr), 0);
        },
      });

      await t.step({
        name: 'lowerS32',
        fn: () => {
          memory = new ArrayBuffer(memorySize);
          const ptr = 0;
          lowerS32(memory, ptr, -42);
          assertEquals(new DataView(memory).getInt32(ptr, true), -42);
        },
      });

      await t.step({
        name: 'lowerU32',
        fn: () => {
          memory = new ArrayBuffer(memorySize);
          const ptr = 0;
          lowerU32(memory, ptr, 42);
          assertEquals(new DataView(memory).getUint32(ptr, true), 42);
        },
      });
    },
  });


  await t.step({
    name: 'liftString and lowerString',
    fn: () => {
      memory = new ArrayBuffer(memorySize);
      const ptr = 0;
      const str = 'Hello, world!';
      lowerString(memory, ptr, [str, 'utf8', str.length]);

      const view = new DataView(memory);
      const offset = view.getUint32(ptr, true);
      const len = view.getUint32(ptr + 4, true);
      assertEquals(len, str.length);

      const buf = new Uint8Array(memory, offset, len);
      const decoder = new TextDecoder('utf-8');
      const decoded = decoder.decode(buf);
      assertEquals(decoded, str);

      const [lifted, encoding, liftedLen] = liftString(memory, ptr);
      assertEquals(lifted, str);
      assertEquals(encoding, 'utf8');
      assertEquals(liftedLen, str.length);
    },
  });
  await t.step({
    name: 'liftPerson and lowerPerson',
    fn: () => {
      memory = new ArrayBuffer(memorySize);
      const ptr = 0;
      const person: Person = { name: 'Alice', age: 30 };
      lowerPerson(memory, ptr, person);
      const lifted = liftPerson(memory, ptr);
      assertEquals(lifted.name, person.name);
      assertEquals(lifted.age, person.age);
    },
  });

  await t.step({
    name: 'liftU8 and lowerU8',
    fn: () => {
      memory = new ArrayBuffer(memorySize);
      const ptr = 0;
      const value: U8 = 255;
      lowerU8(memory, ptr, value);
      assertEquals(liftU8(memory, ptr), value);
    },
  });

  await t.step({
    name: 'liftU16 and lowerU16',
    fn: () => {
      memory = new ArrayBuffer(memorySize);
      const ptr = 0;
      const value: U16 = 65535;
      lowerU16(memory, ptr, value);
      assertEquals(liftU16(memory, ptr), value);
    },
  });
});

Deno.test('Variant', async (t) => {
  const memory = new ArrayBuffer(1024);

  // failed
  await t.step('bool payload', () => {
    const cases = ['A', 'B'];
    const ptr = 0;
    const variant: Variant = { A: true };

    lowerVariant(memory, ptr, variant, cases);
    const lifted = liftVariant(memory, ptr, cases);

    assertEquals(lifted, variant);
  });

  // await t.step('u32 payload', () => {
  //   const cases = ['A', 'B'];
  //   const ptr = 0;
  //   const variant: Variant = { B: 42 };

  //   lowerVariant(memory, ptr, variant, cases);
  //   const lifted = liftVariant(memory, ptr, cases);

  //   assertEquals(lifted, variant);
  // });

  // await t.step('no payload', () => {
  //   const cases = ['A', 'B'];
  //   const ptr = 0;
  //   const variant: Variant = { A: null };

  //   lowerVariant(memory, ptr, variant, cases);
  //   const lifted = liftVariant(memory, ptr, cases);

  //   assertEquals(lifted, variant);
  // });

  // await t.step('invalid discriminant', () => {
  //   const cases = ['A', 'B'];
  //   const ptr = 0;
  //   const variant: Variant = { C: null };

  //   assertThrows(
  //     () => lowerVariant(memory, ptr, variant, cases),
  //     Error,
  //     'Unhandled case: C',
  //   );
  // });

  // await t.step('invalid payload type', () => {
  //   const cases = ['A', 'B'];
  //   const ptr = 0;
  //   const variant: Variant = { A: 'hello' };

  //   lowerVariant(memory, ptr, variant, cases);

  //   assertThrows(
  //     () => liftVariant(memory, ptr, cases),
  //     Error,
  //     'Unsupported payload type: string',
  //   );
  // });

  // Add more test cases for other payload types and edge cases
});