🐈‍⬛

JUnitでWebSocketのテストを行う

に公開

概要

Spring Boot で構築した WebSocket アプリケーションを、JUnitでテストします。

環境

  • Java 21
  • Gradle 8.5
  • Spring Boot 3.5.6

構成

websocket-test-example
|-- src
     |-- main
     |    |-- java
     |          |-- org
     |                |-- example
     |                      |-- Application.java
     |                      |-- ExampleController.java  // コントローラ
     |                      |-- StompConfig.java        // STOMP 設定
     |-- test
          |-- java
                |-- org
                      |-- example
                            |-- ExampleControllerTest.java  // コントローラの WebSocket 通信テスト
                            |-- TestStompFrameHandler.java  // テスト用のハンドラ

依存関係

dependencies {
    implementation("org.springframework.boot:spring-boot-starter-web:3.5.6")
    implementation("org.springframework.boot:spring-boot-starter-websocket:3.5.6")
    testImplementation("org.springframework.boot:spring-boot-starter-test:3.5.6")
    testRuntimeOnly("org.junit.platform:junit-platform-launcher:1.11.3")
}

プログラム

Application.java

package org.example;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class Application {
    public static void main(String[] args) {
        SpringApplication.run(Application.class, args);
    }
}

ExampleController.java

package org.example;

import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Controller;

@Controller
public class ExampleController {
    private final SimpMessagingTemplate template;

    public ExampleController(SimpMessagingTemplate template) {
        this.template = template;
    }

    @MessageMapping("message")
    public void message(String request) {
        template.convertAndSend("/topic/response", request + " を受信しました");
    }
}

@MessageMapping を用いて WebSocket による通信を行います。レスポンスを返すのには SimpMessagingTemplate#convertAndSend を使用します。

StompConfig.java

package org.example;

import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;

@Configuration
@EnableWebSocketMessageBroker
public class StompConfig implements WebSocketMessageBrokerConfigurer {
    @Override
    public void registerStompEndpoints(StompEndpointRegistry registry) {
        registry.addEndpoint("/ws").setAllowedOriginPatterns("*").withSockJS();
    }

    @Override
    public void configureMessageBroker(MessageBrokerRegistry registry) {
        registry.setApplicationDestinationPrefixes("/app");
        registry.enableSimpleBroker("/topic");
    }
}

STOMP の設定です。ExampleController.java で使用している SimpMessagingTemplate@EnableWebSocketMessageBroker を付与した設定を読み込みます。

registerStompEndpoints では接続エンドポイントを登録し、configureMessageBroker でルーティングを設定します。

ExampleControllerTest.java

package org.example;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.messaging.converter.StringMessageConverter;
import org.springframework.messaging.simp.stomp.StompSession;
import org.springframework.messaging.simp.stomp.StompSessionHandlerAdapter;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.messaging.WebSocketStompClient;
import org.springframework.web.socket.sockjs.client.SockJsClient;
import org.springframework.web.socket.sockjs.client.Transport;
import org.springframework.web.socket.sockjs.client.WebSocketTransport;

import java.util.List;
import java.util.concurrent.*;

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class ExampleControllerTest {
    @LocalServerPort
    private int port;

    @Test
    public void test() throws Exception {
        List<Transport> transports = List.of(new WebSocketTransport(new StandardWebSocketClient()));
        SockJsClient sockJsClient = new SockJsClient(transports);
        WebSocketStompClient client = new WebSocketStompClient(sockJsClient);
        client.setMessageConverter(new StringMessageConverter());

        BlockingQueue<String> queue = new LinkedBlockingQueue<>();
        StompSession session = client.connectAsync("ws://localhost:%d/ws".formatted(port), new StompSessionHandlerAdapter() {}).get(3, TimeUnit.SECONDS);
        TestStompFrameHandler<String> handler = new TestStompFrameHandler<>(queue, String.class);
        session.subscribe("/topic/response", handler);

        String message = "テスト";
        session.send("/app/message", message);
        String response = queue.poll(3, TimeUnit.SECONDS);
        Assertions.assertEquals("テスト を受信しました", response);

        client.stop();
        session.disconnect();
    }
}

WebSocket のテストを行うには SpringBootTest.WebEnvironment.RANDOM_PORT を用いてランダムポートを設定します。

後述する TestStompFrameHandler を使用して受信を BlockingQueue で行うようにしています。また、サーバが text/plain を返すため StringMessageConverter を使用しています。

TestStompFrameHandler.java

package org.example;

import org.springframework.lang.Nullable;
import org.springframework.messaging.simp.stomp.StompFrameHandler;
import org.springframework.messaging.simp.stomp.StompHeaders;

import java.lang.reflect.Type;
import java.util.concurrent.BlockingQueue;

public class TestStompFrameHandler<T> implements StompFrameHandler {
    private final BlockingQueue<T> blockingQueue;
    private final Class<T> type;

    public TestStompFrameHandler(BlockingQueue<T> blockingQueue, Class<T> type) {
        this.blockingQueue = blockingQueue;
        this.type = type;
    }

    @Override
    public Type getPayloadType(StompHeaders headers) {
        return type;
    }

    @Override
    public void handleFrame(StompHeaders headers, @Nullable Object payload) {
        blockingQueue.offer((T) payload);
    }
}

テスト用のハンドラです。受信したフレームを BlockingQueue にセットし、テストメソッドから参照できるようにしています。

Discussion