Open6

Durable Objects + WebSocket 復習

mizchimizchi

毎回忘れるので、websocket でつなぐまでをステップバイステップで書き残す

Simple Counter

まずはカウンターを作る。このカウンターの wrangler.toml は後のサンプルでもずっと使い回す。

wrangler.toml
compatibility_date = "2024-03-20"

[durable_objects]
bindings = [{name = "COUNTER", class_name = "Counter"}]

[[migrations]]
tag = "v1" # Should be unique for each entry
new_classes = ["Counter"]
type Env = {
  COUNTER: DurableObjectNamespace;
};

export default {
  async fetch(request: Request, env: Env) {
    const name = "test";
    const obj = env.COUNTER.get(env.COUNTER.idFromName(name));
    const res = await obj.fetch(request.url);
    return new Response(`${name}: ${await res.text()}`);
  }
}

export class Counter implements DurableObject {
  constructor(public state: DurableObjectState, public env: Env) { }
  async fetch(_request: Request) {
    const value: number = await this.state.storage.get("value") || 0;
    await this.state.storage.put("value", value + 1);
    return new Response(value.toString());
  }
}

インメモリストレージと永続ストレージがある。インメモリストレージはアクセスが一定時間途絶えてハイバネートするときに消える。

class の fetch ハンドラでリクエストを返す

$ wrangler dev

Debug

wscat で簡単なデバッグ

$ npx wscat -c http://localhost:8787
> xxx
mizchimizchi

WebSocket

WebSocket コネクションを確立する。この時点では DO とコネクションしてるわけではない。

type Env = {
  COUNTER: DurableObjectNamespace;
};

export default {
  async fetch(request: Request, env: Env) {
    const upgradeHeader = request.headers.get('Upgrade');
    if (!upgradeHeader || upgradeHeader !== 'websocket') {
      return new Response('Expected Upgrade: websocket', { status: 426 });
    }

    const webSocketPair = new WebSocketPair();
    const [client, server] = Object.values(webSocketPair);

    // @ts-ignore
    server.accept();
    server.addEventListener('message', event => {
      console.log(event.data);
    });

    return new Response(null, {
      status: 101,
      webSocket: client,
    });
  }
}

export class Counter implements DurableObject {
  constructor(public state: DurableObjectState, public env: Env) { }
  async fetch(_request: Request) {
    const value: number = await this.state.storage.get("value") || 0;
    await this.state.storage.put("value", value + 1);
    return new Response(value.toString());
  }
}
mizchimizchi

WebSocket on DO fetch handler

worker 側はオブジェクトの振り分け等を最小限ハンドルしつつ、DO 側で WebSocket コネクションを確立する。

type Env = {
  COUNTER: DurableObjectNamespace;
};

export default {
  async fetch(request: Request, env: Env) {
    const upgradeHeader = request.headers.get('Upgrade');
    if (!upgradeHeader || upgradeHeader !== 'websocket') {
      return new Response('Expected Upgrade: websocket', { status: 426 });
    }
    const obj = env.COUNTER.get(env.COUNTER.idFromName("test"));
    return obj.fetch(request);
  }
}

export class Counter implements DurableObject {
  constructor(public state: DurableObjectState, public env: Env) { }
  async fetch(_request: Request) {
    const webSocketPair = new WebSocketPair();
    const [client, server] = Object.values(webSocketPair);
    this.#handleSession(server);
    return new Response(null, { status: 101, webSocket: client });
  };

  #handleSession(socket: WebSocket) {
    console.log('Socket connected', socket);
    // @ts-ignore
    socket.accept();
    socket.addEventListener('message', async event => {
      const value: number = await this.state.storage.get("value") || 0;
      await this.state.storage.put("value", value + 1);
      console.log('Socket onmessage', event.data, value);
      return new Response(value.toString());
    });
    socket.addEventListener('close', () => {
      console.log('Socket closed');
    });

    socket.addEventListener('error', (event) => {
      console.log('Socket errored', event);
    });
  }

}
mizchimizchi

WebSocket Hibernate

