🦙

JavaでローカルLLMを動かす(llama.cpp, Spring AI)

2023/12/31に公開

はじめに:JavaとローカルLLM

生成AI、大規模言語モデル(LLM)が大流行りの一年でしたね。

LLMを扱うプログラミング言語としてはPythonでの情報が盛んでした。Pythonも素晴らしい言語ですが、今後LLMがエンタープライズシステムに導入されていくことを考えると、やはりJavaのエコシステムは外せないな、と思い、そうなるとJavaでどの程度LLMを使えるかな?が気になります。

また、OpenAIやAzure OpenAI Serviceを使用しない、オンプレミスで動く「ローカルLLM」という領域も興味があります。Javaが活躍している領域では、しばし、セキュアな環境での開発が求められ、LLMが導入しにくいという話も聞きます。そういった閉域でも導入しやすいというLLMで、どこまでできるかも試してみたいです。

加えて、Spring Framework界隈では、Spring AI の開発が進行中です。最近、Josh LongさんがSpring AIの紹介をしていて、少し触ってみようと思いました。

ということで、ローカルLLMをJavaで、SpringAIも含め、動かしてみます。

技術要素

今回は、以下の技術要素の組み合わせです。

  • ELYZA-japanese-Llama-2-13b
    • Meta社のLlama2をベースに日本語チューニングしたLLMモデルです。2023.12.27にリリースされた、パラメータ数13Bは、GPT3.5並に頭が良くなってきたと話題のモデルです。
  • Java Bindings for llama.cpp
    • ローカルLLMの一つであるLllam2と、それ扱うC++のライブラリllama.cpp を、Javaからでも使えるようにしたものです。
  • Spring AI
    • LLMを操作する抽象インターフェースと便利な機能を提供。今回はインターフェース導入のみです。
  • Java21/Spring Framework 6/Spring Boot3

忙しい方のために

作ったコードを以下においておきました。以下の3ステップで、ローカルLLMへ質問と回答ができます。

  1. LLMモデルをダウンロード
  2. Spring Bootアプリを起動
  3. RESTで質問を送信

https://github.com/hide212131/java-ai-llamacpp-helloworld.git

macosで試していますが、多分WindowsやLinuxでも動くはずです。
事前にJavaの開発環境はインストールしておく必要があります。

以下手順です。

  1. ここ をクリックして、ELYZA-japanese-Llama-2-13b-fast-instruct-q4_K_M.ggufローカルLLMのモデルをダウンロードします。

    • (GPUメモリが少なめのPCの場合は、ここからファイルサイズの小さいものをダウンロードするのが良いでしょう)
  2. リポジトリからgit cloneし、環境変数を設定し、Spring Bootを立ち上げます。

    git clone https://github.com/hide212131/java-ai-llamacpp-helloworld.git
    cd java-ai-llamacpp-helloworld
    
    export SPRING_AI_LLAMA_CPP_MODEL_HOME='/path/to/model' #モデルをおいたフォルダ
    export SPRING_AI_LLAMA_CPP_MODEL_NAME='ELYZA-japanese-Llama-2-13b-fast-instruct-q4_K_M.gguf' #モデルのファイル名
    
    ./mvnw spring-boot:run
    
  3. 別のターミナルから、"hello!" と伝えると、回答が帰ってきます。

    curl http://localhost:8080/ai/simple\?message=hello\!
    

    {"completion":"\n\n今日は、私の好きな映画について紹介します。\n「君の名」は、2013年公開の日本の映画です。\n私はこの映画の主人公のように、自分の力で運命を変えていきたいと思っています。"}

    日本語トレーニングしたLLMだけあって、回答も日本語色強めですね。

開発手順

方針

以下のSpring AIのOpenAI対応のコードを参考に、llama.cpp対応を真似てみます。

実装

Spring Bootの環境構築

