🍀

OllamaのK/V Context量子化の実践的検証と実装

2024/12/19に公開

はじめに

OllamaにおけるK/V context cache量子化は、VRAM使用量を大幅に削減できる革新的な技術です。本記事では、実際の検証結果を基に、その効果と実用性について詳しく解説します。また、検証に使用したスクリプトのセットアップと使用方法についても説明します。

検証環境

ハードウェア構成

  • GPU1: NVIDIA GeForce RTX 4090 (VRAM: 24GB)
  • GPU2: NVIDIA GeForce RTX 3060 (VRAM: 12GB)
  • CPU: Intel Core i9(24コア)

ソフトウェア構成

  • Ollama v0.4.7
  • Windows 11
  • CUDA 12.6
  • テストモデル: llama3.1

テスト結果

性能測定結果

🚀 応答速度

異なるタイプのプロンプトに対する応答時間を測定:

プロンプトタイプ 応答時間 トークン数 文字数
創作(SF小説) 2.65秒 365 543
コード生成 4.41秒 611 1,341
概念説明 2.03秒 282 431
技術解説 4.80秒 662 1,003
アーキテクチャ説明 4.77秒 661 1,127

📊 メモリ使用量の変化

RTX 4090とRTX 3060での比較:

RTX 4090:

  • ベースライン使用量: 6,562MB
  • 全テストを通じて使用量の変動なし
  • 安定したメモリ管理を実現

RTX 3060:

  • 初期使用量: 4,198MB
  • 最小使用量: 4,095MB
  • 最大の変動: -92MB(コード生成タスク時)
  • 平均変動: -20.4MB

🔍 重要な発見

  1. メモリ効率

    • RTX 4090では完全に安定したメモリ使用を実現
    • RTX 3060でもわずかな変動で効率的に動作
  2. タスク別の特性

    • 短い説明タスク(2-3秒)と長い生成タスク(4-5秒)で明確な差
    • コード生成時に最大のメモリ変動を観察
  3. 量子化の効果

    • メモリ使用量の安定性が向上
    • マルチGPU環境でも効率的な動作を確認

実装方法

基本設定

# Flash Attentionの有効化
export OLLAMA_FLASH_ATTENTION=1

# K/V cache量子化の設定
export OLLAMA_KV_CACHE_TYPE="q8_0"

推奨設定

  • Q8_0: 一般的な用途に最適
  • マルチGPU環境では主要なGPUにモデルを配置
  • バッチサイズの調整は不要

検証スクリプト

🔧 セットアップ

必要なパッケージのインストール

pip install requests pynvml art loguru

スクリプトの全体構造

import requests
import time
import json
from typing import Dict, List, Optional
import pynvml
import argparse
from datetime import datetime
import csv
from art import text2art
from loguru import logger
import sys

class OllamaClient:
    def __init__(self, host: str = "http://localhost:11434"):
        """Ollamaクライアントの初期化"""
        self.host = host
        self.base_url = f"{host}/api"
        pynvml.nvmlInit()
        logger.info(f"Initialized Ollama client with host: {host}")

📊 主要な機能

1. GPUメモリ監視

def get_gpu_memory(self) -> Dict[str, int]:
    """GPU使用メモリを取得"""
    memory_info = {}
    deviceCount = pynvml.nvmlDeviceGetCount()
    for i in range(deviceCount):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        name = pynvml.nvmlDeviceGetName(handle)
        if isinstance(name, bytes):
            name = name.decode('utf-8')
        memory_info[name] = info.used // 1024 // 1024  # MB単位
        logger.debug(f"GPU {name}: {memory_info[name]}MB used")
    return memory_info

2. テキスト生成

