🥊

作って意地でも理解する HTTP + WebSocketサーバーの仕組み

2025/01/16に公開

これは何

HTTPとWebSocketサーバーのコア部分を実装しながら仕組みを理解していきます。学習目的で実用ではありません。

HTTP、WebSocketはどちらもWEBエンジニアなら当たり前に使うプロトコルですが、仕組みはふわっとしか分かっていませんでした。そんな折サーバー周りを触る必要が出てきたので改めて理解しておきたくなり、せっかくなので自分でいちから実装してみることにしました。コード全体は以下から見れます。

実装する機能

ざっくりと以下のような機能を作っています。

HTTPサーバー

  • GETメソッドのみ、HTTP/1.1のみに対応
    • リクエストラインをパースして、リクエストに応じた処理を行う(GETのみ対応なのでレスポンスヘッダーを付けてファイルを返すだけ)
    • ヘッダーフィールドもパースするが、Websocket関連以外のフィールドは無視する

WebSocketサーバー

  • Websocketハンドシェイクの処理
    • WebSocketハンドシェイクはHTTPリクエストと同じ形式なので、HTTPサーバー部分で処理
    • サブプロトコルは実装しない
  • ハンドシェイク後のフレームのやり取り
    • テキストの送受信、クローズハンドシェイク、ping送信、pong受信を実装
    • javaのintサイズに収まらないバイト長のペイロードは非対応
    • メッセージフラグメンテーションは非対応

HTTPサーバー

まずはHTTPサーバーから作っていきましょう!大まかな処理の流れは以下の通りです。

  1. ソケットからの入力を受け取る
  2. HTTPリクエストのヘッダーをパースして、メソッド、ターゲット、HTTPバージョン、ヘッダーフィールド、ボディーを取得する
  3. リクエストの内容に応じた処理を行い、 HTTPレスポンスを返す

ソケットからの入力は、今回java.net.socketを使うのでjava.net.Socket.getInputStream()でストリームとして受け取れます。

HTTPリクエストのパース

実装を始めるために仕様を見ていきましょう。HTTPリクエストのフォーマットは以下のRFC9112で定義されています。
RFC 9112 - HTTP/1.1 #2.1. Message Format

  HTTP-message = start-line CRLF
                 *( field-line CRLF )
                 CRLF
                 [ message-body ]

図にすると以下のようなイメージです。field-line -> CRLFは0以上個の連続で、message-bodyはメソッドによって有無が変わります。CRはcarriage return、LFはline feedの略称です。
HTTPリクエストヘッダー1

複数行に渡っているため、各行ずつ順番に処理していきます。パースを行うには、一度すべての行をトークナイズしてから解析を行う方法と、ストリームを順番に読みながら逐次パースを行う手法(Scannerless parsing)があります。
一度にすべてのリクエストを受け取ると、めちゃくちゃデカいリクエストだったり、悪意のあるコードも一旦読み取る必要があり好ましくないので、今回はスキャンレスパースを用いて処理していきます。
1行目のstart-lineから処理していきます。RFC9112でstart-lineは以下の様に定義されています。

 start-line     = request-line / status-line

ますはHTTPリクエストを処理するパーサーを作るので、request-lineから見ていきましょう。

request-lineパーサー

request-lineは以下のように定義されています。
RFC 9112 - HTTP/1.1 #3. Request Line

  request-line   = method SP request-target SP HTTP-version

今わかっているところまでを図にすると以下の様なイメージです。SPはsingle spaceの略称です。
HTTPリクエストヘッダー2

methodはcase-sensitiveなトークンで、RFC9110にて定義されています。今回はGETのみ対応するため、"GET"がmethodの箇所にセットされていれば処理続行、それ以外のメソッドならエラーを返す方針でいきます。
RFC 9110 - HTTP Semantics #9. Methods

request-targetはメソッドを適用するリソースのありかを指定する箇所です。GETのみの対応なので、request-targetで指定されたリソースをそのままクライアントに返します。request-targetは以下のように複数指定方法がありますが、今回は最も一般的なorigin-formのみ対応します。

 request-target = origin-form
                / absolute-form
                / authority-form
                / asterisk-form

最後にHTTP-versionの定義は以下のとおりです。1.1のみ対応するので、基本的には"HTTP/1.1"かどうかを確認する実装になります。

HTTP-version  = HTTP-name "/" DIGIT "." DIGIT
HTTP-name     = %s"HTTP"

ここまでのパーサーを実装すると以下のようになります。基本的には1バイトずつ入力を読み取っていって文字列を作っていき、SPやCRが現れたらそれまでの文字列を元にHttpRequestオブジェクトを作っていくイメージです。HttpRequestはシンプルなデータストアです。

HttpParser.java
/**  
 * Parse HTTP Request line and update the provided HttpRequest object. 
 * 
 * @param reader  input stream reader from the TCP socket input stream  
 * @param request HttpRequest object to be updated  
 */
