🔄

Rustで実装しながら学ぶWebSocketの基本

2023/11/20に公開

WebSocketを双方向通信のために使うプロトコルでしょ、という感じのうっすら理解で誤魔化していた[1]のですが、IoTアプリケーションぽいものを作ることがあって、理解を深めるためにあらためて学びました、という投稿です。

今回はWebSocketでエコーするサーバーを、TCPライブラリだけを用いてRustで実装していきます。コードは以下です。

https://github.com/ohke/rust-websocket

WebSocketとは

RFC 6455で定義された、主に双方向でやり取りするために用いられる通信プロトコルです。

WebSocketが直接用いるプロトコルはTCPですが、ハンドシェイクはHTTP(S)によって行われます。WebSocketを使うプロトコルとして、MQTT over WebSocketSTOMPなどがあります。

最近のWebアプリケーションではごく普通に用いられてます。企業サイトなどでよく見る問い合わせ用のチャットフォームなどは、WebSocketで実装されていることが多いかと思います。OSS、例えばKubeflowでは、実行中のDAGのグラフやログのリアルタイムな表示などはWebSocketで行われます。

| MQTT over WebSocket, STOMP, ... |
| WebSocket                       |
| TCP, HTTP(S)                    |

WebSocketでの通信は、大まかに以下のステップで行われます。

  • URIは ws://... または wss://...
    • デフォルトのポート番号は、wsの場合は80番、wssの場合は443番
    • スキーム (ws://) 以降は通常のURL等と同じだが、ハッシュ (#) は使えない
  • 接続開始のハンドシェイク (opening handshake) は、HTTP(S)で行われ、クライアントからのリクエストヘッダにconnection, upgrade, sec-websocket-version, sec-websocket-keyを;付与する必要がある
    • サーバにはconnection, upgrade, sec-websocket-acceptが付与したレスポンスをステータスコード101で返すことで、以降はWebSocketで通信することになる
  • 接続が完了したら、in以降はWebSocketのフレームにデータ (テキスト or バイナリ) を乗せて、TCP上でお互いに送受信し合う
  • 接続終了のハンドシェイク (closing handshake) は、WebSocketの制御フレームで行う

WebSocketサーバの実装

Rustでサーバを実装しながら、上述したWebSocketプロトコルの流れを見ていきましょう。
以下のクレートを使います。WebSocketの仕組みを学ぶために、今回はstd::net::TcpLisgenerを使ってTCPで実装していきます。

// [dependencies]
// base64 = "0.21.5"
// rand = "0.8.5"
// sha1 = "0.10.6"
use base64::{engine::general_purpose, Engine as _};
use rand::Rng;
use sha1::{Digest, Sha1};
use std::{
    io::{Read, Write},
    net::TcpListener,
    thread::sleep,
    time::Duration,
};

動作確認では、クライアントとしてWebSocket Test Clientを使いました。
ハンドシェイクとテキストデータの送受信だけできるシンプルなものですが、今回実装するエコーサーバ (ぽいもの) の動作確認としては十分でした。

サーバの実行

TcpListenerで、シングルスレッドのサーバを立ち上げます。
opening handshakeではHTTPで通信する必要があるので、フラグ (is_websocket) にてどちらのプロトコルで通信しているのか区別できるようにしています。

fn main() {
    let listener = TcpListener::bind("127.0.0.1:7778").unwrap();
    let mut buffer = [0; 4096];

    for stream in listener.incoming() {
        let mut stream = stream.unwrap();
        let mut is_websocket = false;

        loop {
            stream.read(&mut buffer).unwrap();
            if is_websocket {
                /* WebSocketでの処理 (後述) */
            } else {
                /* HTTPでの処理 (後述) */
            }
        }
    }
}

opening handshake (HTTP)

opening handshakeでは、クライアントから以下のようなGETリクエストが送られます。

  • Connection: Upgrade (必須)
  • Upgrade: websocket (必須)
  • Sec-WebSocket-Version: 13 (必須)
  • Sec-WebSocket-Key (必須)
  • Sec-WebSocket-Extensions (オプション)
GET ws://127.0.0.1:7778/ HTTP/1.1
Host: 127.0.0.1:7778
Connection: Upgrade
Upgrade: websocket
Sec-WebSocket-Version: 13
Sec-WebSocket-Key: 9Kl3Zz3tA0ibMWQwyn/9kQ==
Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits

実装としては以下のようになります。
httpクレートなどは使っていないので非常に素朴ですが、空行が現れるまでリクエストヘッダから1つずつ値を取り出しています。

/* HTTPでの処理 */
let mut method = None;
let mut upgrade = None;
let mut connection = None;
let mut sec_websocket_version = None;
let mut sec_websocket_key = None;

// リクエストのパース
let request_text = String::from_utf8_lossy(&buffer[..]);
for (i, line) in request_text.lines().enumerate() {
    if i == 0 {
        let values = line.split(" ").map(|s| s.trim()).collect::<Vec<&str>>();
        method = Some(values[0]);
        continue;
    }

    if line == "" {
        break;
    }

    let values = line.split(":").map(|s| s.trim()).collect::<Vec<&str>>();
    let key = values[0].to_ascii_lowercase();
    let value = values[1];
    if key == "upgrade" {
        upgrade = Some(value);
    }
    if key == "connection" {
        connection = Some(value);
    }
    if key == "sec-websocket-version" {
        sec_websocket_version = Some(value);
    }
    if key == "sec-websocket-key" {
        sec_websocket_key = Some(value);
    }
}

WebSocket接続を許可する場合、サーバからのレスポンスは以下のようになります。

  • ステータスコードは101
    • 続く値はなんでもいい (OKでもSwitching Protocolsでもなんでも)
  • Upgrade: websocket (必須)
  • Connection: Upgrade (必須)
  • Sec-WebSocket-Accept (必須)
    • リクエストのSec-WebSocket-Keyから計算された値をセットする
    • クライアントでもこの値を計算しており、一致しなければ、サーバ側にWebSocketに受け入れる準備がないと解釈されて、接続には失敗します
HTTP/1.1 101 OK
Upgrade: websocket
Connection: upgrade
Sec-WebSocket-Accept: EK2cqLXRG/oxQwrUdEVXGrPDBuA=

レスポンス部分の実装です。
Sec-WebSocket-Acceptヘッダの値ですが、Sec-WebSocket-KeyにRFCで定義された特定のGUID文字列(258EAFA5-E914-47DA-95CA-C5AB0DC85B11)を連結させ、さらにそれをSHA-1でエンコードして作っています。こうすることで、サーバが確実にWebSocket通信可能なことを確認しています。

/* HTTPでの処理 (続き) */

// レスポンスの作成と送信
// ex. 0CBldYnlIlaeSy6juzli7g== => 6mUsN+jbuye0zMbRm4w9VfzxDGM=
let plain_text = format!(
    "{}258EAFA5-E914-47DA-95CA-C5AB0DC85B11",
    sec_websocket_key.unwrap()
);
let mut hasher = Sha1::new();
hasher.update(plain_text);
let sec_websocket_accept = general_purpose::STANDARD.encode(hasher.finalize());

let response = format!("HTTP/1.1 101 OK\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {}\r\n\r\n", sec_websocket_accept);
stream.write(response.as_bytes()).unwrap();
stream.flush().unwrap();

// WebSocketモードにする
is_websocket = true;

WebSocketフレーム

opening handshakeの後は、WebSocketのフレームで通信します。フレームは以下のフォーマット (RFCからの抜粋) となっています。

  • FIN (1ビット)
    • 最後のフラグメントの場合、1
  • RSV1, RSV2, RSV3 (各1ビット、必須)
    • これらを使う拡張機能がネゴシエーションされなければ、0
  • opcode (4ビット、必須)
    • 0x0-0x7が非制御フレーム、0x8-0xFが制御フレームを表す
    • 0x0 ... 継続フレーム
    • 0x1 ... テキストフレーム
    • 0x2 ... バイナリフレーム
    • 0x8 ... クローズフレーム (後述)
    • 0x9 ... Pingフレーム
      • キープアライブを確認するための制御フレームで、任意のタイミングで送信できる
    • 0xA ... Pongフレーム
      • Pingフレームへの応答で、クローズフレームを受信していない場合はPingのペイロードと同じデータをすぐに返す
    • それ以外は予約領域
  • MASK (1ビット、必須)
    • Payload DataにXORマスクをかける場合、1
  • Payload len (7ビット、必須)
    • Payload Dataの長さを表す
    • ただし、126 or 127の場合、Extended payload lengthがPayload Dataの長さになる
  • Extended payload length (16ビット or 64ビット、オプション)
    • Payload lenが126なら、16ビット長 (最大2^16-1バイトまで表現)
    • Payload lenが127なら、64ビット長 (最大2^64-1バイトまで表現)
  • Masking-key (32ビット、MASK=1時に必須)
    • 受信側はPayload Dataを32ビットずつXORを取ってデコードする必要がある
  • Payload Data (最大Xバイト、オプション)
0                   1                   2                   3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len |    Extended payload length    |
|I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
|N|V|V|V|       |S|             |   (if payload len==126/127)   |
| |1|2|3|       |K|             |                               |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|     Extended payload length continued, if payload len == 127  |
+ - - - - - - - - - - - - - - - +-------------------------------+
|                               |Masking-key, if MASK set to 1  |
+-------------------------------+-------------------------------+
| Masking-key (continued)       |          Payload Data         |
+-------------------------------- - - - - - - - - - - - - - - - +
:                     Payload Data continued ...                :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|                     Payload Data continued ...                |
+---------------------------------------------------------------+

シンプルなヘッダとなっていて、最小2バイトに収まります。ただし、クライアントからのWebSocketフレームは必ずマスクする必要がある (All frames sent from client to server have this bit set to 1.) ので、最小でも6バイトになります。一方、サーバからクライアントへのフレームでは、通常マスクしません。

フレームの実装 (抜粋) は以下のようになっています。フレームのデシリアライズはFromトレイト、シリアライズはto_bytesメソッドで実装しています。

pub enum Opcode {
    Continuation, // %x0
    Text,         // %x1
    Binary,       // %x2
    Close,        // %x8
    Ping,         // %x9
    Pong,         // %xA
}

pub struct Frame {
    pub fin: bool,
    pub rsv1: bool,
    pub rsv2: bool,
    pub rsv3: bool,
    pub opcode: Opcode,
    pub mask: bool,
    pub payload_len: usize, // included extended payload length
    pub masking_key: Option<[u8; 4]>,
    pub payload_data: Vec<u8>, // decoded (with masking_key)
}

impl Frame {
    pub fn new(opcode: Opcode, payload_data: Option<Vec<u8>>) -> Self {
        let (payload_len, payload_data) = match payload_data, masking_key {
            Some(payload_data) => (payload_data.len(), payload_data),
            None => (0, vec![]),
        };

        Frame {
            fin: true, // Fragmentation is not supported, so always 1
            rsv1: false,
            rsv2: false,
            rsv3: false,
            opcode,
            mask,
            payload_len,
            masking_key: None,
            payload_data,
        }
    }

    pub fn to_bytes(self) -> Vec<u8> {
        let mut buffer = Vec::new();

        buffer.push(
            (self.fin as u8) << 7
                | (self.rsv1 as u8) << 6
                | (self.rsv2 as u8) << 5
                | (self.rsv3 as u8) << 4
                | u8::from(self.opcode),
        );

        if self.payload_len < 126 {
            buffer.push((self.mask as u8) << 7 | self.payload_len as u8)
        } else if self.payload_len < 65536 {
            buffer.push((self.mask as u8) << 7 | 126_u8);
            buffer.extend_from_slice((self.payload_len as u16).to_be_bytes().as_ref());
        } else {
            buffer.push((self.mask as u8) << 7 | 127_u8);
            buffer.extend_from_slice((self.payload_len as u64).to_be_bytes().as_ref());
        }

        if self.mask {
            buffer.extend(self.masking_key.unwrap().clone());
        }

        for (i, b) in self.payload_data.iter().enumerate() {
            buffer.push(if self.mask {
                b ^ self.masking_key.unwrap()[i % 4]
            } else {
                *b
            });
        }

        return buffer;
    }
}

impl From<&[u8]> for Frame {
    fn from(buffer: &[u8]) -> Self {
        let fin = buffer[0] & 0x80 != 0x00;
        let rsv1 = buffer[0] & 0x40 != 0x00;
        let rsv2 = buffer[0] & 0x20 != 0x00;
        let rsv3 = buffer[0] & 0x10 != 0x00;
        let opcode = Opcode::from(buffer[0]);

        let mask = buffer[1] & 0x80 != 0;

        let (payload_len, mut i) = match buffer[1] & 0x7F {
            0x7E => {
                let mut payload_len = [0; 2];
                payload_len.copy_from_slice(&buffer[2..4]);
                (u16::from_be_bytes(payload_len) as usize, 4)
            }
            0x7F => {
                let mut payload_len = [0; 8];
                payload_len.copy_from_slice(&buffer[2..10]);
                (usize::from_be_bytes(payload_len), 10)
            }
            n => (n as usize, 2),
        };
        let masking_key = if mask {
            let mut masking_key = [0; 4];
            masking_key.copy_from_slice(&buffer[i..i + 4]);
            i += 4;
            Some(masking_key)
        } else {
            None
        };
        let payload_data: Vec<u8> = if mask {
            buffer[i..i + payload_len]
                .iter()
                .enumerate()
                .map(|(i, b)| b ^ masking_key.unwrap()[i % 4])
                .collect()
        } else {
            buffer[i..i + payload_len].to_vec()
        };

        Frame {
            fin,
            rsv1,
            rsv2,
            rsv3,
            opcode,
            mask,
            payload_len,
            masking_key,
            payload_data,
        }
    }
}

サーバ側の処理は、opcodeに応じて分岐するものとなります。
テキストフレームの場合、即座に同じテキストメッセージを返し、さらに3秒後に同じテキストメッセージを再送しています。

/* WebSocketでの処理 */
let frame = Frame::from(&buffer[..]);

if frame.opcode == Opcode::Text {
    let payload_data = echo(frame.payload_data.as_slice());
    let response: Frame = Frame::new(Opcode::Text, Some(payload_data));

    stream.write(&response.clone().to_bytes()).unwrap();
    stream.flush().unwrap();

    sleep(Duration::from_secs(3));

    stream.write(&response.to_bytes()).unwrap();
    stream.flush().unwrap();
} else if frame.opcode == Opcode::Close {
    /* closing handshake (後述) */
} // さらに上記以外のフレームの処理など...

では、WebSocket Test Clientでメッセージを送って動作確認してみます。
ws://127.0.0.1:7778/でOpenすると、上述したリクエストヘッダが付与されてステータスコード101でレスポンスされることがわかります。

Alt text

適当な文字列をSendしてMessagesタブを開くと、テキストメッセージがやりとされていることも確認できます。

Alt text

なぜマスクするのか?

マスクは、通信経路上のセキュリティのために行われるそうです。
仲介するプロキシサーバにて、WebSocket通信をHTTP通信と誤認して、攻撃者がキャッシュさせたコンテンツ (任意のスクリプト) を他のユーザに読ませることができます。
そのため、予測不能なmasking-keyを使ってXORを取ることで、意図的にキャッシュヒットさせないようにしています。
詳しくは、RFCで参照されている論文や https://please-sleep.cou929.nu/websocket-protocol.html を読んでみてください。

フラグメンテーション

1つのフレームでほぼ無限の長さのペイロードを扱うことができますが、バッファリングせずに送ったりするユースケース (例えば、メッセージ開始時にペイロード長が不明の場合など) のために、WebSocketでは非制御フレームに限ってフラグメンテーションをサポートしています。

フラグメンテーションする場合、以下の順番でメッセージが送られます。順番が重要なので、送信順で受信する必要があります。

  1. FIN=0, opcode!=0x0 のメッセージ
  2. FIN=0, opcode=0x0 のメッセージ (複数)
  3. FIN=1, opcode=0x0 のメッセージ

なお、今回の実装ではフラグメンテーションはサポートしていません。

closing handshake (WebSocket)

切断はどちらからでも開始でき、お互いがクローズ制御フレーム (opcode=0x8) を送受信することで実現されます。

  1. ピアAが、クローズフレームを送信する
  2. ピアBは、クローズフレームを受信すると、未送信のフレームを送りきってから、クローズフレームを送信し、以降は何も送信しない
  3. ピアAは、クローズフレームを受信すると、それ以上フレームが送られてこないことを確信して、コネクションを閉じる

クライアントからクローズフレームを受信した時の実装を示します。

/* closing handshake */
let response = Frame::new(Opcode::Close, None);

stream.write(&response.to_bytes()).unwrap();
stream.flush().unwrap();

break;

実装にあたってのTips

RFCの翻訳

tex2eさんが作ってるRFC TransというWebページがとっても便利でした。

RFCのページをGoogle翻訳にかけてアップロードしてます。WebSocketはここで、ほぼこれを見ながらサーバを実装しました。
原文・翻訳の2カラム構成で、原文と突き合わせて (なんならDeepLにかけて) すぐに確認できる、というのが嬉しいポイントです。

WebSocket通信だけを見る

Google Chromeであれば デベロッパーツール > Network にて、WSでフィルタすればWebSocketの通信に絞ることができます。

Alt text

まとめ

Rustでサーバを実装しながら、WebSocketプロトコルについて学びました。
HTTPとTCPをベースにしたシンプルで軽量なプロトコルなので、(機能を絞ったとはいえ)実装でもさほど難しいところはなかったと思います。

参考

脚注
  1. お仕事で以前から運用していたKubeflowStreamlitなどで使われていました ↩︎

Discussion