🐥

ML-KEMをPythonで簡易実装試みる!?

に公開

はじめに

PQC(耐量子計算機暗号)やQKD(量子鍵配送)といった最先端の暗号技術に強い関心を持っている大学3年生です。将来的には大学院でこれらの分野を研究し、より安全なデジタル社会の構築に貢献したいと考えています。

そして現在、私はPQCやQKDの学習と情報発信に取り組んでいます。難しい内容も分かりやすく伝えることを目標に、独学で得た知識を体系的に整理し、記事シリーズとしてまとめています。

本記事では、PQCの技術の核心に迫るべく、NISTによって標準化されたML-KEMの簡易実装に挑戦しました。机上での学習だけでなく、実際にコードを書き、理論を「動く形」で体感することで本質的な理解につながると考えたからです。

ML-KEMの複雑なアルゴリズムを、Pythonを用いてわかりやすくコードに落とし込むことを試みました。しかし、その過程で私は、単なる実装の難しさだけではない、暗号アルゴリズムの繊細さに直面しました。当初の目論見通りにはいきませんでしたが、この経験は私にとって、貴重な学びとなりました。

前提

ML-KEMの手順についてざっくり知っていることを前提として進めます。
ML-KEMやK-PKEの理論について詳しく知りたい方は、先に以下の記事をご覧ください。

準備

1. 環境

今回の簡易実装は、Python 3の標準ライブラリのみで動作します。特別なパッケージをインストールする必要はありません。

2. プロジェクト構成

コードは、役割に応じて以下のディレクトリとファイルに分割して実装を進めます。

MLKEM/
├── utils/
│   ├── __init__.py                 
│   ├── hash_utils.py       # H, G, J などハッシュ・乱数生成
│   └── poly_utils.py       # 多項式演算, NTT など
│
├── kpke/
│   ├── __init__.py
│   └── kpke_core.py        # K-PKE: Keygen, Enc, Dec
│
├── mlkem/
│   ├── __init__.py
│   ├── mlkem_core.py       # ML-KEM:EnckeyGen,Encap, Decaps
│
└── main.py                 # 実行テスト
  • utils/: ハッシュ計算や多項式演算など、プロジェクト全体で使われる基本的な部品(ユーティリティ関数)を格納します。

  • kpke/: ML-KEMの中核をなす公開鍵暗号アルゴリズム(K-PKE)を実装します。

  • mlkem/: K-PKEを基に、最終的な鍵カプセル化メカニズム(ML-KEM)のロジックを構築します。

  • main.py: 実装したML-KEMが正しく動作するかを確認するための、テストプログラムです。

3. 作成物

プログラムのソースコードは下記となります。
git cloneしてご利用ください。

https://github.com/miki555555na/MLKEM

ML-KEMアルゴリズムの概要

ML-KEMのアルゴリズムは、詳細を見る前に大きく以下の流れで整理できます。

  1. 鍵生成

    • K-PKE鍵生成アルゴリズムを使用して、共通鍵を安全に受け取るためのカプセル化鍵(公開鍵)とデカプセル化鍵(秘密鍵)を生成します。
  2. カプセル化(Encapsulation)

    • 送信者は、受信者のカプセル化鍵(公開鍵)を使用し、共通鍵と暗号文を生成します。
    • 送信者は受信者に暗号文のみを送信します。
  3. デカプセル化(Decapsulation)

    • 受信者は送られてきた暗号文を秘密鍵で復元し、共通鍵を取り出します。
    • このとき、暗号文が途中で改ざんされていないかを検証し、安全性が確認できた場合のみ、その共通鍵を有効にします。一致しなかった場合は、ランダム値を代替として共通鍵にします。

ここからは、ML-KEMアルゴリズムを擬似的なパラメータで簡易的に実装することで理解を深めていきます。そのために、実際の暗号で使われる巨大なパラメータではなく、手計算でも追えるような擬似的なパラメータを用いた簡易版を実装していきます。

実装は以下のステップで進めます。

  1. 部品となる関数の実装: 全体で使われる共通の関数を準備します。
  2. K-PKEアルゴリズムの実装: ML-KEMの中核となる公開鍵暗号(Public Key Encryption)部分を構築します。
  3. ML-KEMアルゴリズムの実装: K-PKEを基に、鍵カプセル化メカニズム全体を完成させます。