private void parseRequestLine(InputStreamReader reader, HttpRequest request) throws IOException, HttpParsingException {  
    StringBuilder processingDataBuffer = new StringBuilder();  
  
    boolean methodParsed = false;  
    boolean requestTargetParsed = false;  
  
    int b;  
    while ((b = reader.read()) != -1) {  
        if (b == SP) { // Tokenise the request line by SP  
            if (!methodParsed) {  
                request.setMethod(processingDataBuffer.toString());  
                methodParsed = true;  
            } else if (!requestTargetParsed) {  
                request.setRequestTarget(processingDataBuffer.toString());  
                requestTargetParsed = true;  
            } else {  
                throw new HttpParsingException(HttpStatusCode.CLIENT_ERROR_400_BAD_REQUEST);  
            }  
            processingDataBuffer.delete(0, processingDataBuffer.length());  
        } else if (b == CR) { // End of the request line  
            b = reader.read();  
            // Line feed must come after carriage return  
            if (b != LF || !methodParsed || !requestTargetParsed) {  
                throw new HttpParsingException(HttpStatusCode.CLIENT_ERROR_400_BAD_REQUEST);  
            }  
  
            // HTTP version is placed right before the CRLF in the first line  
            try {  
                request.setHttpVersion(processingDataBuffer.toString());  
            } catch (BadHttpVersionException e) {  
                throw new HttpParsingException(HttpStatusCode.CLIENT_ERROR_400_BAD_REQUEST);  
            }  
  
            return;  
        } else {  
            processingDataBuffer.append((char) b);  
            if (!methodParsed && processingDataBuffer.length() > HttpMethod.MAX_LENGTH) {  
                throw new HttpParsingException(HttpStatusCode.SERVER_ERROR_501_NOT_IMPLEMENTED);  
            }  
        }  
    }  
}
HttpRequest.java
/**  
 * A class to hold HTTP request data. 
 * */
public class HttpRequest extends HttpMessage {  
    private HttpMethod method;  
    private String requestTarget;  
    private HttpVersion httpVersion;  
  
    /**  
     * Set HTTP version by string.     
     *    
     * @param httpVersionString HTTP version as string  
     */    
	 public void setHttpVersion(String httpVersionString) throws HttpParsingException, BadHttpVersionException {  
        this.httpVersion = HttpVersion.fromString(httpVersionString);  
    }   
  
    /**  
     * Set HTTP method by string.     
     *     
     * @param methodName method name as string  
     */    
     void setMethod(String methodName) throws HttpParsingException {  
        for (HttpMethod method : HttpMethod.values()) {  
            if (methodName.equals(method.name())) {  
                this.method = method;  
                return;  
            }  
        }  
        throw new HttpParsingException(  
                HttpStatusCode.SERVER_ERROR_501_NOT_IMPLEMENTED  
        );  
    } 
    
    public void setRequestTarget(String requestTarget) throws HttpParsingException {  
        if (requestTarget == null || requestTarget.isEmpty()) {  
            throw new HttpParsingException(HttpStatusCode.SERVER_ERROR_500_INTERNAL_SERVER_ERROR);  
        }  
        this.requestTarget = requestTarget;  
    }  
     
	public HttpMethod getMethod() {  
        return method;  
    } 
  
    public HttpVersion getHttpVersion() {  
        return httpVersion;  
    }  
  
    public String getRequestTarget() {  
        return requestTarget;  
    }  
  
    /**  
     * Check that the HTTP request has a header field with the provided field name.     
     *    
     * @param fieldName Target header field name  
     * @return True if the request has header field  
     */    
     private boolean hasHeaderField(String fieldName) {  
        return getHeaderFields(fieldName) != null;  
    }  
  
  
    /**  
     * Check that the HTTP request has a header field with the specified value in the provided field name.     
     *     
     * @param fieldName Target header field name  
     * @param value     Target header field value  
     * @return True if the request has the value in the header field  
     */    
     private boolean hasHeaderValue(String fieldName, String value) {  
        String headerValue = getHeaderFields(fieldName);  
        return headerValue != null && headerValue.contains(value);  
    }  
}
HttpMessage.java
/**
 * Abstract class for HTTP response and request.
 */
public abstract class HttpMessage {
    private final HashMap<String, String> headerFields = new HashMap<>();
    private byte[] messageBody = new byte[0];

    public String toString() {
        String res = "";
        for (Map.Entry<String, String> entry : headerFields.entrySet()) {
            String key = entry.getKey();
            Object value = entry.getValue();
            res = String.format("%s%s: %s\n", res, key, value);
        }
        return res;
    }

    public Set<String> getHeaderFieldNames() {
        return headerFields.keySet();
    }

    public byte[] getMessageBody() {
        return messageBody;
    }

    public void setMessageBody(byte[] messageBody) {
        this.messageBody = messageBody;
    }

    /**
     * Add new header field.
     *
     * @param fieldName  Header field name
     * @param fieldValue Header field value
     */
    void addHeaderField(String fieldName, String fieldValue) {
        headerFields.put(fieldName.toLowerCase(), fieldValue);
    }