def generate(self, model: str, prompt: str, stream: bool = True) -> dict:
    """テキスト生成を実行"""
    logger.info(f"Generating text with model: {model}")
    logger.debug(f"Prompt: {prompt[:100]}...")
    
    url = f"{self.base_url}/generate"
    data = {
        "model": model,
        "prompt": prompt,
        "stream": stream
    }
    
    if stream:
        response = requests.post(url, json=data, stream=True)
        full_response = ""
        for line in response.iter_lines():
            if line:
                json_response = json.loads(line)
                if 'response' in json_response:
                    full_response += json_response['response']
                if json_response.get('done', False):
                    logger.success("Text generation completed")
                    return {
                        'response': full_response,
                        'total_duration': json_response.get('total_duration', 0),
                        'load_duration': json_response.get('load_duration', 0),
                        'prompt_eval_count': json_response.get('prompt_eval_count', 0),
                        'eval_count': json_response.get('eval_count', 0)
                    }

3. ベンチマーク実行

def run_benchmark(client: OllamaClient, 
                 model: str, 
                 prompts: List[str], 
                 output_file: str):
    """ベンチマークを実行して結果をCSVに保存"""
    logger.info("Starting benchmark...")
    results = []
    
    for i, prompt in enumerate(prompts, 1):
        logger.info(f"Running test {i}/{len(prompts)}")
        
        # 生成前のGPUメモリを記録
        initial_memory = client.get_gpu_memory()
        
        # 生成実行
        start_time = time.time()
        response = client.generate(model, prompt)
        end_time = time.time()
        
        # 生成後のGPUメモリを記録
        final_memory = client.get_gpu_memory()
        
        # 結果を記録
        result = {
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'prompt': prompt[:50],
            'total_duration': end_time - start_time,
            'load_duration': response.get('load_duration', 0),
            'prompt_eval_count': response.get('prompt_eval_count', 0),
            'eval_count': response.get('eval_count', 0),
            'response_length': len(response.get('response', '')),
        }
        
        # 各GPUのメモリ使用量の変化を記録
        for gpu, initial in initial_memory.items():
            result[f'{gpu}_initial_memory_mb'] = initial
            result[f'{gpu}_final_memory_mb'] = final_memory[gpu]
            result[f'{gpu}_memory_change_mb'] = final_memory[gpu] - initial
            logger.info(f"GPU {gpu} memory change: {result[f'{gpu}_memory_change_mb']}MB")
        
        results.append(result)
        logger.success(f"Test {i}/{len(prompts)} completed in {result['total_duration']:.2f}s")

    # 結果をCSVに保存
    if results:
        with open(output_file, 'w', newline='', encoding='utf-8') as f:
            writer = csv.DictWriter(f, fieldnames=results[0].keys())
            writer.writeheader()
            writer.writerows(results)
        logger.success(f"Results saved to {output_file}")

🚀 使用方法

基本的な実行

python ollama_client.py

オプションの指定

# 特定のモデルを使用
python ollama_client.py --model llama3.1:latest

# 出力ファイルの指定
python ollama_client.py --output results.csv

# デバッグモード有効
python ollama_client.py --debug

📈 出力例

CSVフォーマット

timestamp,prompt,total_duration,load_duration,prompt_eval_count,eval_count,response_length,
2024-12-05 17:23:01,創作タスク,2.65,10258500,33,365,543,4198,4192,-6,6562,6562,0
2024-12-05 17:23:05,コード生成,4.41,9704500,37,611,1341,4192,4100,-92,6562,6562,0
2024-12-05 17:23:07,概念説明,2.03,10876500,25,282,431,4100,4096,-4,6562,6562,0

⚙️ カスタマイズのポイント

  1. プロンプトの設定
prompts = [
    "1000文字のSF小説を書いてください。設定は宇宙船内での出来事です。",
    "Pythonで簡単なウェブスクレイピングプログラムを書いてください。",
    "量子コンピューティングについて500文字で説明してください。",
    "機械学習におけるバイアスとバリアンスのトレードオフについて説明してください。",
    "クリーンアーキテクチャの主要な原則について説明し、実装例を示してください。"
]
  1. ロギングの設定
