🔐

Opacusの実装を読んでみる (DP-SGD): Part 3

2024/05/24に公開

はじめに

PyTorch向けのDP-SGDライブラリであるOpacusの実装を引き続き読んでいきます。
前回の記事はこちら

バージョン情報:Opacus v1.4.0
https://opacus.ai/

実装を読んでみる

DP-SGDではサンプルごとの勾配が必要になりますが、PyTorchにはそれを取得する機能がありません。そのため、Opacusでは以下の3種類の方式が実装されています。

  • Hooks-based approach
  • Functorch approach
  • ExpandedWeigths approach

執筆時点のデフォルトは"Hooks-based approach"で、対応していないモジュールのみ"Functorch approach"を使用します。よって、この方式を実装しているGradSampleModuleクラスの処理内容を見ていきます。

GradSampleModule & add_hooks

GradSampleModuleクラスの初期化メソッドです。add_hooksメソッドでサンプルごとの勾配を計算するためのフックを追加します。

grad_sample/grad_sample_module.py
class GradSampleModule(AbstractGradSampleModule):
    def __init__(
        self,
        m: nn.Module,
        *,
        batch_first=True,
        loss_reduction="mean",
        strict: bool = True,
        force_functorch=False,
    ):
        self.hooks_enabled = False
        self.batch_first = batch_first
        self.loss_reduction = loss_reduction
        self.force_functorch = force_functorch
        self.add_hooks(
            loss_reduction=loss_reduction,
            batch_first=batch_first,
            force_functorch=force_functorch,
        )

add_hooksメソッドは、順伝播中に活性化を保存するためのcapture_activations_hookメソッドと、逆伝播中にサンプルごとの勾配を計算するためのcapture_backprops_hookメソッドを、フックとして追加します。また、これらのフックはiterate_submodulesメソッドが返すサブモジュール、すなわち訓練可能なパラメータを直接持つ全てのモジュールに追加されます。なお、GRAD_SAMPLERSに登録されていないモジュールは"Functorch approach"で処理するため、prepare_layer関数を通します。

grad_sample/grad_sample_module.py
    def iterate_submodules(self, module: nn.Module) -> Iterable[nn.Module]:
        if has_trainable_params(module):
            yield module

        # Don't recurse if module is handled by functorch
        if (
            has_trainable_params(module)
            and type(module) not in self.GRAD_SAMPLERS
            and type(module) not in [DPRNN, DPLSTM, DPGRU]
        ):
            return

        for m in module.children():
            yield from self.iterate_submodules(m)

    def add_hooks(
        self,
        *,
        loss_reduction: str = "mean",
        batch_first: bool = True,
        force_functorch: bool = False,
    ) -> None:
        for module in self.iterate_submodules(self._module):
            # Do not add hooks to DPRNN, DPLSTM or DPGRU as the hooks are handled by the `RNNLinear`
            if type(module) in [DPRNN, DPLSTM, DPGRU]:
                continue

            if force_functorch or not type(module) in self.GRAD_SAMPLERS:
                prepare_layer(module, batch_first=batch_first)

            self.autograd_grad_sample_hooks.append(
                module.register_forward_hook(self.capture_activations_hook)
            )

            self.autograd_grad_sample_hooks.append(
                module.register_backward_hook(
                    partial(
                        self.capture_backprops_hook,
                        loss_reduction=loss_reduction,
                        batch_first=batch_first,
                    )
                )
            )

prepare_layer関数は、"Functorch approach"で処理するモジュールに対して前処理を行います。具体的には、functorchのmake_functional関数を使って関数化したモジュールflayerについて、損失を計算するcompute_loss_stateless_model関数を定義します。この関数とfunctorchのgrad関数、vmap関数を用いて、サンプルごとの勾配を計算するft_compute_sample_gradメソッドを定義します。

grad_sample/functorch.py
def prepare_layer(layer, batch_first=True):
    from functorch import grad, make_functional, vmap

    flayer, _ = make_functional(layer)

    def compute_loss_stateless_model(params, activations, backprops):
        if batch_first or type(layer) is RNNLinear:
            batched_activations = activations.unsqueeze(0)
            batched_backprops = backprops.unsqueeze(0)
        else:
            # If batch_first is False, the batch dimension is the second dimension
            batched_activations = activations.unsqueeze(1)
            batched_backprops = backprops.unsqueeze(1)

        output = flayer(params, batched_activations)
        loss = (output * batched_backprops).sum()

        return loss

    ft_compute_grad = grad(compute_loss_stateless_model)
    # Note that the vmap is done on the first dimension, regardless of batch_first
    # This is because the activations and backprops given by the GradSampleModule
    # are always batch_first=True
    layer.ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))

register_grad_sampler

"Hooks-based approach"で処理できるモジュールを登録するGRAD_SAMPLERSには、デコレータ関数register_grad_samplerを使って登録します。