    /**
     * Get header field value by field name.
     *
     * @param fieldName Target header field name
     * @return Header field value
     */
    @Nullable
    public String getHeaderFields(String fieldName) {
        return headerFields.get(fieldName.toLowerCase());
    }
}

field-lineパーサー

2行目以降には0個以上のヘッダーフィールドが続きます。field-lineは以下のように定義されています。OWSはOptional Whitespaceで、人間用の読みやすさのためのスペースです。パースする際にはトリムの必要があります。
RFC 9112 - HTTP/1.1 #5. Field Syntax

field-line   = field-name ":" OWS field-value OWS

field-lineにフォーカスした図は以下のようになります。ヘッダーフィールドは連続したCRLFによってその終了が示されます。上のHttpMessageの実装にあるように、パースしたfield-namefield-valueはハッシュマップで保存しておきます。
HTTPリクエストヘッダー3

HttpParser.java
/**  
 * Parse HTTP header fields and update the provided HttpRequest object. 
 * 
 * @param reader  input stream reader from the TCP socket input stream  
 * @param request HttpRequest object to be updated  
 */
private void parseHeaderFields(InputStreamReader reader, HttpRequest request) throws IOException, HttpParsingException {  
    StringBuilder processingDataBuffer = new StringBuilder();  
    boolean crlfFound = false;  
  
    int b;  
    while ((b = reader.read()) >= 0) {  
        if (b == CR) {  
            b = reader.read();  
            // Line feed must come after carriage return  
            if (b != LF) {  
                throw new HttpParsingException(HttpStatusCode.CLIENT_ERROR_400_BAD_REQUEST);  
            }  
            // Two CRLF received, end of header fields section  
            if (crlfFound) return;  
            // Handle header field line
            crlfFound = true;  
            processSingleHeaderField(processingDataBuffer.toString(), request);  
            processingDataBuffer.delete(0, processingDataBuffer.length());  
        } else {  
            crlfFound = false;  
            processingDataBuffer.append((char) b);  
        }  
    }  
}  
  
/**  
 * Add header field and value pair to the provided HttpRequest object. 
 * 
 * @param rawFieldLine raw HTTP header field line  
 * @param request      HttpRequest object to be updated  
 */
private void processSingleHeaderField(String rawFieldLine, HttpRequest request) throws HttpParsingException {  
    String[] headerFieldParts = rawFieldLine.split(":", 2);  
    if (headerFieldParts.length != 2) {  
        throw new HttpParsingException(HttpStatusCode.CLIENT_ERROR_400_BAD_REQUEST);  
    }  
  
    String fieldName = headerFieldParts[0].trim();  
    String fieldValue = headerFieldParts[1].trim();  
    request.addHeaderField(fieldName, fieldValue);  
}

レスポンスを返す

HTTPレスポンスも以下の基本形に沿いますが、start-linestatus-lineになります。
GETリクエストに対しては、message-bodyとしてファイルを送信します。そしてmessage-bodyを持つとき、content-lengthcontent-typeのふたつのヘッダーフィールドも合わせて送信します。(Transfer-Encodingは今回考慮しません)
content-lengthmessage-bodyのバイト数、content-typeにはmedia-typeを付与します。content-typeフィールドは必須ではないですが一般的にはだいたい付与されています。
RFC 9112 - HTTP/1.1 #6. Message Body
RFC 9110 - HTTP Semantics #8.3. Content-Type

HTTP-message  = start-line CRLF
                *( field-line CRLF )
                CRLF
                [ message-body ]
			   
status-line = HTTP-version SP status-code SP [ reason-phrase ]

Content-Type = media-type

media-type = type "/" subtype parameters
type       = token
subtype    = token

今回はJavaの練習も兼ねているので、ビルダーパターンでレスポンスが構築できるように実装しました。
全部載せると長くなってしまうので割愛しますが、ファイル操作系の実装はsrc/main/java/com/server/core/io/WebRootHandler.javaに入っています。

HttpConnectionWorkerThread.java
/**  
 * Handle GET request and return HTTP response with body. 
 * 
 * @param request Whether the response includes the body  
 * @return HTTP response for the GET request  
 */
private HttpResponse handleGetRequest(HttpRequest request) {  
    try {  
        HttpResponse.Builder builder = new HttpResponse.Builder()  
                .httpVersion(request.getHttpVersion().literal)  
                .statusCode(HttpStatusCode.OK)  
                .addHeader(HttpHeaderFieldName.CONTENT_TYPE.headerName, webRootHandler.getFileMimeType(request.getRequestTarget()));  
        byte[] messageBody = webRootHandler.getFileByteArrayData(request.getRequestTarget());  
        builder.addHeader(HttpHeaderFieldName.CONTENT_LENGTH.headerName, String.valueOf(messageBody.length))  
                .messageBody(messageBody);  
        return builder.build();  
    } catch (FileNotFoundException e) {  
        return new HttpResponse.Builder()  
                .httpVersion(request.getHttpVersion().literal)  
                .statusCode(HttpStatusCode.CLIENT_ERROR_404_NOT_FOUND)  
                .build();  
    } catch (ReadFileException e) {  
        return new HttpResponse.Builder()  
                .httpVersion(request.getHttpVersion().literal)  
                .statusCode(HttpStatusCode.SERVER_ERROR_500_INTERNAL_SERVER_ERROR)  
                .build();  
    }  
}
HttpResponse.java
/**  
 * A class to hold HTTP response data. 
 * */