Spring Initializrより

  • Maven / Spring Boot3.2.1 / Java21 を選択します
  • Dependency として以下を選択します
    • Spring Web
    • Spring Reactive Web

依存関係の追加

展開したプロジェクトのpom.xmlに、Java Bindings for llama.cppの依存関係をつけます。

        <!-- https://github.com/kherud/java-llama.cpp#quick-start -->
        <dependency>
            <groupId>de.kherud</groupId>
            <artifactId>llama</artifactId>
            <version>2.3.1</version>
        </dependency>

Spring AIは、現時点では正式リリース前であり、Spring InitializrやMaven Central Repositoryに登録されていないため、依存関係やリポジトリは手動でつけます。

    <dependencies>
        <!-- https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-starters/spring-ai-starter-openai/pom.xml -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-core</artifactId>
            <version>0.8.0-SNAPSHOT</version>
        </dependency>
    </dependencies>

    <!-- https://docs.spring.io/spring-ai/reference/getting-started.html#_dependencies -->
    <repositories>
        <repository>
            <id>spring-snapshots</id>
            <name>Spring Snapshots</name>
            <url>https://repo.spring.io/snapshot</url>
            <releases>
                <enabled>false</enabled>
            </releases>
        </repository>
    </repositories>

コードの追加

Spring AIのインターフェースChatClientに合わせて、Llamacppクライアントのコードを書いていきます。

public class LlamaCppChatClient implements ChatClient {
    // プロンプトからレスポンスを生成
    @Override
    public ChatResponse generate(Prompt prompt) {
        LlamaModel.setLogger((level, message) -> System.out.print(message));
        var modelParams = new ModelParameters()
                .setNGpuLayers(1);
        var inferParams = new InferenceParameters()
                .setTemperature(0.7f)
                .setPenalizeNl(true)
//                .setNProbs(10)
                .setMirostat(InferenceParameters.MiroStat.V2)
                .setAntiPrompt("User:");

        //ダウンロードしたLLMモデルのファイル
	var modelPath = "/path/to/ELYZA-japanese-Llama-2-13b-fast-instruct-q4_K_M.gguf";

        var sb = new StringBuilder();
	// モデルの生成
        try (var model = new LlamaModel(modelPath, modelParams)) {
	    // モデルからチャンク文字列を取り出し結合
            Iterable<LlamaModel.Output> outputs = model.generate(prompt.getContents(), inferParams);
            for (LlamaModel.Output output : outputs) {
                sb.append(output.text);
            }
        }
        return new ChatResponse(List.of(new Generation(sb.toString())));
    }
}

AutoConfigurationできるようファイルを作ります。

@AutoConfiguration
public class LlamaCppAutoConfiguration {

    @Bean
    @ConditionalOnMissingBean
    public LlamaCppChatClient llamaCppChatClient() {
        LlamaCppChatClient llamaCppChatClient = new LlamaCppChatClient();
        return llamaCppChatClient;
    }
}

src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports

com.example.aillamacpphelloworld.LlamaCppAutoConfiguration

送受信用のRestControllerを作ります。messageのパラメータに質問文を送り、返信してもらう形式です。

@RestController
public class SimpleAiController {

    private final LlamaCppChatClient chatClient;

    @Autowired
    public SimpleAiController(LlamaCppChatClient chatClient) {
        this.chatClient = chatClient;
    }

    @GetMapping("/ai/simple")
    public Completion completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) {
        return new Completion(chatClient.generate(message));
    }

}

返信を格納するクラス。クライアント側で受け取るJSONに対応します。

public class Completion {

    private String completion;

    public Completion(String completion) {
        this.completion = completion;
    }

    public String getCompletion() {
        return completion;
    }
}

動作確認

アプリケーションを実行し、別ターミナルから質問します。

./mvnw spring-boot:run
curl http://localhost:8080/ai/simple\?message=hello\!

これだけで動きます。自分のPCでLLMが文章を生成するのは、ちょっと感動しますね。

