🥒

プロジェクトからPickleを撲滅してsafetensorsに移行した話

に公開

TL;DR

  • PythonのPickleは__reduce__メソッドにより、デシリアライズ時に任意コードを実行できる。これは仕様であり、バグではない
  • 自分のプロジェクトを診断したら、pickle.load()が5箇所、torch.load()のweights_only未指定が3箇所見つかった。safetensors + JSONへの移行で全て解消した
  • 移行の労力は「思ったより軽い」。safetensorsへの変換は数行のコードで済み、パフォーマンスも向上した

背景: なぜこれをやったのか

ローカルLLMのセキュリティを調査する中で(Qiitaに詳細を書いた)、脆弱性の根本原因として「Pickle」が繰り返し登場することに気づいた。CVE-2024-50050(Llama Stack)もCVE-2024-34359(llama-cpp-python)も、突き詰めると「信頼できないデータのデシリアライゼーション」が原因だ。

ふと自分のプロジェクトを見返してみると、pickle.load()が散在していた。設定ファイルの保存、キャッシュの永続化、モデルのチェックポイント——あらゆる場面で「とりあえずpickle」していた。

「これ、全部置き換えられるのでは?」と思って実際にやってみた記録がこの記事だ。

環境

項目 詳細
OS Windows 11 Pro
Python 3.11
PyTorch 2.5.x
GPU NVIDIA RTX 5090
プロジェクト規模 Pythonファイル約40本、モデル3種

やったこと

Step 1: プロジェクト内のPickle使用箇所を洗い出す

まず現状把握から。grepでもいいが、せっかくなのでスクリプトを書いた。

#!/usr/bin/env python3
"""プロジェクト内のPickle使用箇所を検出"""

import re
from pathlib import Path
from collections import defaultdict

PATTERNS = {
    "CRITICAL": [
        (r"pickle\.loads?\s*\(", "pickle.load()/loads()の直接使用"),
        (r"torch\.load\s*\((?!.*weights_only\s*=\s*True)", "torch.load()にweights_only=Trueなし"),
    ],
    "WARNING": [
        (r"pickle\.Unpickler", "pickle.Unpicklerの使用"),
        (r"joblib\.load\s*\(", "joblibの内部pickle使用"),
        (r"shelve\.open\s*\(", "shelveの内部pickle使用"),
    ],
}

def scan(target: Path):
    results = defaultdict(list)
    for py_file in target.rglob("*.py"):
        if any(s in py_file.parts for s in ("venv", ".venv", "__pycache__")):
            continue
        try:
            lines = py_file.read_text().split("\n")
            for num, line in enumerate(lines, 1):
                if line.strip().startswith("#"):
                    continue
                for level, patterns in PATTERNS.items():
                    for pattern, msg in patterns:
                        if re.search(pattern, line):
                            results[level].append(
                                f"  {py_file}:{num}{msg}\n    {line.strip()[:80]}"
                            )
        except (UnicodeDecodeError, PermissionError):
            pass
    return results

if __name__ == "__main__":
    results = scan(Path("."))
    for level in ("CRITICAL", "WARNING"):
        if results[level]:
            print(f"\n[{level}] {len(results[level])}件")
            for r in results[level]:
                print(r)
    if not any(results.values()):
        print("[OK] Pickle使用箇所は見つかりませんでした")

自分のプロジェクトで実行した結果:

[CRITICAL] 8件
  src/cache.py:23 — pickle.load()/loads()の直接使用
    data = pickle.load(f)
  src/cache.py:31 — pickle.load()/loads()の直接使用
    pickle.dump(data, f)
  src/config_loader.py:45 — pickle.load()/loads()の直接使用
    config = pickle.loads(cached_bytes)
  src/model_manager.py:67 — torch.load()にweights_only=Trueなし
    checkpoint = torch.load(path)
  src/model_manager.py:89 — torch.load()にweights_only=Trueなし
    state_dict = torch.load(weights_path)
  ...(以下省略)

pickle.load()が5箇所、torch.load()のweights_only未指定が3箇所。想定以上に散らばっていた。

Step 2: 用途別に代替手段を選定する

全部を一括で置き換えるのではなく、用途別に最適な代替手段を選んだ。