public class HttpResponse extends HttpMessage {  
    private final String httpVersion;  
    private final HttpStatusCode statusCode;  
  
    private HttpResponse(HttpResponse.Builder builder) {  
        this.httpVersion = builder.httpVersion;  
        this.statusCode = builder.statusCode;  
        builder.getHeaderFieldNames().forEach(name -> {  
            this.addHeaderField(name, builder.getHeaderFields(name));  
        });  
        this.setMessageBody(builder.getMessageBody());  
    }  
  
    public String getReasonPhrase() {  
        if (statusCode != null) {  
            return statusCode.MESSAGE;  
        }  
        return null;  
    }  
  
    /**  
     * Generate a byte array to be sent to clients.     
     *     
     * @return Byte array of the response data  
     */    
     public byte[] generateResponseBytes() {  
        StringBuilder responseBuilder = new StringBuilder();  
        String CRLF = "\r\n";  
  
        responseBuilder.append(httpVersion)  
                .append(" ")  
                .append(statusCode.STATUS_CODE)  
                .append(" ")  
                .append(getReasonPhrase())  
                .append(CRLF);  
        for (String headerName : getHeaderFieldNames()) {  
            responseBuilder.append(headerName)  
                    .append(": ")  
                    .append(getHeaderFields(headerName))  
                    .append(CRLF);  
        }  
        responseBuilder.append(CRLF);  
        byte[] responseBytes = responseBuilder.toString().getBytes();  
        if (getMessageBody().length == 0)  
            return responseBytes;  
        byte[] responseWithBody = new byte[responseBytes.length + getMessageBody().length];  
        System.arraycopy(responseBytes, 0, responseWithBody, 0, responseBytes.length);  
        System.arraycopy(getMessageBody(), 0, responseWithBody, responseBytes.length, getMessageBody().length);  
        return responseWithBody;  
    }  
  
    /**  
     * Builder of an HTTP response object.     
     * */
    public static class Builder extends HttpMessage {  
        private String httpVersion;  
        private HttpStatusCode statusCode;  
  
        public Builder httpVersion(String httpVersion) {  
            this.httpVersion = httpVersion;  
            return this;  
        }  
  
        public Builder statusCode(HttpStatusCode statusCode) {  
            this.statusCode = statusCode;  
            return this;  
        }  
  
        public Builder addHeader(String headerName, String headerField) {  
            this.addHeaderField(headerName, headerField);  
            return this;  
        }  
  
        public Builder messageBody(byte[] messageBody) {  
            this.setMessageBody(messageBody);  
            return this;  
        }  
  
        public HttpResponse build() {  
            return new HttpResponse(this);  
        }  
    }
}

ここまでの実装で、HTTPリクエストに対して正しくHTMLや画像ファイルが送信されるようになります。これでHTTPサーバーのコア部分は完了です。お疲れ様でした!

HTTPサーバー動作確認1
HTTPサーバー動作確認2

WebSocketサーバー

続けて本丸のWebSocketサーバーも作っていきましょう!大まかな処理の流れは以下のとおりです。

  1. ソケットからの入力を受け取る
  2. HTTPリクエストがWebSocketクライアントハンドシェイクなら、サーバーハンドシェイクのHTTPレスポンスを返す
  3. ハンドシェイクが完了したら、WebSocketフレーム受付を開始
  4. 送信されたフレームをパース、必要に応じてフレームを送り返す

今回のサーバーには以下の機能を実装します。

  • request-targetに関係なくWebSocketハンドシェイクを受付
  • テキストを受け取ったら、同じテキストに飾りを付けて送り返す
  • 5秒に一回サーバーから自動でpingを送信する
  • 接続解除ハンドシェイクを受信したら、接続解除ハンドシェイクを送り返す

ハンドシェイクの処理

WebSocketのハンドシェイクはシンプルで、特別な条件を満たすHTTP GETリクエストをクライアントが送信し、その内容に問題がなければサーバーからハンドシェイクを送信、その後WebSocketフレームを使った双方向のやり取りを行う、という流れになります。
どちらのピアから送信されるハンドシェイクかをわかりやすくするため、ここではクライアントから送信されるものを「クライアントハンドシェイク」、サーバーからは「サーバーハンドシェイク」と呼ぶことにします。
※RFC6455では"The handshake from the client", "The handshake from the server"などとそれぞれ記載されています。

WebSocketプロトコルへの変更時、クライアントハンドシェイクが以下の様に送られてきます。

GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Version: 13