logger.add(
    sys.stdout,
    format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
    level="INFO"
)
logger.add("ollama_benchmark_{time}.log")

考察

メリット

  1. 安定性の向上

    • メモリ使用量の変動が最小限(RTX 4090で変動なし)
    • マルチGPU環境での効率的なリソース利用
  2. 実用的なパフォーマンス

    • 応答時間は用途に応じて2-5秒の範囲
    • メモリ効率と処理速度のバランスが良好

最適な使用シナリオ

  1. 大規模モデルの実行

    • より大きなモデルを限られたVRAMで実行可能
    • 特にRTX 3060のような中規模GPUで効果的
  2. 長時間の推論タスク

    • メモリ使用量の安定性が重要な場合に有効
    • バッチ処理やストリーミング生成に適している

まとめ

K/V Context量子化は、特に以下の点で効果的であることが実証されました:

  1. 安定性: メモリ使用量の変動を最小限に抑制
  2. 効率性: マルチGPU環境での効率的なリソース利用
  3. 実用性: 応答速度とメモリ効率のバランスが良好

この機能は、特に限られたVRAMでより大きなモデルを実行したい場合や、長時間の安定した推論が必要な場合に非常に有用です。また、提供したベンチマークスクリプトを使用することで、独自の環境での性能評価や最適な設定の発見が可能になります。

参考情報

全体コード

import requests
import time
import json
from typing import Dict, List, Optional
import pynvml
import argparse
from datetime import datetime
import csv
from art import text2art
from loguru import logger
import sys

# ロガーの設定
logger.remove()
logger.add(
    sys.stdout,
    format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
    level="INFO"
)
logger.add("ollama_benchmark_{time}.log")

def print_banner():
    """アプリケーションバナーを表示"""
    art = text2art("Ollama Benchmark", font='block')
    print("\033[94m" + art + "\033[0m")  # 青色で表示
    print("\033[92m" + "=" * 50 + "\033[0m")  # 緑色の区切り線
    print("\033[93mK/V Context Quantization Performance Test\033[0m")  # 黄色でサブタイトル
    print("\033[92m" + "=" * 50 + "\033[0m\n")  # 緑色の区切り線

class OllamaClient:
    def __init__(self, host: str = "http://localhost:11434"):
        """Ollamaクライアントの初期化"""
        self.host = host
        self.base_url = f"{host}/api"
        pynvml.nvmlInit()
        logger.info(f"Initialized Ollama client with host: {host}")

    def get_gpu_memory(self) -> Dict[str, int]:
        """GPU使用メモリを取得"""
        memory_info = {}
        deviceCount = pynvml.nvmlDeviceGetCount()
        for i in range(deviceCount):
            handle = pynvml.nvmlDeviceGetHandleByIndex(i)
            info = pynvml.nvmlDeviceGetMemoryInfo(handle)
            name = pynvml.nvmlDeviceGetName(handle)
            if isinstance(name, bytes):
                name = name.decode('utf-8')
            memory_info[name] = info.used // 1024 // 1024  # MB単位
            logger.debug(f"GPU {name}: {memory_info[name]}MB used")
        return memory_info

    def list_models(self) -> List[Dict]:
        """利用可能なモデルの一覧を取得"""
        logger.info("Fetching available models...")
        response = requests.get(f"{self.base_url}/tags")
        models = response.json()['models']
        logger.info(f"Found {len(models)} models")
        return models

    def generate(self, model: str, prompt: str, stream: bool = True) -> dict:
        """テキスト生成を実行"""
        logger.info(f"Generating text with model: {model}")
        logger.debug(f"Prompt: {prompt[:100]}...")
        
        url = f"{self.base_url}/generate"
        data = {
            "model": model,
            "prompt": prompt,
            "stream": stream
        }
        
        if stream:
            response = requests.post(url, json=data, stream=True)
            full_response = ""
            for line in response.iter_lines():
                if line:
                    json_response = json.loads(line)
                    if 'response' in json_response:
                        full_response += json_response['response']
                    if json_response.get('done', False):
                        logger.success("Text generation completed")
                        return {
                            'response': full_response,
                            'total_duration': json_response.get('total_duration', 0),
                            'load_duration': json_response.get('load_duration', 0),
                            'prompt_eval_count': json_response.get('prompt_eval_count', 0),
                            'eval_count': json_response.get('eval_count', 0)
                        }
        else:
            response = requests.post(url, json=data)
            return response.json()