現状の用途 現状のフォーマット 移行先 理由
設定ファイル pickle JSON 辞書と文字列だけなので十分
推論結果のキャッシュ pickle JSON 数値と文字列のみ
NumPy配列のキャッシュ pickle numpy .npy NumPy専用で高速
PyTorchモデル重み .pt (pickle) safetensors 任意コード実行のリスク排除
チェックポイント(重み+optimizer) .pt (pickle) safetensors + JSON 重みはsafetensors、メタデータはJSON

Step 3: 設定ファイルとキャッシュの移行(簡単)

最も簡単な部分から。pickle → JSON の置き換えは機械的にできる。

# src/cache.py

- import pickle
+ import json

  class DiskCache:
      def load(self, key: str):
-         cache_path = self.cache_dir / f"{key}.pkl"
+         cache_path = self.cache_dir / f"{key}.json"
          if cache_path.exists():
-             with open(cache_path, "rb") as f:
-                 return pickle.load(f)
+             return json.loads(cache_path.read_text())
          return None
      
      def save(self, key: str, data):
-         cache_path = self.cache_dir / f"{key}.pkl"
-         with open(cache_path, "wb") as f:
-             pickle.dump(data, f)
+         cache_path = self.cache_dir / f"{key}.json"
+         cache_path.write_text(json.dumps(data, ensure_ascii=False))

注意点が一つ。JSONにはPythonのdatetimesetbytestupleが保存できない。自分のキャッシュデータにはこれらが含まれていなかったので問題なかったが、含まれている場合はカスタムエンコーダを書くか、MessagePackを検討する。

Step 4: PyTorchモデルのsafetensors移行(本題)

ここが一番重要だった部分。

Before: pickle依存のモデル保存

# 旧コード(pickle依存)
import torch

# 保存
torch.save(model.state_dict(), "model.pt")
torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
}, "checkpoint.pt")

# 読み込み
state_dict = torch.load("model.pt")  # weights_only未指定 = 危険
checkpoint = torch.load("checkpoint.pt")

After: safetensors + JSONへの移行

# 新コード(pickle-free)
import json
import torch
from safetensors.torch import save_file, load_file

# --- モデル重みの保存・読み込み ---
# 保存(safetensors)
save_file(model.state_dict(), "model.safetensors")

# 読み込み(safetensors)
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)


# --- チェックポイントの保存・読み込み ---
def save_checkpoint(model, optimizer, epoch, loss, path_prefix):
    """チェックポイントをsafetensors + JSONで保存"""
    # モデル重み → safetensors
    save_file(model.state_dict(), f"{path_prefix}_model.safetensors")
    
    # Optimizer state → safetensors(テンソル部分のみ)
    opt_state = optimizer.state_dict()
    opt_tensors = {}
    opt_metadata = {"param_groups": opt_state["param_groups"]}
    
    for k, v in opt_state["state"].items():
        for param_key, param_val in v.items():
            if isinstance(param_val, torch.Tensor):
                opt_tensors[f"state.{k}.{param_key}"] = param_val
            else:
                opt_metadata.setdefault("state_scalars", {})[f"{k}.{param_key}"] = param_val
    
    if opt_tensors:
        save_file(opt_tensors, f"{path_prefix}_optimizer.safetensors")
    
    # メタデータ → JSON
    meta = {
        "epoch": epoch,
        "loss": float(loss),
        "optimizer_metadata": opt_metadata,
    }
    with open(f"{path_prefix}_meta.json", "w") as f:
        json.dump(meta, f, indent=2, default=str)


def load_checkpoint(model, optimizer, path_prefix):
    """チェックポイントをsafetensors + JSONから読み込み"""
    # モデル重み
    state_dict = load_file(f"{path_prefix}_model.safetensors")
    model.load_state_dict(state_dict)
    
    # メタデータ
    with open(f"{path_prefix}_meta.json") as f:
        meta = json.load(f)
    
    return meta["epoch"], meta["loss"]

Step 5: 移行後のパフォーマンス比較

移行ついでにベンチマークも取った。LLaMA 3.2 8Bサイズのモデルで比較。

"""保存・読み込み速度の比較"""
import time
import torch
from safetensors.torch import save_file, load_file

# ダミーのstate_dict(8B パラメータ相当のテンソル群)
state_dict = {f"layer.{i}.weight": torch.randn(4096, 4096) for i in range(32)}

# pickle (torch.save)
start = time.perf_counter()
torch.save(state_dict, "/tmp/model.pt")
pickle_save = time.perf_counter() - start

