🔐

Opacusの実装を読んでみる (DP-SGD): Part 2

2024/05/23に公開

はじめに

PyTorch向けのDP-SGDライブラリであるOpacusの実装を引き続き読んでいきます。
前回の記事はこちら

バージョン情報:Opacus v1.4.0
https://opacus.ai/

実装を読んでみる

以下はOpacusを用いた学習の実装例です。(モデル訓練部分)

from opacus.utils.batch_memory_manager import BatchMemoryManager

def train(model, train_loader, optimizer, epoch, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    
    with BatchMemoryManager(
        data_loader=train_loader, 
        max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE, 
        optimizer=optimizer
    ) as memory_safe_data_loader:
        for i, (images, target) in enumerate(memory_safe_data_loader):   
            optimizer.zero_grad()
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            loss.backward()
            optimizer.step()

DPOptimizer

DPOptimizerはPyTorchのOptimizerのサブクラスです。
まずzero_gradメソッドでは、PyTorch標準のgrad属性に加えて、Opacusで新たに追加されるgrad_sample属性とsummed_grad属性もクリアします。

  • grad_sample : サンプルごとの勾配(クリッピング前)
  • summed_grad : ミニバッチ集約後の勾配(ノイズ付加前)
  • grad : 最終的な勾配
optimizers/optimizer.py
    def zero_grad(self, set_to_none: bool = False):
        for p in self.params:
            p.grad_sample = None

            if not self._is_last_step_skipped:
                p.summed_grad = None

        self.original_optimizer.zero_grad(set_to_none)

stepメソッドでは、最適化を実行する前にpre_stepメソッドで勾配のクリッピングとノイズ付加を行います。clip_and_accumulateadd_noisescale_gradstep_hookの順に実行されます。

optimizers/optimizer.py
    def pre_step(
        self, closure: Optional[Callable[[], float]] = None
    ) -> Optional[float]:
        self.clip_and_accumulate()

        self.add_noise()
        self.scale_grad()

        if self.step_hook:
            self.step_hook(self)

    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        if self.pre_step():
            return self.original_optimizer.step()

clip_and_accumulateメソッドは、勾配のクリッピングとミニバッチ集約を行います。具体的には、grad_sampleのL2ノルムから求めたper_sample_clip_factorgrad_sampleを乗算することで、勾配の大きさがmax_grad_normを超えないように制限します。その結果を集約してsummed_gradに格納します。

optimizers/optimizer.py
    def clip_and_accumulate(self):
        if len(self.grad_samples[0]) == 0:
            # Empty batch
            per_sample_clip_factor = torch.zeros((0,))
        else:
            per_param_norms = [
                g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
            ]
            per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
            per_sample_clip_factor = (
                self.max_grad_norm / (per_sample_norms + 1e-6)
            ).clamp(max=1.0)

        for p in self.params:
            grad_sample = self._get_flat_grad_sample(p)
            grad = contract("i,i...", per_sample_clip_factor, grad_sample)

            if p.summed_grad is not None:
                p.summed_grad += grad
            else:
                p.summed_grad = grad

add_noiseメソッドは、勾配へのノイズ付加を行います。 _generate_noise関数で平均が0、分散がnoise_multipliermax_grad_normの積である正規ノイズを生成し、summed_gradに加えたものをgradとします。

optimizers/optimizer.py
    def add_noise(self):
        for p in self.params:
            noise = _generate_noise(
                std=self.noise_multiplier * self.max_grad_norm,
                reference=p.summed_grad,
                generator=self.generator,
                secure_mode=self.secure_mode,
            )
            p.grad = (p.summed_grad + noise).view_as(p)
optimizers/optimizer.py
def _generate_noise(
    std: float,
    reference: torch.Tensor,
    generator=None,
    secure_mode: bool = False,
) -> torch.Tensor:
    if secure_mode:
        ### 省略 ###
    else:
        return torch.normal(
            mean=0,
            std=std,
            size=reference.shape,
            device=reference.device,
            generator=generator,
        )

scale_gradメソッドは、バッチサイズに基づいて勾配をスケーリングします。

optimizers/optimizer.py
    def scale_grad(self):
        if self.loss_reduction == "mean":
            for p in self.params:
                p.grad /= self.expected_batch_size * self.accumulated_iterations

step_hookメソッドは、前回の記事でアタッチしたIAccountanthook_fn関数を実行します。

accountants/accountant.py
    def get_optimizer_hook_fn(
        self, sample_rate: float
    ) -> Callable[[DPOptimizer], None]:
        def hook_fn(optim: DPOptimizer):
            self.step(
                noise_multiplier=optim.noise_multiplier,
                sample_rate=sample_rate * optim.accumulated_iterations,
            )

        return hook_fn

PRVAccountant

PRVAccountantIAccountantのサブクラスで、下記論文の方法でプライバシーバジェットを計算します。
https://arxiv.org/abs/2106.02848

stepメソッドは、上述のDPOptimizerpre_stepメソッドの最後に呼び出され、historyにノイズ乗数、サンプルレート、ステップ数を記録します。

accountants/prv.py
    def step(self, *, noise_multiplier: float, sample_rate: float):
        if len(self.history) >= 1:
            (last_noise_multiplier, last_sample_rate, num_steps) = self.history.pop()
            if (
                last_noise_multiplier == noise_multiplier
                and last_sample_rate == sample_rate
            ):
                self.history.append(
                    (last_noise_multiplier, last_sample_rate, num_steps + 1)
                )
            else:
                self.history.append(
                    (last_noise_multiplier, last_sample_rate, num_steps)
                )
                self.history.append((noise_multiplier, sample_rate, 1))

        else:
            self.history.append((noise_multiplier, sample_rate, 1))

get_epsilonメソッドは、\deltahistoryの記録から消費されたプライバシーバジェット \epsilon を計算します。計算方法の詳細は複雑なため今回は割愛します。

accountants/prv.py
    def get_epsilon(
        self, delta: float, *, eps_error: float = 0.01, delta_error: float = None
    ) -> float:
        if delta_error is None:
            delta_error = delta / 1000
        # we construct a discrete PRV from the history
        dprv = self._get_dprv(eps_error=eps_error, delta_error=delta_error)
        # this discrete PRV can be used to directly estimate and bound epsilon
        _, _, eps_upper = dprv.compute_epsilon(delta, delta_error, eps_error)
        # return upper bound as we want guarantee, not just estimate
        return eps_upper

続く

モデル訓練中の処理は以上になります。
次回は、サンプルごとの勾配を計算するGradSampleModuleクラスの詳細を見ていきます。
続きの記事はこちらです。

Discussion