DOのインメモリステートでソケットインスタンスを管理していると、ハイバネートに入ったときに現在のコネクション情報を失う。

なので、組み込みの this.state.acceptWebSocket(server) で DO にコネクション一覧を管理させる。
webSocketMessage() で message を受け、 webSocketClose で disconnect 時の処理を書く。

type Env = {
  COUNTER: DurableObjectNamespace;
};

export default {
  async fetch(request: Request, env: Env) {
    const upgradeHeader = request.headers.get('Upgrade');
    if (!upgradeHeader || upgradeHeader !== 'websocket') {
      return new Response('Expected Upgrade: websocket', { status: 426 });
    }
    const obj = env.COUNTER.get(env.COUNTER.idFromName("test"));
    return obj.fetch(request);
  }
}

export class Counter implements DurableObject {
  constructor(public state: DurableObjectState, public env: Env) { }
  async fetch(_request: Request) {
    const webSocketPair = new WebSocketPair();
    const [client, server] = Object.values(webSocketPair);
    this.state.acceptWebSocket(server);
    return new Response(null, { status: 101, webSocket: client });
  };

  async webSocketMessage(socket: WebSocket, message: string | ArrayBuffer) {
    const len = this.state.getWebSockets().length;
    socket.send(`[echo] ${message} in ${len}`);
  }

  async webSocketClose(socket: WebSocket, code: number, reason: string, wasClean: boolean) {
    // If the client closes the connection, we will close it too.
    socket.close(code, "Durable Object is closing WebSocket");
  }
}
mizchimizchi

Browser WebSocket Client

さすがに wscat だけではだめだろうと思うので、ブラウザ版も用意する。

<html>

<head>
  <meta charset="utf-8">
</head>

<body>
  <h1>Hello World!</h1>
  <script type="module">
    const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
    const host = 'localhost:8787' // window.location.host;
    const wsUrl = `${protocol}//${host}`;
    const socket = new WebSocket(wsUrl);
    socket.addEventListener('error', console.error);
    socket.addEventListener('close', console.log);
    socket.addEventListener('message', (event) => {
      console.log('[socket:message]', event);
    });
    await new Promise((resolve) => socket.addEventListener('open', resolve));
    socket.send("Hello");
  </script>
</body>

</html>
mizchimizchi

RPC

そのままのインターフェースだと使いづらいので、 comlink 的なインターフェースでラップするユーティリティを作る。

rpc.ts
const defer = <T>() => {
  let resolve: (arg: T) => void = undefined as any;
  let reject: (arg: any) => void = undefined as any;
  const promise = new Promise<T>((r, j) => {
    resolve = r;
    reject = j;
  });
  return { promise, resolve, reject };
}

function ulid() {
  return Date.now().toString(36) + Math.random().toString(36).slice(2);
}

const REQUEST_MARKER = '$r' as const;
const RESPONSE_MARKER = '$s' as const;
const ERROR_MARKER = '$e' as const;
// const 

type RPC_REQUEST = [typeof REQUEST_MARKER, rid: string, method: string, ...any[]];
type RPC_RESPONSE = [typeof RESPONSE_MARKER, rid: string, any];
type RPC_ERROR_RESPONSE = [typeof ERROR_MARKER, rid: string, error: string, stack?: string];

function isRpcRequest(parsed: any): parsed is RPC_REQUEST {
  return Array.isArray(parsed) && parsed[0] === REQUEST_MARKER;
}

function isRpcResponse(parsed: any): parsed is RPC_RESPONSE {
  return Array.isArray(parsed) && parsed[0] === RESPONSE_MARKER
}

function isRpcErrorResponse(parsed: any): parsed is RPC_ERROR_RESPONSE {
  return Array.isArray(parsed) && parsed[0] === ERROR_MARKER;
}

