プロジェクトからPickleを撲滅してsafetensorsに移行した話
TL;DR
- PythonのPickleは
__reduce__メソッドにより、デシリアライズ時に任意コードを実行できる。これは仕様であり、バグではない - 自分のプロジェクトを診断したら、
pickle.load()が5箇所、torch.load()のweights_only未指定が3箇所見つかった。safetensors + JSONへの移行で全て解消した - 移行の労力は「思ったより軽い」。safetensorsへの変換は数行のコードで済み、パフォーマンスも向上した
背景: なぜこれをやったのか
ローカルLLMのセキュリティを調査する中で(Qiitaに詳細を書いた)、脆弱性の根本原因として「Pickle」が繰り返し登場することに気づいた。CVE-2024-50050(Llama Stack)もCVE-2024-34359(llama-cpp-python)も、突き詰めると「信頼できないデータのデシリアライゼーション」が原因だ。
ふと自分のプロジェクトを見返してみると、pickle.load()が散在していた。設定ファイルの保存、キャッシュの永続化、モデルのチェックポイント——あらゆる場面で「とりあえずpickle」していた。
「これ、全部置き換えられるのでは?」と思って実際にやってみた記録がこの記事だ。
環境
| 項目 | 詳細 |
|---|---|
| OS | Windows 11 Pro |
| Python | 3.11 |
| PyTorch | 2.5.x |
| GPU | NVIDIA RTX 5090 |
| プロジェクト規模 | Pythonファイル約40本、モデル3種 |
やったこと
Step 1: プロジェクト内のPickle使用箇所を洗い出す
まず現状把握から。grepでもいいが、せっかくなのでスクリプトを書いた。
#!/usr/bin/env python3
"""プロジェクト内のPickle使用箇所を検出"""
import re
from pathlib import Path
from collections import defaultdict
PATTERNS = {
"CRITICAL": [
(r"pickle\.loads?\s*\(", "pickle.load()/loads()の直接使用"),
(r"torch\.load\s*\((?!.*weights_only\s*=\s*True)", "torch.load()にweights_only=Trueなし"),
],
"WARNING": [
(r"pickle\.Unpickler", "pickle.Unpicklerの使用"),
(r"joblib\.load\s*\(", "joblibの内部pickle使用"),
(r"shelve\.open\s*\(", "shelveの内部pickle使用"),
],
}
def scan(target: Path):
results = defaultdict(list)
for py_file in target.rglob("*.py"):
if any(s in py_file.parts for s in ("venv", ".venv", "__pycache__")):
continue
try:
lines = py_file.read_text().split("\n")
for num, line in enumerate(lines, 1):
if line.strip().startswith("#"):
continue
for level, patterns in PATTERNS.items():
for pattern, msg in patterns:
if re.search(pattern, line):
results[level].append(
f" {py_file}:{num} — {msg}\n {line.strip()[:80]}"
)
except (UnicodeDecodeError, PermissionError):
pass
return results
if __name__ == "__main__":
results = scan(Path("."))
for level in ("CRITICAL", "WARNING"):
if results[level]:
print(f"\n[{level}] {len(results[level])}件")
for r in results[level]:
print(r)
if not any(results.values()):
print("[OK] Pickle使用箇所は見つかりませんでした")
自分のプロジェクトで実行した結果:
[CRITICAL] 8件
src/cache.py:23 — pickle.load()/loads()の直接使用
data = pickle.load(f)
src/cache.py:31 — pickle.load()/loads()の直接使用
pickle.dump(data, f)
src/config_loader.py:45 — pickle.load()/loads()の直接使用
config = pickle.loads(cached_bytes)
src/model_manager.py:67 — torch.load()にweights_only=Trueなし
checkpoint = torch.load(path)
src/model_manager.py:89 — torch.load()にweights_only=Trueなし
state_dict = torch.load(weights_path)
...(以下省略)
pickle.load()が5箇所、torch.load()のweights_only未指定が3箇所。想定以上に散らばっていた。
Step 2: 用途別に代替手段を選定する
全部を一括で置き換えるのではなく、用途別に最適な代替手段を選んだ。
| 現状の用途 | 現状のフォーマット | 移行先 | 理由 |
|---|---|---|---|
| 設定ファイル | pickle | JSON | 辞書と文字列だけなので十分 |
| 推論結果のキャッシュ | pickle | JSON | 数値と文字列のみ |
| NumPy配列のキャッシュ | pickle | numpy .npy | NumPy専用で高速 |
| PyTorchモデル重み | .pt (pickle) | safetensors | 任意コード実行のリスク排除 |
| チェックポイント(重み+optimizer) | .pt (pickle) | safetensors + JSON | 重みはsafetensors、メタデータはJSON |
Step 3: 設定ファイルとキャッシュの移行(簡単)
最も簡単な部分から。pickle → JSON の置き換えは機械的にできる。
# src/cache.py
- import pickle
+ import json
class DiskCache:
def load(self, key: str):
- cache_path = self.cache_dir / f"{key}.pkl"
+ cache_path = self.cache_dir / f"{key}.json"
if cache_path.exists():
- with open(cache_path, "rb") as f:
- return pickle.load(f)
+ return json.loads(cache_path.read_text())
return None
def save(self, key: str, data):
- cache_path = self.cache_dir / f"{key}.pkl"
- with open(cache_path, "wb") as f:
- pickle.dump(data, f)
+ cache_path = self.cache_dir / f"{key}.json"
+ cache_path.write_text(json.dumps(data, ensure_ascii=False))
注意点が一つ。JSONにはPythonのdatetime、set、bytes、tupleが保存できない。自分のキャッシュデータにはこれらが含まれていなかったので問題なかったが、含まれている場合はカスタムエンコーダを書くか、MessagePackを検討する。
Step 4: PyTorchモデルのsafetensors移行(本題)
ここが一番重要だった部分。
Before: pickle依存のモデル保存
# 旧コード(pickle依存)
import torch
# 保存
torch.save(model.state_dict(), "model.pt")
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
}, "checkpoint.pt")
# 読み込み
state_dict = torch.load("model.pt") # weights_only未指定 = 危険
checkpoint = torch.load("checkpoint.pt")
After: safetensors + JSONへの移行
# 新コード(pickle-free)
import json
import torch
from safetensors.torch import save_file, load_file
# --- モデル重みの保存・読み込み ---
# 保存(safetensors)
save_file(model.state_dict(), "model.safetensors")
# 読み込み(safetensors)
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
# --- チェックポイントの保存・読み込み ---
def save_checkpoint(model, optimizer, epoch, loss, path_prefix):
"""チェックポイントをsafetensors + JSONで保存"""
# モデル重み → safetensors
save_file(model.state_dict(), f"{path_prefix}_model.safetensors")
# Optimizer state → safetensors(テンソル部分のみ)
opt_state = optimizer.state_dict()
opt_tensors = {}
opt_metadata = {"param_groups": opt_state["param_groups"]}
for k, v in opt_state["state"].items():
for param_key, param_val in v.items():
if isinstance(param_val, torch.Tensor):
opt_tensors[f"state.{k}.{param_key}"] = param_val
else:
opt_metadata.setdefault("state_scalars", {})[f"{k}.{param_key}"] = param_val
if opt_tensors:
save_file(opt_tensors, f"{path_prefix}_optimizer.safetensors")
# メタデータ → JSON
meta = {
"epoch": epoch,
"loss": float(loss),
"optimizer_metadata": opt_metadata,
}
with open(f"{path_prefix}_meta.json", "w") as f:
json.dump(meta, f, indent=2, default=str)
def load_checkpoint(model, optimizer, path_prefix):
"""チェックポイントをsafetensors + JSONから読み込み"""
# モデル重み
state_dict = load_file(f"{path_prefix}_model.safetensors")
model.load_state_dict(state_dict)
# メタデータ
with open(f"{path_prefix}_meta.json") as f:
meta = json.load(f)
return meta["epoch"], meta["loss"]
Step 5: 移行後のパフォーマンス比較
移行ついでにベンチマークも取った。LLaMA 3.2 8Bサイズのモデルで比較。
"""保存・読み込み速度の比較"""
import time
import torch
from safetensors.torch import save_file, load_file
# ダミーのstate_dict(8B パラメータ相当のテンソル群)
state_dict = {f"layer.{i}.weight": torch.randn(4096, 4096) for i in range(32)}
# pickle (torch.save)
start = time.perf_counter()
torch.save(state_dict, "/tmp/model.pt")
pickle_save = time.perf_counter() - start
start = time.perf_counter()
_ = torch.load("/tmp/model.pt", weights_only=True)
pickle_load = time.perf_counter() - start
# safetensors
start = time.perf_counter()
save_file(state_dict, "/tmp/model.safetensors")
st_save = time.perf_counter() - start
start = time.perf_counter()
_ = load_file("/tmp/model.safetensors")
st_load = time.perf_counter() - start
print(f"torch.save: 保存 {pickle_save:.2f}s / 読込 {pickle_load:.2f}s")
print(f"safetensors: 保存 {st_save:.2f}s / 読込 {st_load:.2f}s")
自分の環境での結果:
| 操作 | torch.save (pickle) | safetensors | 差 |
|---|---|---|---|
| 保存 | 2.8s | 1.1s | 2.5倍速 |
| 読み込み | 1.9s | 0.4s | 4.7倍速 |
| ファイルサイズ | 2.15GB | 2.15GB | ほぼ同じ |
safetensorsの方が保存で2.5倍、読み込みで4.7倍速い。これはsafetensorsがメモリマップドI/Oを使っているため。セキュリティだけでなくパフォーマンスでも優位という、移行しない理由がない結果になった。
Step 6: CI/CDでのPickle使用検出
再発防止のため、CIにPickle検出を組み込んだ。
# .github/workflows/security.yml の一部
- name: Check for unsafe pickle usage
run: |
# pickle.load() / pickle.loads() の検出
if grep -rn "pickle\.loads\?\s*(" --include="*.py" src/; then
echo "::error::pickle.load()/loads() の使用を検出しました。JSONまたはsafetensorsに移行してください。"
exit 1
fi
# torch.load() の weights_only=True 未指定を検出
if grep -rn "torch\.load\s*(" --include="*.py" src/ | grep -v "weights_only=True"; then
echo "::error::torch.load() に weights_only=True が指定されていません。"
exit 1
fi
結果
移行前後の比較:
| 指標 | Before | After |
|---|---|---|
| pickle.load()の使用箇所 | 5箇所 | 0箇所 |
| torch.load() weights_only未指定 | 3箇所 | 0箇所 |
| .pkl ファイル | 12ファイル | 0ファイル |
| モデル読み込み速度 | 1.9s | 0.4s |
| RCEリスク | あり | なし |
作業時間は全体で半日程度。最も時間がかかったのはOptimizerのstate_dictをsafetensors+JSONに分離する部分で、それ以外は機械的な置き換えだった。
考察
今回の移行で感じたのは、Pickleがデフォルトになっている慣性の強さだ。
Pythonの入門書ではpickle.dump()とpickle.load()が「ファイル保存の定番」として紹介されている。PyTorchのチュートリアルでもtorch.save()/torch.load()がデフォルトだ。この「デフォルト」が変わらない限り、新規プロジェクトでもPickleが使われ続ける。
PyTorch 2.6でweights_only=Trueがデフォルトになったのは大きな前進だ。しかし、torch.save()自体がpickle形式で保存することは変わっていない。safetensorsがPyTorchのファーストクラスサポートになれば状況はさらに改善するが、現時点では明示的にsave_file()を使う必要がある。
もう一つ、safetensorsの読み込み速度が予想以上に速かったのは嬉しい誤算だった。セキュリティのために妥協するのではなく、セキュリティもパフォーマンスも同時に改善できるというのは、移行の強い動機になる。
まとめ
- 自分のプロジェクトから
pickle.load()を全て排除し、JSON + safetensorsに移行した - 作業は半日。最も面倒だったのはOptimizerのstate_dictの分離
- safetensorsは読み込み速度で4.7倍速く、セキュリティとパフォーマンスの両方が改善した
- CIにPickle検出ルールを入れて再発を防止
移行チェックリスト
-
プロジェクト内の
pickle.load()を洗い出す - 辞書/リストのpickle → JSONに置き換え
- NumPy配列のpickle → .npy/.npzに置き換え
- PyTorchモデル → safetensorsに置き換え
-
torch.load()にweights_only=Trueを追加 - 古い.pklファイルを削除
- CIにPickle検出ルールを追加
参考
- Hugging Face - Safetensors Documentation
- Python公式ドキュメント - pickle
- PyTorch - torch.load Security Advisory
- Trail of Bits - Fickling
- ReversingLabs - Malicious ML models on Hugging Face
Pickleの仕組みをもっと体系的に知りたい方は、Qiitaの入門記事Pickleってなんだ?
も参考にしてほしい。PVMバイトコードの解説やセキュリティ診断スクリプトなど、網羅的にまとめている。
ローカルLLMのセキュリティ全般については:
Discussion