grad_sample/utils.py
def register_grad_sampler(
    target_class_or_classes: Union[Type[nn.Module], Sequence[Type[nn.Module]]]
):
    def decorator(f):
        target_classes = (
            target_class_or_classes
            if isinstance(target_class_or_classes, Sequence)
            else [target_class_or_classes]
        )
        for target_class in target_classes:
            GradSampleModule.GRAD_SAMPLERS[target_class] = f
        return f

    return decorator

例えばnn.Linearモジュールが登録されています。

grad_sample/linear.py
@register_grad_sampler(nn.Linear)
def compute_linear_grad_sample(
    layer: nn.Linear, activations: List[torch.Tensor], backprops: torch.Tensor
) -> Dict[nn.Parameter, torch.Tensor]:
    activations = activations[0]
    ret = {}
    if layer.weight.requires_grad:
        gs = contract("n...i,n...j->nij", backprops, activations)
        ret[layer.weight] = gs
    if layer.bias is not None and layer.bias.requires_grad:
        ret[layer.bias] = contract("n...k->nk", backprops)
    return ret

capture_activations_hook & capture_backprops_hook

capture_activations_hookメソッドは、モジュールが順伝播を行うたびに呼び出され、activations属性に順伝播の入力を保存します。

grad_sample/grad_sample_module.py
    def capture_activations_hook(
        self,
        module: nn.Module,
        forward_input: List[torch.Tensor],
        _forward_output: torch.Tensor,
    ):
        if not hasattr(module, "activations"):
            module.activations = []
        module.activations.append([t.detach() for t in forward_input])  # pyre-ignore

        for _, p in trainable_parameters(module):
            p._forward_counter += 1

capture_backprops_hookメソッドは、モジュールの逆伝播中に呼び出され、逆伝播の出力backpropsと順伝播で保存したactivationsからサンプルごとの勾配を計算します。この計算はgrad_sampler_fn関数、すなわちGRAD_SAMPLERSに登録されている関数(例:nn.Linearの場合、compute_linear_grad_sample関数)もしくはft_compute_per_sample_gradient関数が行います。前者は"Hooks-based approach"、後者は"Functorch approach"です。

grad_sample/grad_sample_module.py
    def capture_backprops_hook(
        self,
        module: nn.Module,
        _forward_input: torch.Tensor,
        forward_output: torch.Tensor,
        loss_reduction: str,
        batch_first: bool,
    ):
        backprops = forward_output[0].detach()
        activations, backprops = self.rearrange_grad_samples(
            module=module,
            backprops=backprops,
            loss_reduction=loss_reduction,
            batch_first=batch_first,
        )
        if not self.force_functorch and type(module) in self.GRAD_SAMPLERS:
            grad_sampler_fn = self.GRAD_SAMPLERS[type(module)]
        else:
            grad_sampler_fn = ft_compute_per_sample_gradient

        grad_samples = grad_sampler_fn(module, activations, backprops)
        for param, gs in grad_samples.items():
            create_or_accumulate_grad_sample(
                param=param, grad_sample=gs, max_batch_len=module.max_batch_len
            )

        for _, p in trainable_parameters(module):
            p._forward_counter -= 1
            if p._forward_counter == 0:
                promote_current_grad_sample(p)

ft_compute_per_sample_gradient関数は、上述のprepare_layer関数で定義したft_compute_sample_gradメソッドを実行して、サンプルごとの勾配を計算します。

grad_sample/functorch.py
def ft_compute_per_sample_gradient(layer, activations, backprops):
    parameters = list(layer.parameters(recurse=True))

    per_sample_grads = layer.ft_compute_sample_grad(
        parameters, activations[0], backprops
    )

    ret = {}
    for i_p, p in enumerate(parameters):
        ret[p] = per_sample_grads[i_p]

    return ret

create_or_accumulate_grad_sample関数により、サンプルごとの勾配が_current_grad_sample属性に蓄積されます。最後にpromote_current_grad_sample関数によって、_current_grad_sample属性がgrad_sample属性に昇格されます。

grad_sample/grad_sample_module.py
def create_or_accumulate_grad_sample(
    *, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int
) -> None:
    if param.requires_grad:
        if hasattr(param, "_current_grad_sample"):
            param._current_grad_sample[: grad_sample.shape[0]] += grad_sample
        else:
            param._current_grad_sample = torch.zeros(
                torch.Size([max_batch_len]) + grad_sample.shape[1:],
                device=grad_sample.device,
                dtype=grad_sample.dtype,
            )
            param._current_grad_sample[: grad_sample.shape[0]] = grad_sample

def promote_current_grad_sample(p: nn.Parameter) -> None:
    if p.requires_grad:
        if p.grad_sample is not None:
            if isinstance(p.grad_sample, list):
                p.grad_sample.append(p._current_grad_sample)
            else:
                p.grad_sample = [p.grad_sample, p._current_grad_sample]
        else:
            p.grad_sample = p._current_grad_sample

        del p._current_grad_sample

続く

GradSampleModuleクラスの処理は以上になります。
次回は、PRVAccountantクラスによるプライバシーバジェット計算の詳細を見ていきます。

参考記事

https://medium.com/pytorch/differential-privacy-series-part-2-efficient-per-sample-gradient-computation-in-opacus-5bf4031d9e22

Discussion