流れるような文章を生成してみる

ChatGPTでは、文字を連続的に生成したものを随時クライアントに返却し、表示に時間がかかるストレスを多少は軽減させてます。ここまでの実装は、せっかくLLMが連続的に生成した文字を一つにまとめてしまっているので、同じような振る舞いができるよう、Webfluxを使い、生成した時点で次々と返却するようにしてみます。

依存関係の変更

Tomcatを依存関係から外しspring-boot-starter-webを依存関係からはずし(もともと)、Stream処理に適したNettyを使うように変更します。

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
            <exclusions>
                <!-- Exclude the Tomcat dependency -->
                <exclusion>
                    <groupId>org.springframework.boot</groupId>
                    <artifactId>spring-boot-starter-tomcat</artifactId>
                </exclusion>
            </exclusions>
        </dependency>

コードの追加

Spring AIには、ストリーム処理用の StreamingChatClient も用意されているので、その流儀に沿って実装していきます。

public class LlamaCppChatClient implements ChatClient, StreamingChatClient {

    // プロンプトからレスポンスのストリームを生成
    @Override
    public Flux<ChatResponse> generateStream(Prompt prompt) {

	// モデルの準備は上述の同様
        return Flux.using(
                () -> new LlamaModel(modelPath, modelParams),
                model -> Flux.fromIterable(model.generate(prompt.getContents(), inferParams))
                        .map(output -> {
                            var text = output.text;
                            System.out.print(text);
                            return new ChatResponse(List.of(new Generation(text)));
                        }),
                LlamaModel::close
        );
    }
}

RestControllerにも、Webflux実装を追加していきます。

@RestController
public class SimpleAiController {

    @GetMapping(value = "/ai/simple-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<Completion> completionStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) {
        var prompt = new Prompt(new UserMessage(message));
        Flux<ChatResponse> chatResponseFlux = chatClient.generateStream(prompt);
        return chatResponseFlux
                .map(chatResponse -> new Completion(chatResponse.getGeneration().getContent()));
    }
}

動作確認

Webfluxは、Server-Sent Events(SSE)プロトコルを用いて通信するため、返却される電文は以下のようになります。

data:{"completion":"This"}

data:{"completion":" is"}

data:{"completion":" my"}

data:{"completion":" first"}

data:{"completion":" post"}

必要な部分だけ抽出し、流れるような文章を表示するための、クライアント側のコードを作ります。

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class AiLlamacppHelloworldApplicationTests {

    @LocalServerPort
    private int port;

    @Test
    public void testExample() throws Exception {
        // 「ジョークを言って!」
        String message = "tell me a joke";
        String urlString = "http://localhost:" + port + "/ai/simple-stream?message=" +
                java.net.URLEncoder.encode(message, "UTF-8");

        URL url = new URL(urlString);
        HttpURLConnection connection = (HttpURLConnection) url.openConnection();
        connection.setRequestMethod("GET");

        BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream()));
        String inputLine;
        while ((inputLine = in.readLine()) != null) {
            try {	        
                JSONObject obj = new JSONObject(inputLine.replace("data:", ""));
                System.out.print(obj.getString("completion"));
                System.out.flush(); // バッファリングせずにすぐ出力
            } catch (JSONException e) {}
        }
        in.close();
    }
}
./mvnw test

ジョークなのかなこれは...。

おわりに

今回は、一番基本的なLLMの実装をトライしてみました。
ローカルLLMの実現は、今回挙げた他にもLocalAIなどローカルサーバを立てて通信する手段があります。
また、Spring AIには、VectorDBや外部データと連携してRAG(Retrieval-Augmented Generation)を実現する機能もありますし、インターフェースや機能の抽象化という目的では、有名なLangChainのJava版である[Langchain4j]https://github.com/langchain4j)もあります。

Java界も活性化してますね!

自分の方でも今後、可能な限り紹介していきます。

Discussion