def run_benchmark(client: OllamaClient, 
                 model: str, 
                 prompts: List[str], 
                 output_file: str):
    """ベンチマークを実行して結果をCSVに保存"""
    logger.info("Starting benchmark...")
    results = []
    
    for i, prompt in enumerate(prompts, 1):
        logger.info(f"Running test {i}/{len(prompts)}")
        logger.debug(f"Prompt: {prompt[:50]}...")
        
        # 生成前のGPUメモリを記録
        initial_memory = client.get_gpu_memory()
        
        # 生成実行
        start_time = time.time()
        response = client.generate(model, prompt)
        end_time = time.time()
        
        # 生成後のGPUメモリを記録
        final_memory = client.get_gpu_memory()
        
        # 結果を記録
        result = {
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'prompt': prompt[:50],
            'total_duration': end_time - start_time,
            'load_duration': response.get('load_duration', 0),
            'prompt_eval_count': response.get('prompt_eval_count', 0),
            'eval_count': response.get('eval_count', 0),
            'response_length': len(response.get('response', '')),
        }
        
        # 各GPUのメモリ使用量の変化を記録
        for gpu, initial in initial_memory.items():
            result[f'{gpu}_initial_memory_mb'] = initial
            result[f'{gpu}_final_memory_mb'] = final_memory[gpu]
            result[f'{gpu}_memory_change_mb'] = final_memory[gpu] - initial
            logger.info(f"GPU {gpu} memory change: {result[f'{gpu}_memory_change_mb']}MB")
        
        results.append(result)
        logger.success(f"Test {i}/{len(prompts)} completed in {result['total_duration']:.2f}s")

    # 結果をCSVに保存
    if results:
        with open(output_file, 'w', newline='', encoding='utf-8') as f:
            writer = csv.DictWriter(f, fieldnames=results[0].keys())
            writer.writeheader()
            writer.writerows(results)
        logger.success(f"Results saved to {output_file}")

def main():
    print_banner()
    
    parser = argparse.ArgumentParser(description='Ollama Benchmark Client')
    parser.add_argument('--model', default='llama3.1:latest', help='Model to use')
    parser.add_argument('--output', default='benchmark_results.csv', help='Output CSV file')
    parser.add_argument('--debug', action='store_true', help='Enable debug logging')
    args = parser.parse_args()

    if args.debug:
        logger.remove()
        logger.add(sys.stdout, level="DEBUG")

    client = OllamaClient()
    
    # テスト用プロンプト
    prompts = [
        "1000文字のSF小説を書いてください。設定は宇宙船内での出来事です。",
        "Pythonで簡単なウェブスクレイピングプログラムを書いてください。BeautifulSoupを使用してください。",
        "量子コンピューティングについて500文字で説明してください。",
        "機械学習におけるバイアスとバリアンスのトレードオフについて説明してください。",
        "クリーンアーキテクチャの主要な原則について説明し、実装例を示してください。"
    ]

    logger.info(f"Using model: {args.model}")
    logger.info("Checking available models...")
    
    models = client.list_models()
    print("\nAvailable models:")
    for model in models:
        print(f"- {model['name']}: {model.get('digest', 'N/A')}")

    print("\n" + "=" * 50)
    run_benchmark(client, args.model, prompts, args.output)

if __name__ == "__main__":
    main()

<script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>

Discussion