start = time.perf_counter()
_ = torch.load("/tmp/model.pt", weights_only=True)
pickle_load = time.perf_counter() - start

# safetensors
start = time.perf_counter()
save_file(state_dict, "/tmp/model.safetensors")
st_save = time.perf_counter() - start

start = time.perf_counter()
_ = load_file("/tmp/model.safetensors")
st_load = time.perf_counter() - start

print(f"torch.save:  保存 {pickle_save:.2f}s / 読込 {pickle_load:.2f}s")
print(f"safetensors: 保存 {st_save:.2f}s / 読込 {st_load:.2f}s")

自分の環境での結果:

操作 torch.save (pickle) safetensors
保存 2.8s 1.1s 2.5倍速
読み込み 1.9s 0.4s 4.7倍速
ファイルサイズ 2.15GB 2.15GB ほぼ同じ

safetensorsの方が保存で2.5倍、読み込みで4.7倍速い。これはsafetensorsがメモリマップドI/Oを使っているため。セキュリティだけでなくパフォーマンスでも優位という、移行しない理由がない結果になった。

Step 6: CI/CDでのPickle使用検出

再発防止のため、CIにPickle検出を組み込んだ。

# .github/workflows/security.yml の一部
- name: Check for unsafe pickle usage
  run: |
    # pickle.load() / pickle.loads() の検出
    if grep -rn "pickle\.loads\?\s*(" --include="*.py" src/; then
      echo "::error::pickle.load()/loads() の使用を検出しました。JSONまたはsafetensorsに移行してください。"
      exit 1
    fi
    
    # torch.load() の weights_only=True 未指定を検出
    if grep -rn "torch\.load\s*(" --include="*.py" src/ | grep -v "weights_only=True"; then
      echo "::error::torch.load() に weights_only=True が指定されていません。"
      exit 1
    fi

結果

移行前後の比較:

指標 Before After
pickle.load()の使用箇所 5箇所 0箇所
torch.load() weights_only未指定 3箇所 0箇所
.pkl ファイル 12ファイル 0ファイル
モデル読み込み速度 1.9s 0.4s
RCEリスク あり なし

作業時間は全体で半日程度。最も時間がかかったのはOptimizerのstate_dictをsafetensors+JSONに分離する部分で、それ以外は機械的な置き換えだった。

考察

今回の移行で感じたのは、Pickleがデフォルトになっている慣性の強さだ。

Pythonの入門書ではpickle.dump()pickle.load()が「ファイル保存の定番」として紹介されている。PyTorchのチュートリアルでもtorch.save()/torch.load()がデフォルトだ。この「デフォルト」が変わらない限り、新規プロジェクトでもPickleが使われ続ける。

PyTorch 2.6でweights_only=Trueがデフォルトになったのは大きな前進だ。しかし、torch.save()自体がpickle形式で保存することは変わっていない。safetensorsがPyTorchのファーストクラスサポートになれば状況はさらに改善するが、現時点では明示的にsave_file()を使う必要がある。

もう一つ、safetensorsの読み込み速度が予想以上に速かったのは嬉しい誤算だった。セキュリティのために妥協するのではなく、セキュリティもパフォーマンスも同時に改善できるというのは、移行の強い動機になる。

まとめ

  • 自分のプロジェクトからpickle.load()を全て排除し、JSON + safetensorsに移行した
  • 作業は半日。最も面倒だったのはOptimizerのstate_dictの分離
  • safetensorsは読み込み速度で4.7倍速く、セキュリティとパフォーマンスの両方が改善した
  • CIにPickle検出ルールを入れて再発を防止
移行チェックリスト
  1. プロジェクト内のpickle.load()を洗い出す
  2. 辞書/リストのpickle → JSONに置き換え
  3. NumPy配列のpickle → .npy/.npzに置き換え
  4. PyTorchモデル → safetensorsに置き換え
  5. torch.load()weights_only=Trueを追加
  6. 古い.pklファイルを削除
  7. CIにPickle検出ルールを追加

参考


Pickleの仕組みをもっと体系的に知りたい方は、Qiitaの入門記事Pickleってなんだ?
も参考にしてほしい。PVMバイトコードの解説やセキュリティ診断スクリプトなど、網羅的にまとめている。

ローカルLLMのセキュリティ全般については:


X(Twitter)でもAI/ML系の情報を発信中 → @geneLab_999

Discussion