export async function runHandlers<T extends Record<string, (...args: any[]) => any>>(
  handlers: T,
  socket: WebSocket,
  message: string | ArrayBuffer
) {
  const parsed = JSON.parse(message as string);
  if (isRpcRequest(parsed)) {
    const [_, rid, method, ...args] = parsed;
    const handler = handlers[method];
    if (!handler) {
      socket.send(JSON.stringify([ERROR_MARKER, rid, `[unknown-method] ${method}`]));
      return;
    }
    try {
      const result = await handler(...args);
      socket.send(JSON.stringify([RESPONSE_MARKER, rid, result]));
      return;
    } catch (error) {
      if (error instanceof Error) {
        socket.send(JSON.stringify([ERROR_MARKER, rid, error.message]));
      } else {
        socket.send(JSON.stringify([ERROR_MARKER, rid, JSON.stringify(error)]));
      }
      return;
    }
  }
  socket.send(`[unknown] ${message}`);
}

export function createRpc<T extends { [key: string]: (...args: any[]) => any }>(socket: WebSocket): {
  [K in keyof T]: (...args: Parameters<T[K]>) => Promise<Awaited<ReturnType<T[K]>>>
} {
  const resolverMap: Map<string, [r: (arg: any) => any, j: (arg: any) => any]> = new Map();
  socket.addEventListener('message', async (event) => {
    const parsed = JSON.parse(event.data);
    if (isRpcErrorResponse(parsed)) {
      const [_, rid, error, stack] = parsed;
      const resolver = resolverMap.get(rid);
      resolver?.[1]?.(error);
      return;
    };

    if (isRpcResponse(parsed)) {
      const [_, rid, result] = parsed;
      const resolver = resolverMap.get(rid);
      resolver?.[0]?.(result);
    } else {
      throw new Error(`[unknown-rpc] ${event.data}`);
    }
  });

  return new Proxy({}, {
    get(_, method: string) {
      const rid = ulid();
      const d = defer();
      resolverMap.set(rid, [d.resolve, d.reject]);
      return (...args: any[]) => {
        socket.send(JSON.stringify([
          REQUEST_MARKER, rid, method, ...args
        ] satisfies RPC_REQUEST));
        d.promise.finally(() => resolverMap.delete(rid));
        return d.promise;
      };
    }
  }) as any;
}

ワーカー側の実装

worker.ts
import { runHandlers } from "./rpc";

type Env = {
  COUNTER: DurableObjectNamespace;
};

export default {
  async fetch(request: Request, env: Env) {
    const upgradeHeader = request.headers.get('Upgrade');
    if (!upgradeHeader || upgradeHeader !== 'websocket') {
      return new Response('Expected Upgrade: websocket', { status: 426 });
    }
    const obj = env.COUNTER.get(env.COUNTER.idFromName("test"));
    return obj.fetch(request);
  }
}

const handlers = {
  async foo(a: number) {
    return a + 1;
  },
  async bar() {
    return 3;
  },
  panic() {
    throw new Error("panic!");
  }
}

export type Handlers = typeof handlers;

export class Counter implements DurableObject {
  constructor(public state: DurableObjectState, public env: Env) { }
  async fetch(_request: Request) {
    const webSocketPair = new WebSocketPair();
    const [client, server] = Object.values(webSocketPair);
    this.state.acceptWebSocket(server);
    return new Response(null, { status: 101, webSocket: client });
  };

  async webSocketMessage(socket: WebSocket, message: string | ArrayBuffer) {
    await runHandlers(handlers, socket, message);
  }

  async webSocketClose(socket: WebSocket, code: number, reason: string, wasClean: boolean) {
    // If the client closes the connection, we will close it too.
    socket.close(code, "Durable Object is closing WebSocket");
  }
}

呼び出し側

import { createRpc } from "./rpc";
import type { Handlers } from "./worker"

async function main() {
  const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
  const host = 'localhost:8787' // window.location.host;
  const wsUrl = `${protocol}//${host}`;
  const socket = new WebSocket(wsUrl);
  socket.addEventListener('error', console.error);
  socket.addEventListener('close', console.log);
  await new Promise((resolve) => socket.addEventListener('open', resolve));
  const rpc = createRpc<Handlers>(socket);

  // run
  console.log(await rpc.foo(1), await rpc.bar());

  // remote error
  try {
    await rpc.panic();
  } catch (error) {
    // console.error('catched!', error);
  }
}

main().catch(console.error);