📝
FLDetector: Defending Federated Learning Against Model Poisoning....
FLDetector: Defending Federated Learning Against Model Poisoning Attacks via Detecting Malicious Clients
モデルポイズニング攻撃
悪意のあるクライアントが改竄したモデル更新をサーバに送信することでグローバルモデルを破壊する
攻撃者
注入された偽のクライアントであったり、攻撃者によって侵害された本物のクライアント
非標的型と標的型
- 非標的型モデルポイズニング攻撃では、改竄されたグローバルモデルが、多数のテスト入力に対して無差別に誤った予測を行う。
- 標的型モデルポイズニング攻撃では、改竄されたグローバルモデルがユーザーのテスト入力に対して意図的に誤った予測を導く。他のテスト入力に対するグローバルモデルの精度には影響を与えない。
既存の攻撃検知手法
- Byzantine-robust FL method
- provably robust FL method
欠点
they can only resist a small number of malicious clients. It is still an open challenge how to defend against model poisoning attacks with a large number of malicious clients.
悪意のあるクライアントが少数だった場合にしか効果的でない。
→ 多数の悪意あるクライアントによるモデルポイズニング攻撃をどのように防御するかは、まだ未解決の課題
提案手法
- FLDetector
貢献
- 多数派の悪意のあるクライアントを検出し、削除。(FLDetector)
- Byzantine-robust FL method or provably robust FL methodFL手法へバトンタッチ
FLDetectorの概要
サーバは過去のモデル更新に基づいて各反復におけるクライアントのモデル更新を予測し、クライアントから受信したモデル更新と予測されたモデル更新が複数の反復において矛盾している場合、クライアントを悪意のあるものとしてフラグを立てる。
→ モデル更新の一貫性をチェックすることにより、悪意のあるクライアントを検出する
着目点
モデルポイズニング攻撃では、複数の反復tにおけるクライアントからのモデル更新が矛盾すること。
STEP
- コーシー平均値の定理を用いて、サーバーが過去のモデル更新に基づき、各反復における各クライアントのモデル更新を予測する
- ユークリッド距離を使用して、各反復における各クライアントの予測モデル更新と受信モデル更新の類似度を測定する。
- 反復tにおけるクライアントの疑わしいスコアは、過去N回の反復におけるユークリッド距離の平均(各クライアントの疑わしいスコアを定義し、各反復で動的に更新)
- クライアントの疑わしいスコアに基づくギャップ統計とk-meansを活用する(各反復において悪意のあるクライアントを検出)
- 特に、ある反復における疑わしいスコアとGap統計量に基づき、クライアントが複数のクラスタにグループ化できる場合、以下のようにグループ化する。
クライアントをk-meansを用いて2つのクラスタに分類し、平均疑わしいスコアが大きいクラスタ内のクライアントを悪意があるものとして分類する。
FLDetector
Algorithm
Input: Total training iterations 𝐼𝑡𝑒𝑟 and window size 𝑁 .
Output: Detected malicious clients or none.
1: for 𝑡 = 1, 2, · · · , 𝐼𝑡𝑒𝑟 do
2: ˆ H 𝑡 = L-BFGS(Δ𝑊𝑡 , Δ𝐺𝑡 ).
3: for 𝑖 = 1, 2, · · · , 𝑛 do
4: 𝑔ˆ𝑡 𝑖 = 𝑔𝑡−1 𝑖 +ˆ H 𝑡 (𝑤𝑡 − 𝑤𝑡−1).
5: end for
6: 𝑑𝑡 = [∥𝑔ˆ𝑡 1 − 𝑔𝑡 1 ∥2, ∥𝑔ˆ𝑡 2 − 𝑔𝑡 2 ∥2, · · · , ∥𝑔ˆ𝑡 𝑛 − 𝑔𝑡 𝑛 ∥2].
7: 𝑑ˆ𝑡 = 𝑑𝑡 /∥𝑑𝑡 ∥1.
8: 𝑠𝑡 𝑖= 1 𝑁 ∑︁𝑁 −1 𝑟 =0 𝑑ˆ𝑡−𝑟 𝑖.
9: Determine the number of clusters 𝑘 by Gap statistics.
10: if 𝑘 > 1 then
11: Perform 𝑘-means clustering based on the suspicious scores with 𝑘 = 2.
12: return The clients in the cluster with larger average suspicious score as malicious.
13: end if
14: end for
15: return None.
理解しやすいように擬似コード
# パラメータ設定
Iterations = 10 # 総学習イテレーション数
N = 3 # ウィンドウサイズ
n_clients = 5 # クライアント数
𝑔𝑡−1 = [...]
𝑤𝑡−1 = ...
d_list = [[...], [...]...]
for i in range(Iterations):
𝑔ˆ𝑡 = []
1. L-BFGS法でへシアン行列ˆ H 𝑡を推定
for j in range(n_clients):
2. 𝑔ˆ𝑡.append(𝑔𝑡−1 𝑖+ ˆ H 𝑡 (𝑤𝑡 − 𝑤𝑡−1))
𝑔𝑡 =[...]
3. dt = np.linalg.norm(𝑔ˆ𝑡 - 𝑔𝑡, axis=1)
4. d = dt / np.linalg.norm(dt, ord=1)
5. suspicious_scores = np.mean(d_list[-N:], axis=0)
6. Gap statisticsでクラスタ数が1を超えた場合、k-means(k=2)でクラスタリング
print("悪意があると思わしきクライアントが見つかった")
𝑔𝑡−1 = 𝑔𝑡
𝑤𝑡−1 = 𝑤
print("悪意のあるクライアントは見つからなかった")
なぜFLDetectorなのか?
実験の評価
1.使用されたデータセット
- MNIST
- CIFAR10
- FEMNIST
-
評価対象: 非標的モデルポイズニング攻撃と3つの標的型モデルポイズニング攻撃、さらにFLDetectorに合わせた適応型攻撃に対して評価が行われた。
-
評価結果:
- 非標的モデルポイズニング攻撃: FLDetectorはベースラインの検出方法を上回りました。
- 標的型モデルポイズニング攻撃: 多くのケースでFLDetectorはベースラインを上回り、残りのケースでは同等の検出精度を達成しました。
- 適応型攻撃: FLDetectorは適応型攻撃に対しても有効であることが示されました。
参考文献
Discussion