今回の実装で用いるパラメータは以下の通りです:

  • 多項式の次数:n = 8
  • 係数の法:q = 17
  • 多項式の数:k = 2
  • n次の原始根:psi = 3^{{q-1}/n} \mod q = 9

※n次の原始根は、多項式演算の高速化で使用されるNTT(数論変換)で中心的な役割を果たすため、ご存知でない方は先にこちらをご覧ください。

部品となる関数の実装

ML-KEMのような複雑な暗号アルゴリズムは、多くの小さな部品(関数)の組み合わせで成り立っています。ここでは、それらの部品を役割に応じて2つのファイルに分けて実装していきます。

汎用ユーティリティ関数(hash_utils.py)

一つ目は、暗号の数学的な構造とは直接関係なく、データそのものを操作するための低レベルな関数群です。ここには、ハッシュ化、XOF(拡張可能出力関数)、乱数生成、そしてバイトとビットの相互変換といった、基本的なデータ処理の道具を揃えます。

hash_utils.py
#ハッシュ、乱数生成
#n=8 q=17 k=2 の実装例(簡略版)
import hashlib

#パラメータ
n = 8 #多項式の次数
q = 17 #係数の法
k = 2 #多項式の数 (k=eta) 

#=====ハッシュ関数=====

def hash_G(data: bytes, k: int) -> tuple[bytes, bytes]:
    data = data + k.to_bytes(1, 'little')
    h = hashlib.sha3_512(data).digest()
    return h[:32], h[32:]

def hash_H(data: bytes) -> bytes:
    return hashlib.sha3_256(data).digest()

def hash_J(data: bytes, output_length: int) -> bytes:
    return hashlib.shake_256(data).digest(output_length)

def prf(seed: bytes, eta: int, b: int) -> bytes:
    #役割:疑似ランダム関数(実際には規則性があるが、規則性がないように見える数字を作ってくれる関数)で、指定された長さの乱数を生成する
    assert len(seed) == 32
    assert eta in [2,3]
    assert 0 <= b < 256 # 1バイト

    output_len = 64 * eta #出力バイト数
    hasher = hashlib.shake_256()
    hasher.update(seed)
    hasher.update(bytes([b]))
    
    return hasher.digest(output_len)

#=====XOF(Extended-Output Function)=====

class XOF:
    #役割:任意の長さの暗号ハッシュを生成する
    def __init__(self):
        self.__hash__obj = hashlib.shake_128()

    def absorb(self, input_data: bytes):
        self.__hash__obj.update(input_data)

    def squeeze(self, output_length: int) -> bytes:
        return self.__hash__obj.digest(output_length)
    
#=====バイト・ビット=====
def BytesToBits(input_Byte: bytes) -> list[int]:
    bits = []
    for byte in input_Byte:
        for j in range(8):
            bits.append(byte & 1) #最下位ビットを取り出す
            byte //= 2 #次のビットに移動
    return bits

def BitsToBytes(input_data: list[int]) -> bytes:
    n_bits = len(input_data)
    assert n_bits % 8 == 0
    n_bytes = n_bits // 8
    output = [0] * n_bytes
    for i in range(n_bits):
        byte_index = i // 8
        bit_index = i % 8
        output[byte_index] |= (input_data[i] << bit_index)  
    return bytes(output)

def bit_rev(x: int, bits: int) -> int:
    #役割:指定されたビット幅(bits)でビットを反転する
    result = 0
    for i in range(bits):
        bit = (x >> i) & 1 #1ビットだけ抽出する
        result |= bit << (bits - 1 - i) #反転位置にセット
    return result

多項式関連アルゴリズム(poly_utils.py)

二つ目は、ML-KEMの核となる「多項式」というデータ構造に特化した、より高レベルなアルゴリズム群です。先のhash_utils.pyで定義した関数を内部で利用しつつ、多項式同士の加算や乗算、高速な計算を実現するNTT(数論変換)、そしてデータのサイズを圧縮・解凍するための関数などを実装します。

poly_utils.py
#K-PKE実装
#n=8 q=17 k=2 の実装例(簡略版)
from . import hash_utils
import math

#パラメータ
n = 8 #多項式の次数
q = 17 #係数の法
k = 2 #多項式の数 (k=eta)

#=====エンコード・デコード=====

def ByteEncode(F: list[int], d: int):
    #役割:(0,q-1)範囲の整数リストの各要素をdビット表現にした(n x d ビット)後、バイト化する。
    b = [0] * (n * d)
    for i in range(n):
        a = F[i] 
        for j in range(d):
            b[i * d + j] = a & 1
            a >>= 1
    return hash_utils.BitsToBytes(b)