HTTPからのシームレスな切り替えを行えるようにするため、HTTPリクエストに沿った形式が取られていますが、以下の条件を満たしている必要があります。(重要な箇所のみ抜粋、詳細はRFC 6455 - The WebSocket Protocol #4.2.1. Reading the Client's Opening Handshakeを参照)

  • GETメソッドである
  • ヘッダーフィールドにHostを持つ
  • ヘッダーフィールドにUpgradeを持ち、その値はwebsocketを含む
  • ヘッダーフィールドにConnectionを持ち、その値はUpgradeを含む
  • ヘッダーフィールドにSec-WebSocket-Keyを持ち、その値はコネクションごとにランダムに選ばれた16byte長のnonceをbase64エンコードしたもの
  • WEBブラウザからのリクエストの場合、ヘッダーフィールドにOriginを持つ
  • ヘッダーフィールドにSec-WebSocket-Versionを持ち、その値は13である

これらをHTTPリクエストヘッダのパース後に確認し、全てに当てはまる場合は正しいクライアントハンドシェイクとみなします。その後サーバーハンドシェイクを送信し、WebSocketプロトコルへの切り替えを完了します。

まずはHTTPリクエストがWebSocketハンドシェイクかどうか確認する必要があるため、以下のメソッドをHttpRequestクラスに生やしました。
Sec-WebSocket-Keyはnonceが元になっていて具体的な値の検証はできないので、長さを確認することにします。(16バイトのデータをBase64でエンコードすると24文字になる)

HttpRequest.java
/**  
 * Check if the HTTP request is a WebSocket handshake. 
 * 
 * @return True if it is a WebSocket handshake  
 */
public boolean isWebsocketHandshake() {  
    final String websocketKeyValue = getHeaderFields("Sec-WebSocket-Key");  
    return hasHeaderField("Host")  
            && hasHeaderValue("Upgrade", "websocket")  
            && hasHeaderValue("Connection", "Upgrade")  
            && hasHeaderField("Origin")  
            && hasHeaderField("Sec-WebSocket-Key")  
            && websocketKeyValue != null
            && websocketKeyValue.length() == 24 // Base64 encoded 16 byte nonce should have 24 characters  
            && "13".equals(getHeaderFields("Sec-WebSocket-Version"));  
}

WebSocketクライアントハンドシェイクに問題がないことを確認したら、サーバーハンドシェイクを送り返します。サーバーハンドシェイクもHTTPレスポンスがベースになっていて、以下の条件を満たしている必要があります。(重要箇所のみ、詳細はRFC 6455 - The WebSocket Protocol #4.2.2. Sending the Server's Opening Handshake を参照)

  • 101レスポンス
  • ヘッダーフィールドにUpgradeを持ち、その値はwebsocketを含む
  • ヘッダーフィールドにConnectionを持ち、その値はUpgradeを含む
  • ヘッダーフィールドにSec-WebSocket-Acceptを持ち、その値はリクエストで受け取ったSec-WebSocket-Keyの値の後ろに"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"を結合し、SHA-1でハッシュ化してbase64エンコードしたもの

Sec-WebSocket-Keyフィールドの処理が若干ややこしいですが、サーバーがWebSocketに対応していて、特定のクライアントとコネクションを許可したかどうかの判定に使われます。
WebSocketハンドシェイク完了後にはWebSocketフレーム処理担当のスレッドを起動し、ソケットを渡しておきます。

HttpConnectionWorkerThread.java
/**  
 * Handle WebSocket upgrade request. Start a worker thread to listen WebSocket frames. 
 * 
 * @param request WebSocket handshake from the client  
 */
private void handleWebSocketUpgradeRequest(HttpRequest request) throws HttpParsingException, IOException {  
    HttpResponse handshakeResponse = new HttpResponse.Builder()  
            .httpVersion(request.getHttpVersion().literal)  
            .statusCode(HttpStatusCode.WEBSOCKET_UPGRADE)  
            .addHeader("Upgrade", "websocket")  
            .addHeader("Connection", "Upgrade")  
            .addHeader("Sec-WebSocket-Accept", request.generateSecWebsocketAcceptFieldValue())  
            .build();  
            
    sendResponse(handshakeResponse);  
    WebsocketWorkerThread websocketWorkerThread = new WebsocketWorkerThread(socket);  
    websocketWorkerThread.start();  
}
HttpRequest.java
/**  
 * Generate the value for the `Sec-WebSocket-Accept` header field of the WebSocket server handshake. 
 * 
 * @return Handshake header field value for `Sec-WebSocket-Accept`  
 */
 public String generateSecWebsocketAcceptFieldValue() throws HttpParsingException {  
    final String websocketKeyValue = getHeaderFields("Sec-WebSocket-Key");  
    if (websocketKeyValue == null || websocketKeyValue.length() != 24) {  
        // Base64 encoded 16 byte nonce should have 24 characters  
        throw new HttpParsingException(HttpStatusCode.CLIENT_ERROR_400_BAD_REQUEST);  
    }  
    final byte[] hash = DigestUtils.sha1(websocketKeyValue + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");  
    return Base64.getEncoder().encodeToString(hash);  
}

データの送受信

ハンドシェイクの完了後、クライアント、サーバー両方からメッセージの送受信ができるようになります。メッセージは1つ以上のフレームによって構成されていて、フレームの構造は以下のように定義されています。
RFC 6455 - The WebSocket Protocol #5.2. Base Framing Protocol

      0                   1                   2                   3
      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
     +-+-+-+-+-------+-+-------------+-------------------------------+
     |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
     |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
     |N|V|V|V|       |S|             |   (if payload len==126/127)   |
     | |1|2|3|       |K|             |                               |
     +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
     |     Extended payload length continued, if payload len == 127  |
     + - - - - - - - - - - - - - - - +-------------------------------+
     |                               |Masking-key, if MASK set to 1  |
     +-------------------------------+-------------------------------+
     | Masking-key (continued)       |          Payload Data         |
     +-------------------------------- - - - - - - - - - - - - - - - +
     :                     Payload Data continued ...                :
     + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
     |                     Payload Data continued ...                |
     +---------------------------------------------------------------+

ヘッダーの最短長は2バイトで、2バイト目まではどのビットに何の情報が入っているかが固定されています。3バイト目以降は2バイト目の情報を元に構築されます。

1バイト目

  • FIN(1 bit): これがシリーズの最後のメッセージかどうかのフラグ(今回はFIN=1のみ実装)
  • RSV1 ~ RSV3(1 bit): 将来の拡張用のため無視
  • opcode(4 bits): ペイロードデータの解釈方法およびコントロールフレームの指定
    • 0x0: 継続 (FIN=0との組み合わせで使われる、今回は実装しない)
    • 0x1: UTF-8テキスト
    • 0x2: バイナリ (今回は実装しない)
    • 0x8: クローズ
    • 0x9: ping
    • 0xA: pong

2バイト目

  • mask(1 bit): メッセージがエンコードされているかどうかのフラグ
    • クライアントからのフレームは常にマスクされている必要があるので、サーバー側では基本的にこのフラグが1かどうかを確認する
    • 逆にサーバーからのフレームは常にマスクされていない必要がある
  • payload length(7 bits): ペイロードデータのバイトサイズ
    • ビット9~15までを符号なし整数として読み取る
      • 125以下:それがそのまま長さになる
      • 126:次の2バイトを符号なし整数として解釈したものが長さになる
      • 127:次の8バイトを符号なし整数として解釈したものが長さになる

3バイト目以降

  • Extended payload length(16~64 bits): 上記の通り、payload lengthが126なら2バイト分、127なら8バイト分を符号なし整数として読み取ると最終的なpayload lengthを得られる
  • Masking key(32 bits): MASKフラグがセットされている場合、マスクキーをセットする
    • マスクキーはエンコードされたペイロードデータに対して1バイト単位でループしXOR計算を行なうことでデコードされた文字列を得られる
  • Payload data: メッセージの本体 前述のpayload lengthで得られたバイト長のメッセージが格納されている

WebSocketフレームパーサーの実装

上記の仕様に従ってパーサーを実装していきます。
基本的には最初の2バイトにはビットフラグと整数に変換したほうが楽な値が混ざっているのでビットマスクで処理し、それ以降は適宜バイト列に保存しつつ扱いやすいように変換していきます。
クライアントからのフレームは必ずマスクされている必要があるので、受け取り時にそのデコードを行います。decoded[i] = payload[i] \oplus mask[i\mod{4}] でデコードできます。

WebSocketParser.java
public class WebSocketParser {
    /**
     * The entry point of the WebSocket parser. Does not support following features.
     * <ul>
     * <li>Message fragmentation</li>
     * <li>Bytes of payload longer than the int-holdable length</li>
     * </ul>
     *
     * @param inputStream input stream from the TCP socket
     * @return Parsed WebSocketFrame object
     */
    public WebSocketFrame parseWebsocketFrame(InputStream inputStream) throws WebSocketParsingException, IOException {
        // First byte: FIN flag, opcode
        final int firstByte = inputStream.read();
        boolean fin = (firstByte & 0b10000000) != 0;
        Opcode opcode = Opcode.fromCode(firstByte & 0b00001111);

        // Second byte: MASK flag, payload length
        final int secondByte = inputStream.read();
        boolean mask = (secondByte & 0b10000000) != 0;
        if (!mask) {
            throw new WebSocketParsingException("Frame from client must be masked.");
        }
        int payloadLength = (secondByte & 0b01111111);

        // Handle extended payload length
        if (payloadLength == 126) {
            byte[] extendedPayloadLengthBytes = new byte[2];
            inputStream.read(extendedPayloadLengthBytes, 0, 2);
            payloadLength = ((extendedPayloadLengthBytes[0] & 0xFF) << 8) |
                    (extendedPayloadLengthBytes[1] & 0xFF);
        } else if (payloadLength == 127) {
            byte[] extendedPayloadLengthBytes = new byte[8];
            inputStream.read(extendedPayloadLengthBytes, 0, 8);
            payloadLength = 0;
            for (byte b : extendedPayloadLengthBytes) {
                payloadLength = (payloadLength << 8) | (b & 0xFF);
            }
        }

        // Masking key
        byte[] maskingKey = new byte[4];
        int bytesRead = inputStream.read(maskingKey);
        if (bytesRead != 4) {
            throw new WebSocketParsingException("Failed to read masking key");
        }

        // Get payload and decode by the masking key
        byte[] decodedPayload = new byte[payloadLength];
        for (int i = 0; i < payloadLength; i++) {
            final byte b = (byte) inputStream.read();
            if (b == -1) {
                throw new WebSocketParsingException("Unexpected end of stream");
            }
            final int maskByte = maskingKey[i % 4];
            decodedPayload[i] = (byte) (b ^ maskByte);
        }

        return new WebSocketFrame.Builder().fin(fin).opcode(opcode).payload(decodedPayload).build();
    }
}

WebSocketフレームビルダーの実装

クライアントから受け取ったメッセージを元にサーバーからテキストメッセージを送り返したいので、フレームの構築処理も作っていきましょう。
HttpResponseクラスと同様に、フレームを送信する際にオブジェクトからバイト配列に変換する必要があるので、WebSocketFrameクラスにgenerateFrameBytesメソッドを生やしています。
基本的にはパーサーの逆の処理で特筆すべき箇所はないですが、サーバーからのフレームのpayloadはマスクされていない必要がある点が異なります。

WebSocketFrame.java
public class WebSocketFrame {
    private final boolean fin;
    private final boolean mask;
    private final Opcode opcode;
    private final byte[] payload;

    private WebSocketFrame(Builder builder) {
        this.fin = builder.fin;
        this.mask = builder.mask;
        this.opcode = builder.opcode;
        this.payload = builder.payload;
    }


    public Opcode getOpcode() {
        return opcode;
    }

    public byte[] getPayload() {
        return payload;
    }

    public String getPayloadAsString() {
        return new String(payload, StandardCharsets.UTF_8);
    }

    /**
     * Generate a byte array of the WebSocket frame.
     *
     * @return a byte array of the WebSocket frame
     */
    public byte[] generateFrameBytes() {
        ArrayList<Byte> bytes = new ArrayList<>();
        // First byte: FIN flag(1st bit), opcode(5th - 8th bits)
        // 2nd - 4th bits are for RSV, which are not used
        byte firstByte = (byte) (fin ? 0b10000000 : 0b00000000);
        bytes.add((byte) (firstByte | opcode.code));

        // Second byte: MASK flag(1st bit), payload length(2nd - 8th bits)
        // payload from server side must not be masked
        byte secondByte = (byte) (mask ? 0b10000000 : 0b00000000);
        final int length = payload.length;
        if (length <= 125) {
            bytes.add((byte) (secondByte | payload.length));
        } else if (length <= 65535) { // 2^16 - 1
            bytes.add((byte) (secondByte | 126));
            // Insert 2 bytes for extended payload length
            bytes.add((byte) ((length >> 8) & 0xFF));
            bytes.add((byte) (length & 0xFF));
        } else {
            // As the max. possible length of payload is 2^64, we omit the payload length check for it
            // Insert 8 bytes for extended payload length
            bytes.add((byte) (secondByte | 127));
            for (int i = 7; i >= 0; i--) {
                bytes.add((byte) ((length >> (i * 8)) & 0xFF));
            }
        }

        // Add masking bytes
        // Server must not mask the payload, so this is only for the testing purpose
        if (mask) {
            for (int i = 0; i < 4; i++) bytes.add((byte) 0x00);
        }

        // Append payload
        for (byte b : payload) bytes.add(b);

        // Convert the arraylist to array
        byte[] res = new byte[bytes.size()];
        for (int i = 0; i < res.length; i++) {
            res[i] = bytes.get(i);
        }
        return res;
    }

    /**
     * Builder of a WebSocketFrame object.
     */
    public static class Builder {
        private boolean fin = true;
        private boolean mask = false; // mask should be false as default since server cannot send masked payload
        private Opcode opcode;
        private byte[] payload;

        public Builder fin(boolean fin) {
            this.fin = fin;
            return this;
        }

        public Builder opcode(Opcode opcode) {
            this.opcode = opcode;
            return this;
        }

        // only for testing purpose
        public Builder mask(boolean mask) {
            this.mask = mask;
            return this;
        }

        public Builder payload(byte[] payload) {
            this.payload = payload;
            return this;
        }

        public WebSocketFrame build() {
            if (opcode == null) {
                throw new NullPointerException();
            }
            if (payload == null) {
                payload = new byte[0];
            }
            return new WebSocketFrame(this);
        }

    }
}

コントロールフレームの実装

サーバーからメッセージを送信できるようになったので、コントロールフレームも実装してみます。まずはping送信です。
pingはopcodeが0x9のフレームで、pingフレームを受け取ったクライアントはすぐに同じpayloadを持ったpong(opcodeが0xA)のフレームを送り返します。その名の通り相手がまだ接続しているか確認するために使われます。
RFC 6455 - The WebSocket Protocol #5.5.2. Ping
今回は5秒に一回サーバーからクライアントにpingを送信する機能を付けてみました。余談ですが、javascriptのWebSocketオブジェクトはping受信でちゃんとpongを送り返してくれるのですね。ping/pongを手動で送受信するAPIが無いようなのでどうかなーと思いつつ実装したら、pongが帰ってきたのでちょっと驚きました。でも確かにping/pongのほとんどの用途はサーバーからのクライアントの生存確認で、クライアントからpingを送信するシーンは通常ならほぼないため、APIを用意するまでもなくping受け取りですぐpong返却する実装で納得です。

WebsocketWorkerThread.java
/**
 * Set the scheduler to send a ping frame to the client every 5 sec.
 */
private void sendPing() {
	executor.scheduleAtFixedRate(() -> {
		WebSocketFrame pingFrame = new WebSocketFrame.Builder().fin(true).opcode(Opcode.PING).build();
		try {
			sendFrame(pingFrame);
		} catch (IOException | WebSocketParsingException e) {
			throw new RuntimeException(e);
		}
		LOGGER.info("Ping frame sent.");
	}, 0, 5, TimeUnit.SECONDS);
}
20:53:32.884 [pool-1-thread-1] INFO com.server.core.WebsocketWorkerThread -- Ping frame sent.
20:53:32.884 [Thread-5] INFO com.server.core.WebsocketWorkerThread -- Pong frame received.
20:53:37.883 [pool-1-thread-1] INFO com.server.core.WebsocketWorkerThread -- Ping frame sent.
20:53:37.883 [Thread-5] INFO com.server.core.WebsocketWorkerThread -- Pong frame received.
...

続いてクローズフレームです。クローズフレームを受け取ったサーバーは、クライアントにクローズフレームを送り返す必要があります。双方がクローズフレームを受け取った時点でWebSocket接続は終了したと判断し、サーバー側はTCPコネクションを切断します。今回はクライアントからクローズフレームが送信されてきた場合のみ実装しています。(フレーム受け取り処理全体のコードですがご勘弁ください)

WebsocketWorkerThread.java
public WebsocketWorkerThread(Socket socket) throws IOException {  
    this.socket = socket;  
    this.inputStream = socket.getInputStream();  
    this.outputStream = socket.getOutputStream();  
}  
  
    @Override  
    public void run() {  
        try {  
            LOGGER.info("WebSocket worker thread started.");  
            sendPing();  
            while (!Thread.currentThread().isInterrupted()) {  
                WebSocketFrame clientFrame = parser.parseWebsocketFrame(inputStream);  
                handleFrame(clientFrame);  
                if (clientFrame.getOpcode() == Opcode.CLOSE) break;  
            }  
        } catch (IOException e) {  
            LOGGER.error("Error in WebSocket worker thread: ", e);  
        } catch (WebSocketParsingException e) {  
            LOGGER.error("WebSocket parsing error: ", e);  
        } finally {  
            try {  
                socket.close();  
            } catch (IOException e) {  
                throw new RuntimeException(e);  
            }  
        }  
    }

    /**
     * Handle the WebSocket frame sent by the client and respond to it based on the opcode.
     *
     * @param clientFrame WebSocket frame from the client
     */
    private void handleFrame(WebSocketFrame clientFrame) throws WebSocketParsingException, IOException {
        WebSocketFrame serverFrame = null;
        switch (clientFrame.getOpcode()) {
            case Opcode.TEXT -> {
                LOGGER.info("Payload string: {}", clientFrame.getPayloadAsString());
                // Build a "duck say" payload based on the client frame payload
                final byte[] serverPayload = Duck.Say(clientFrame.getPayload());
                serverFrame = new WebSocketFrame.Builder().fin(true).opcode(Opcode.TEXT).payload(serverPayload).build();
            }
            case Opcode.CLOSE -> {
                LOGGER.info("Close frame received.");
                // Build a close response frame
                serverFrame = new WebSocketFrame.Builder().fin(true).opcode(Opcode.CLOSE).build();
            }
            case Opcode.PONG -> LOGGER.info("Pong frame received.");
            default -> throw new WebSocketParsingException("Unknown opcode " + clientFrame.getOpcode());
        }
    
        if (serverFrame != null) {
            sendFrame(serverFrame);
        }
    }
...
}

これで以下のようにブラウザとサーバーでHTTPとWebSocket通信ができるようになりました。簡易的でも自力で実装したサーバーがブラウザと会話できているのを見るのはグッとくるものがありますね!

WebSocket動作確認

参考

RFC 9112 - HTTP/1.1
RFC 9110 - HTTP Semantics
GitHub - CoderFromScratch/simple-java-http-server: Create a Simple HTTP Server in Java Tutorial Series
RFC 6455 - The WebSocket Protocol
WebSocket サーバーを書く - Web API | MDN
Java で WebSocket サーバーを書く - Web API | MDN

Discussion