def ByteDecode(B: bytes, d: int, k: int) -> list[list[int]]:
    #役割:エンコードでバイト化されたものを整数リスト(k個)に戻す。
    bit_list = hash_utils.BytesToBits(B)
    polynomials = []
    for poly_idx in range(k):
        poly = []
        for coeff_idx in range(n):
            coeff = 0
            for bit_idx in range(d):
                bit_position = poly_idx * n * d + coeff_idx * d + bit_idx
                coeff |= (bit_list[bit_position] << bit_idx)
            coeff %= q  # mod q
            poly.append(coeff)
        polynomials.append(poly)
    return polynomials

#=====乱数分布=====

def sample_poly_cbd(seed: bytes, eta: int):
    #役割:乱数バイト列から中心二項分布に従う多項式の係数をサンプリングし、n次元多項式のリストを生成する
    b = hash_utils.BytesToBits(seed)
    coefficients = [0] * n
    for i in range(n):
        x = 0
        y = 0
        for j in range(eta):
            x += b[2 * i * eta + j]
            y += b[2 * i * eta + j + eta]
        coefficients[i] = (x - y) % q
    return coefficients

#=====圧縮・復元=====

def Compress(vec: list[int], d: int) -> list[int]:
    #役割:{0,q-1}の範囲のリストを{0,1}の範囲に変換する
    n = len(vec)
    result = [0] * n
    for i in range(n):
        if(((vec[i] + 1) // 2) >= 5):
            result[i] = 1
        else:
            result[i] = 0
    return result

def Decompress(vec: list[int], d: int) -> list[int]:
    #役割:{0,1}範囲のリストをqの範囲に変換する
    n = len(vec)
    result = [0] * n
    for i in range(n):
        if (vec[i] == 0):
            result[i] = 0
        else:
            result[i] = (q + 1) // 2
    return result

#=====多項式演算=====
def poly_add(poly1: list[int], poly2: list[int]) -> list[int]:
    # zipで2つのリストの要素を同時に取り出し、計算結果を新しいリストとして返す
    return [(p1 + p2 + q) % q for p1, p2 in zip(poly1, poly2)]

def poly_mul_ntt(poly1: list[int], poly2: list[int]) -> list[int]:

    return [(p1 * p2 + q) % q for p1, p2 in zip(poly1, poly2)]

def poly_sub(poly1: list[int], poly2: list[int]) -> list[int]:

    return [(p1 - p2 + q) % q for p1, p2 in zip(poly1, poly2)]

#=====NTT関連=====

def sample_ntt(input_data: bytes) -> list[int]:
    #役割:NTT多項式の係数リストを生成する
    ctx = hash_utils.XOF()
    ctx.absorb(input_data) #乱数ストリームを生成
    coefficients = []
    chunk_size = 16 #乱数ストリームの出力を16バイトに設定(まとめて乱数生成することで効率性向上)
    buffer = b''
    while len(coefficients) < n:
        if len(buffer) < chunk_size:
            buffer += ctx.squeeze(chunk_size)
        b = buffer[0] #バッファから1バイト取り出す
        buffer = buffer[1:] #バッファの更新
        val = b & 0x1F  #下位5ビットの抽出
        if val < q: #qによるフィルタリング(q(=17) < 2^5 )
            coefficients.append(val)
    return coefficients 

def NTT(poly: list[int], psi: int):
    #役割:係数リストをNTT多項式の係数リストに変換する

    f = poly.copy()
    #リストの順序を入れ替える
    for i in range(n):
        rev_i = hash_utils.bit_rev(i, int(math.log2(n)))
        if i < rev_i:
            f[i], f[rev_i] = f[rev_i], f[i] 

    # バタフライ演算
    length = 2 #ステージを表しlength 2,4,8 と計算規模を大きくする
    while length <= n:
        half_len = length // 2
        # ステージごとの回転因子
        zeta_base = pow(psi, n // length, q) 
        for start in range(0, n, length): #処理するブロックの開始位置
            zeta = 1
            for j in range(start, start + half_len): #length長のブロック
                t = (zeta * f[j + half_len]) % q
                f[j + half_len] = (f[j] - t + q) % q
                f[j] = (f[j] + t) % q
                zeta = (zeta * zeta_base) % q
        length *= 2
    return f

def NTT_rev(poly: list[int], psi_rev: int):
    # NTTとほぼ同じ構造で逆の回転因子を使う
    f_ntt = NTT(poly, psi_rev) # NTT関数を psi_rev で再利用

    # 最後に n の逆元を掛ける(NTT変換でn倍されているため)
    n_inv = pow(n, -1, q)
    return [(coeff * n_inv) % q for coeff in f_ntt]

K-PKEアルゴリズムの実装(kpke_core.py)

K-PKEアルゴリズムは、大きく三つのプロセスから成っています。

  • 鍵生成:受信者は乱数シードから公開鍵(ek_PKE), 秘密鍵(dk_PKE)を生成。
  • 暗号化:送信者は受信者の公開鍵でランダムメッセージを暗号化し、暗号文を生成。
  • 復号:受信者は秘密鍵で暗号文を復号。
kpke_core.py
#K-PKE実装
#n=8 q=17 k=2 の実装例(簡略版)
from ..utils import hash_utils
from ..utils import poly_utils


#パラメータ
n = 8 #多項式の次数
q = 17 #係数の法
k = 2 #多項式の数 (k=eta)
eta = k
psi = pow(3, (q-1) // n, q) #n次の原始根
psi_rev = pow(psi, -1, q) #(ζ^-1)

#K-PKE鍵生成
def k_pke_keygen(seed: bytes) -> tuple:
    #入力:32バイトの乱数
    #出力:公開鍵(ek_PKE), 秘密鍵(dk_PKE)

    #1.乱数生成 (p,S)
    p,S = hash_utils.hash_G(seed,k)

    #2.公開鍵行列 ntt_A (k x k)生成
    ntt_A = [[0] * k for _ in range(k)]
    for i in range(k):
        for j in range(k):
            input_data = p + j.to_bytes(1, 'little') + i.to_bytes(1, 'little')
            ntt_A[i][j] =poly_utils.sample_ntt(input_data) #n次元多項式の係数リストを生成
              
    #3.秘密鍵ベクトル s と誤差ベクトル e の生成
    nonce = 0
    s = [0] * k
    for i in range(k):
        s[i] = poly_utils.sample_poly_cbd(hash_utils.prf(S, eta, nonce), eta) #乱数を生成し、n次元多項式のリストを生成
        nonce += 1
    
    e = [0] * k
    for i in range(k):
        e[i] = poly_utils.sample_poly_cbd(hash_utils.prf(S, eta, nonce), eta) #乱数を生成し、n次元多項式のリストを生成
        nonce += 1

    #4.NTT変換
    ntt_vec_s = [poly_utils.NTT(poly, psi) for poly in s]
    ntt_vec_e = [poly_utils.NTT(poly, psi) for poly in e]

    #5.公開鍵 b の生成
    # b = ntt_A * ntt_vec_s + ntt_vec_e 
    b = [0] * k
    for i in range(k):
        term_ntt = [0] * n
        for j in range(k):
            prod = poly_utils.poly_mul_ntt(ntt_A[i][j], ntt_vec_s[j]) #リスト同士の乗算
            term_ntt = poly_utils.poly_add(term_ntt, prod)
        b[i] = poly_utils.poly_add(term_ntt, ntt_vec_e[i])

    #6.エンコード
    encoded_b = []
    for i in range(k):
        encoded_part = poly_utils.ByteEncode(b[i], 5) #n次元多項式の係数リストからバイト列を生成
        encoded_b.append(encoded_part)
    
    encoded_s = []
    for i in range(k):
        encoded_part = poly_utils.ByteEncode(ntt_vec_s[i], 5)
        encoded_s.append(encoded_part)

    #7.鍵ペアの生成

    # 10 bytes + 32 bytes = 42 bytes
    ek_PKE = b''.join(encoded_b) + p 
    # 10 bytes
    dk_PKE = b''.join(encoded_s)  

    return (ek_PKE,dk_PKE)

#K-PKE暗号化
def k_pke_enc(ek_PKE: bytes, m:bytes, r:bytes) -> bytes:
    #入力:公開鍵(ek_PKE),メッセージ(m,1バイト),乱数(r,32バイト)
    #出力:暗号文(c)

    #1.公開鍵のデコード
    p = ek_PKE[-32:]
    encorded_b = ek_PKE[:-32]
    b = poly_utils.ByteDecode(encorded_b, 5, k) #バイト列からk個のn次元多項式の(0,q-1)範囲の整数リストを返す

    #2.公開鍵行列 ntt_A (k x k)再現(生成)
    ntt_A = [[0] * k for _ in range(k)]
    for i in range(k):
        for j in range(k):
            input_data = p + j.to_bytes(1, 'little') + i.to_bytes(1, 'little')
            ntt_A[i][j] =poly_utils.sample_ntt(input_data) #n次元多項式の係数リストを生成

    #3.一時乱数ベクトル y と誤差ベクトル e1 を生成
    y = [0] * k
    nonce = 0
    for i in range(k):
        y[i] = poly_utils.sample_poly_cbd(hash_utils.prf(r, eta, nonce), eta) #n次元多項式の係数リストを生成
        nonce += 1
    
    e1 = [0] * k
    for i in range(k):
        e1[i] = poly_utils.sample_poly_cbd(hash_utils.prf(r, eta, nonce), eta) 
        nonce += 1

    #4.誤差多項式 e2 を生成
    e2 = poly_utils.sample_poly_cbd(hash_utils.prf(r, eta, nonce), eta)

    #5.NTT変換
    ntt_vec_y = [poly_utils.NTT(poly, psi) for poly in y]

    #6.ベクトル U の生成
    #U = ntt_A * ntt_vec_y + e1
    U = [0] * k
    for i in range(k):
        term_ntt = [0] * n
        for j in range(k):
            prod = poly_utils.poly_mul_ntt(ntt_A[i][j], ntt_vec_y[j]) #リスト同士の乗算
            term_ntt = poly_utils.poly_add(term_ntt, prod)
        U[i] = poly_utils.poly_add(term_ntt, e1[i])
    
    #7.メッセージ m を多項式化
    μ_poly_list = poly_utils.ByteDecode(m, 1, 1) #n次元の多項式の{0,1}の係数リストを生成
    μ = poly_utils.Decompress(μ_poly_list[0], 1) #0 => 0 , 1 =>(q + 1)/2 の範囲に変換

    #8.暗号文 V の生成
    #V = b * ntt_vec_y + e2 + μ
    term_ntt = [0] * n
    for i in range(k):
        prod = poly_utils.poly_mul_ntt(b[i], ntt_vec_y[i])
        term_ntt = poly_utils.poly_add(term_ntt, prod)
    V = poly_utils.poly_add(term_ntt, poly_utils.poly_add(e2, μ)) #n次元の多項式の係数リストを生成

    #9.エンコード
    u = []
    for i in range(k):
        compressed_U = poly_utils.Compress(U[i], 1) #(0~q-1)表現から{0,1}表現の係数リストを生成
        encorded_U = poly_utils.ByteEncode(compressed_U, 1) #バイト列生成
        u.append(encorded_U)

    v = []
    compressed_V = poly_utils.Compress(V, 1) #(0~q-1)表現から{0,1}表現の係数リストを生成
    v = poly_utils.ByteEncode(compressed_V, 1) #バイト列生成
    
    c = b''.join(u) + v #2+1=3バイト

    return c

#K-PKE復号
def k_pke_dec(dk_PKE: list[int], c: bytes) -> bytes:
    #入力:秘密鍵(dk_PKE),暗号文(c)
    #出力:復号文(m')

    #1.暗号文のデコード
    u_byte = c[0:2]
    v = c[2:]

    #2.UとVを復元
    # バイト列から係数リストに直し、(0,q-1)表現の係数リストを復元
    U = [poly_utils.Decompress(i, 1) for i in poly_utils.ByteDecode(u_byte, 1, k)]
    V = [poly_utils.Decompress(i, 1) for i in poly_utils.ByteDecode(u_byte, 1, 1)]

    #3.秘密鍵を復元
    # バイト列から(0,q-1)表現の係数リストを復元
    s = poly_utils.ByteDecode(dk_PKE, 5, k)

    #4.中間多項式を生成
    #w = V - s * NTT(U)
    ntt_U = [poly_utils.NTT(poly,psi) for poly in U]
    term_ntt = [0] * n
    for j in range(k):
        prod = poly_utils.poly_mul_ntt(s[j], ntt_U[j])  
        term_ntt = poly_utils.poly_add(term_ntt, prod)
    w = poly_utils.poly_sub(V[0], poly_utils.NTT_rev(term_ntt,psi_rev))
    
    #5.メッセージの復元
    #バイト列を復元
    m = poly_utils.ByteEncode(poly_utils.Compress(w,1), 1)

    return m

K-PKEアルゴリズムを単体で実装し、その挙動を確かめる

実際のML-KEMアルゴリズムの実装に入る前に、基盤となるK-PKEが単体でどのように振る舞うかを見たいと思います。

実は、この過程で、私は大きな壁にぶつかり、そして同時に非常に重要な気づきを得ることになりました。

それは、「暗号アルゴリズムは、たとえ簡易版であっても、数学的な前提条件から少しでも外れると、途端に正しく動作しなくなる」 という事実です。

当初、私は簡易パラメータで復号が頻繁に失敗する原因を、コードのバグだと考えていました。しかし、いくらコードを修正しても、成功と失敗が確率的に発生する状況は変わりませんでした。

Kyber PKE Simplified Test (n=8, q=17, k=2)
==================================================

--- 1. 鍵生成中... ---
✅ 公開鍵 (ek_PKE) 生成完了 (長さ: 42 bytes)
✅ 秘密鍵 (dk_PKE) 生成完了 (長さ: 10 bytes)

--- 2. 暗号化中... ---
メッセージ (平文): 01
乱数: bc84d84fc66d736f34e7e7e8e3f32983cd79905e10a44a84a93f3aa1f3bf0ec3
✅ 暗号文 生成完了 (長さ: 3 bytes)
暗号文 (c): 8589a9

--- 3. 復号中... ---
✅ 復号完了
復号されたメッセージ: 01

--- 4. 検証結果 ---
🥳 成功! 元のメッセージと復号されたメッセージは一致します。
==================================================

--- NTT Self-Test ---
Original:  [1, 2, 3, 4, 5, 6, 7, 8]
Restored:  [1, 2, 3, 4, 5, 6, 7, 8]
NTT/NTT_rev pair is CORRECT.
Kyber PKE Simplified Test (n=8, q=17, k=2)
==================================================

--- 1. 鍵生成中... ---
✅ 公開鍵 (ek_PKE) 生成完了 (長さ: 42 bytes)
✅ 秘密鍵 (dk_PKE) 生成完了 (長さ: 10 bytes)

--- 2. 暗号化中... ---
メッセージ (平文): 01
乱数: 154e701334e918e87a8f8c1b5dd9ca2af27a61d5efcec4d3a6e4ffc7686be5bb
✅ 暗号文 生成完了 (長さ: 3 bytes)
暗号文 (c): 04243a

--- 3. 復号中... ---
✅ 復号完了
復号されたメッセージ: 8e

--- 4. 検証結果 ---
❌ 失敗! メッセージが一致しませんでした。
==================================================

--- NTT Self-Test ---
Original:  [1, 2, 3, 4, 5, 6, 7, 8]
Restored:  [1, 2, 3, 4, 5, 6, 7, 8]
NTT/NTT_rev pair is CORRECT.

時折成功するものの、ほとんどの試行で失敗するというこの結果こそが、PQC(耐量子計算機暗号)の概要と必要性 で述べた「PQC導入の課題」を身をもって体験させてくれました。

この失敗はバグではなく、 このパラメータ設定における「正しい挙動」 だったのです。簡易パラメータはアルゴリズムの構造を理解するにはとても役立ちますが、小さすぎるがゆえに暗号化で加わる「ノイズ」に耐えきれず、確率的に復号が失敗してしまっています。
実際のML-KEMで巨大なパラメータが使われているのは、この失敗確率を天文学的に低くするためでした。

この「大きな悲しみ」と「重要な気づき」を経て、私は暗号実装の繊細さと、パラメータ設計の重要性を深く理解することができました。ここではアルゴリズムの構造理解に焦点を当てるため、この簡易実装のまま、次はいよいよML-KEM全体の構築へと進みます。

ML-KEMアルゴリズムの実装(mlkem_core.py)

ML-KEMアルゴリズムは、大きく三つのプロセスからなっています。

  • 鍵生成:受信者はK-PKE鍵生成アルゴリズムから、共通鍵を安全に受け取るためのカプセル化鍵(ek)とデカプセル化鍵(dk)を生成
  • カプセル化:送信者はK-PKE暗号化アルゴリズムを使用して、受信者のカプセル化鍵(公開鍵)から共通鍵を、さらに暗号文を生成
  • デカプセル化:受信者は送られてきた暗号文をデカプセル化鍵(秘密鍵)で復元し、共通鍵を取り出す。このとき、暗号文が途中で改ざんされていないかを検証し、安全性が確認できた場合のみ、その共通鍵を有効にする
mlkem_core.py
#ML-KEM実装
#n=8 q=17 k=2 の実装例(簡略版)
from ..utils import hash_utils
from ..kpke import kpke_core

#パラメータ
n = 8 #多項式の次数
q = 17 #係数の法
k = 2 #多項式の数 (k=eta) 
psi = pow(3, (q-1) // n, q) #n次の原始根

#ML-KEM鍵生成
def mlkem_keygen(seed1: bytes,seed2: bytes):
    #入力:32バイトの乱数 d,z
    #出力:鍵カプセル化鍵(ek),鍵デカプセル化鍵(dk)
    (ek_PKE, dk_PKE) = kpke_core.k_pke_keygen(seed1)
    ek = ek_PKE
    dk = dk_PKE + ek + hash_utils.hash_H(ek) + seed2
    return(ek,dk)

#ML-KEMカプセル化
def mlkem_encaps(ek: bytes, m: bytes):
    #入力:鍵カプセル化鍵(ek)、乱数メッセージ(m)
    #出力;共通鍵(k_enc),暗号文(c)

    #1.セッション鍵シード生成
    k_seed = m + hash_utils.hash_H(ek)

    #2.共通鍵と乱数生成
    (K, r) = hash_utils.hash_G(k_seed, 1) 

    #3.K-PKE暗号化
    c = kpke_core.k_pke_enc(ek, m, r)

    return(K, c)

#ML-KEMデカプセル化
def mlkem_decaps(dk: bytes, c: bytes):
    #入力:デカプセル化鍵(dk)、暗号文(c_dec)
    #出力:共通鍵(K_dec)

    #1デカプセル化鍵を分解
    dk_PKE = dk[0 : 10]           #10バイト
    ek_PKE = dk[10 : 52]          #10 + 32 = 42バイト
    hashed_ek = dk[52 : 52 + 32]  #ハッシュ関数Hの出力は32バイト
    z = dk[84 : 84 + 32]          #乱数zは32バイト

    #2.暗号文の復元
    m_dec = kpke_core.k_pke_dec(dk_PKE, c) #暗号文を復元し、乱数メッセージを得る

    """
    ML-KEMデカプセル化では、受け取った暗号文を復号して得た情報をもとに、もう一度ゼロから暗号文を再計算
    再計算した暗号文と、受け取った暗号文が完全に一致するかを検証
    <= 一致しなかった場合は、偽の共通鍵を返すことでCCA安全性を保証する
    """

    #3.共通鍵と乱数の生成
    (K_dec, r_dec) = hash_utils.hash_G(m_dec + hashed_ek, 1)

    #4.偽の共通鍵の生成
    fake_K = hash_utils.hash_J(z + c, 32)

    #5.暗号文の再暗号化
    c_dec = kpke_core.k_pke_enc(ek_PKE, m_dec, r_dec)
    if(c != c_dec):
        K_dec = fake_K
    
    return K_dec

ML-KEMアルゴリズムを実装し、その挙動を確かめる

いよいよML-KEMを実装してみます。
ところが実際に走らせてみると、送信者と受信者の共通鍵は残念ながら一致しませんでした。やはり、簡易パラメータでの実装は、暗号化で加わる「ノイズ」に耐えきれず、失敗してしまいます。

ML-KEM Simplified Test (n=8, q=17, k=2)
==================================================

--- 1. 鍵生成中... ---
✅ 公開鍵 (ek) 生成完了 (長さ: 42 bytes)
✅ 秘密鍵 (dk) 生成完了 (長さ: 116 bytes)

--- 2. カプセル化中... ---
クライアント側で生成した秘密の値 (m): e1ef0e733a0e3e2b182a421e01b54b779bdcf060cb9a10b5709d3341d0386a66
✅ 共通鍵と暗号文の生成完了
生成された共通鍵 (K): 6a9a54d2f65439e67d8fb1a9875200061be1c3101ff62d9c22a44113986ae4e3
生成された暗号文 (c): 60199b

--- 3. デカプセル化中... ---
✅ 共通鍵の復元完了
復元された共通鍵 (K_dec): 93386dea91d71a6ae2f3ec98a85ddcf35ee3cd1a26348b499d6c3afc1b43b866

--- 4. 検証結果 ---
❌ 失敗! 共通鍵が一致しませんでした。
==================================================
main.py
import os

from MLKEM.utils import poly_utils as pol
from MLKEM.kpke import kpke_core as pke
from MLKEM.mlkem import mlkem_core as mlkem

n = 8 #多項式の次数
q = 17 #係数の法
k = 2 #多項式の数 (k=eta)
eta = k
psi = pow(3, (q-1) // n, q) #n次の原始根
psi_rev = pow(psi, -1, q) #(ζ^-1)


def run_mlkem_test():
    """
    ML-KEMの鍵生成、カプセル化、デカプセル化のサイクルをテストします。
    """
    print("ML-KEM Simplified Test (n=8, q=17, k=2)")
    print("="*50)

    # --- 1. 鍵生成 (Key Generation) ---
    print("\n--- 1. 鍵生成中... ---")
    # 32バイトの乱数シードを2つ生成 (dとz用)
    d_seed = os.urandom(32)
    z_seed = os.urandom(32)
    ek, dk = mlkem.mlkem_keygen(d_seed, z_seed)
    print(f"✅ 公開鍵 (ek) 生成完了 (長さ: {len(ek)} bytes)")
    print(f"✅ 秘密鍵 (dk) 生成完了 (長さ: {len(dk)} bytes)")

    # --- 2. カプセル化 (Encapsulation) ---
    print("\n--- 2. カプセル化中... ---")
    # 32バイトのランダムなメッセージシードmを生成
    m_seed = os.urandom(32)
    print(f"クライアント側で生成した秘密の値 (m): {m_seed.hex()}")

    K_encaps, c = mlkem.mlkem_encaps(ek, m_seed)
    print(f"✅ 共通鍵と暗号文の生成完了")
    print(f"生成された共通鍵 (K): {K_encaps.hex()}")
    print(f"生成された暗号文 (c): {c.hex()}")

    # --- 3. デカプセル化 (Decapsulation) ---
    print("\n--- 3. デカプセル化中... ---")
    K_decaps = mlkem.mlkem_decaps(dk, c)
    print(f"✅ 共通鍵の復元完了")
    print(f"復元された共通鍵 (K_dec): {K_decaps.hex()}")

    # --- 4. 検証 (Verification) ---
    print("\n--- 4. 検証結果 ---")
    if K_encaps == K_decaps:
        print("🥳 \033[92m成功!\033[0m 生成された共通鍵と復元された共通鍵は一致します。")
    else:
        print("❌ \033[91m失敗!\033[0m 共通鍵が一致しませんでした。")
    
    print("="*50)


if __name__ == "__main__":
    run_mlkem_test()

終わりに

ここまでで、ML-KEMおよびその基盤となるK-PKEアルゴリズムの簡易実装を追いかけてきました。そして、Module-LWE問題の困難性を使用した暗号化、CCA安全なカプセル化、NTTによる高速な多項式演算といった、ML-KEMを支える要素の連携を深く理解することができました。

今回の簡易実装で当初目標としていた「100%の確率での共通鍵の一致」は、最終的に達成できませんでした。しかし、その失敗に至るまでの道のりは、二つの重要な学びを与えてくれました。

一つ目は、暗号の理論と実装の間に存在する、深く、そして見過ごされがちな溝です。論文上の数式は完璧に見えても、それをコードとして現実の世界に持ち込むには、パラメータの選定、ノイズの許容範囲、圧縮による誤差といった、極めて繊細な要素が完璧なバランスで成り立っていなければなりません。今回の実装で直面した「100%の失敗」は、このバランスが簡易パラメータでは原理的に成立しないことを示す、何よりの証拠でした。

二つ目は、PQC導入のハードルがいかに高いかを肌で感じられたことです。これまでは「既存のシステムがPQCに未対応だから大変だ」という表面的な理解でした。しかし、実際に簡易実装でさえパラメータの繊細さに直面したことで、実用レベルの巨大なパラメータが要求する計算量や、それを搭載するためのメモリといった物理的な制約の重みを実感しました。

今後は、この経験を糧に、ソフトウェアや組み込みシステムにPQCをどう統合していくか、アルゴリズムの効率化やサイドチャネル攻撃への耐性といった、より実践的なテーマに挑戦していきたいと考えています。


参考文献:

  1. https://www.cryptrec.go.jp/report/cryptrec-gl-2004-2022.pdf
  2. https://www.cryptrec.go.jp/report/cryptrec-tr-2001-2024.pdf

Discussion