Opacusの実装を読んでみる (DP-SGD): Part 3
はじめに
PyTorch向けのDP-SGDライブラリであるOpacusの実装を引き続き読んでいきます。
前回の記事はこちら。
バージョン情報:Opacus v1.4.0
実装を読んでみる
DP-SGDではサンプルごとの勾配が必要になりますが、PyTorchにはそれを取得する機能がありません。そのため、Opacusでは以下の3種類の方式が実装されています。
- Hooks-based approach
- Functorch approach
- ExpandedWeigths approach
執筆時点のデフォルトは"Hooks-based approach"で、対応していないモジュールのみ"Functorch approach"を使用します。よって、この方式を実装しているGradSampleModule
クラスの処理内容を見ていきます。
GradSampleModule & add_hooks
GradSampleModule
クラスの初期化メソッドです。add_hooks
メソッドでサンプルごとの勾配を計算するためのフックを追加します。
class GradSampleModule(AbstractGradSampleModule):
def __init__(
self,
m: nn.Module,
*,
batch_first=True,
loss_reduction="mean",
strict: bool = True,
force_functorch=False,
):
self.hooks_enabled = False
self.batch_first = batch_first
self.loss_reduction = loss_reduction
self.force_functorch = force_functorch
self.add_hooks(
loss_reduction=loss_reduction,
batch_first=batch_first,
force_functorch=force_functorch,
)
add_hooks
メソッドは、順伝播中に活性化を保存するためのcapture_activations_hook
メソッドと、逆伝播中にサンプルごとの勾配を計算するためのcapture_backprops_hook
メソッドを、フックとして追加します。また、これらのフックはiterate_submodules
メソッドが返すサブモジュール、すなわち訓練可能なパラメータを直接持つ全てのモジュールに追加されます。なお、GRAD_SAMPLERS
に登録されていないモジュールは"Functorch approach"で処理するため、prepare_layer
関数を通します。
def iterate_submodules(self, module: nn.Module) -> Iterable[nn.Module]:
if has_trainable_params(module):
yield module
# Don't recurse if module is handled by functorch
if (
has_trainable_params(module)
and type(module) not in self.GRAD_SAMPLERS
and type(module) not in [DPRNN, DPLSTM, DPGRU]
):
return
for m in module.children():
yield from self.iterate_submodules(m)
def add_hooks(
self,
*,
loss_reduction: str = "mean",
batch_first: bool = True,
force_functorch: bool = False,
) -> None:
for module in self.iterate_submodules(self._module):
# Do not add hooks to DPRNN, DPLSTM or DPGRU as the hooks are handled by the `RNNLinear`
if type(module) in [DPRNN, DPLSTM, DPGRU]:
continue
if force_functorch or not type(module) in self.GRAD_SAMPLERS:
prepare_layer(module, batch_first=batch_first)
self.autograd_grad_sample_hooks.append(
module.register_forward_hook(self.capture_activations_hook)
)
self.autograd_grad_sample_hooks.append(
module.register_backward_hook(
partial(
self.capture_backprops_hook,
loss_reduction=loss_reduction,
batch_first=batch_first,
)
)
)
prepare_layer
関数は、"Functorch approach"で処理するモジュールに対して前処理を行います。具体的には、functorchのmake_functional
関数を使って関数化したモジュールflayer
について、損失を計算するcompute_loss_stateless_model
関数を定義します。この関数とfunctorchのgrad
関数、vmap
関数を用いて、サンプルごとの勾配を計算するft_compute_sample_grad
メソッドを定義します。
def prepare_layer(layer, batch_first=True):
from functorch import grad, make_functional, vmap
flayer, _ = make_functional(layer)
def compute_loss_stateless_model(params, activations, backprops):
if batch_first or type(layer) is RNNLinear:
batched_activations = activations.unsqueeze(0)
batched_backprops = backprops.unsqueeze(0)
else:
# If batch_first is False, the batch dimension is the second dimension
batched_activations = activations.unsqueeze(1)
batched_backprops = backprops.unsqueeze(1)
output = flayer(params, batched_activations)
loss = (output * batched_backprops).sum()
return loss
ft_compute_grad = grad(compute_loss_stateless_model)
# Note that the vmap is done on the first dimension, regardless of batch_first
# This is because the activations and backprops given by the GradSampleModule
# are always batch_first=True
layer.ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))
register_grad_sampler
"Hooks-based approach"で処理できるモジュールを登録するGRAD_SAMPLERS
には、デコレータ関数register_grad_sampler
を使って登録します。
def register_grad_sampler(
target_class_or_classes: Union[Type[nn.Module], Sequence[Type[nn.Module]]]
):
def decorator(f):
target_classes = (
target_class_or_classes
if isinstance(target_class_or_classes, Sequence)
else [target_class_or_classes]
)
for target_class in target_classes:
GradSampleModule.GRAD_SAMPLERS[target_class] = f
return f
return decorator
例えばnn.Linear
モジュールが登録されています。
@register_grad_sampler(nn.Linear)
def compute_linear_grad_sample(
layer: nn.Linear, activations: List[torch.Tensor], backprops: torch.Tensor
) -> Dict[nn.Parameter, torch.Tensor]:
activations = activations[0]
ret = {}
if layer.weight.requires_grad:
gs = contract("n...i,n...j->nij", backprops, activations)
ret[layer.weight] = gs
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = contract("n...k->nk", backprops)
return ret
capture_activations_hook & capture_backprops_hook
capture_activations_hook
メソッドは、モジュールが順伝播を行うたびに呼び出され、activations
属性に順伝播の入力を保存します。
def capture_activations_hook(
self,
module: nn.Module,
forward_input: List[torch.Tensor],
_forward_output: torch.Tensor,
):
if not hasattr(module, "activations"):
module.activations = []
module.activations.append([t.detach() for t in forward_input]) # pyre-ignore
for _, p in trainable_parameters(module):
p._forward_counter += 1
capture_backprops_hook
メソッドは、モジュールの逆伝播中に呼び出され、逆伝播の出力backprops
と順伝播で保存したactivations
からサンプルごとの勾配を計算します。この計算はgrad_sampler_fn
関数、すなわちGRAD_SAMPLERS
に登録されている関数(例:nn.Linear
の場合、compute_linear_grad_sample
関数)もしくはft_compute_per_sample_gradient
関数が行います。前者は"Hooks-based approach"、後者は"Functorch approach"です。
def capture_backprops_hook(
self,
module: nn.Module,
_forward_input: torch.Tensor,
forward_output: torch.Tensor,
loss_reduction: str,
batch_first: bool,
):
backprops = forward_output[0].detach()
activations, backprops = self.rearrange_grad_samples(
module=module,
backprops=backprops,
loss_reduction=loss_reduction,
batch_first=batch_first,
)
if not self.force_functorch and type(module) in self.GRAD_SAMPLERS:
grad_sampler_fn = self.GRAD_SAMPLERS[type(module)]
else:
grad_sampler_fn = ft_compute_per_sample_gradient
grad_samples = grad_sampler_fn(module, activations, backprops)
for param, gs in grad_samples.items():
create_or_accumulate_grad_sample(
param=param, grad_sample=gs, max_batch_len=module.max_batch_len
)
for _, p in trainable_parameters(module):
p._forward_counter -= 1
if p._forward_counter == 0:
promote_current_grad_sample(p)
ft_compute_per_sample_gradient
関数は、上述のprepare_layer
関数で定義したft_compute_sample_grad
メソッドを実行して、サンプルごとの勾配を計算します。
def ft_compute_per_sample_gradient(layer, activations, backprops):
parameters = list(layer.parameters(recurse=True))
per_sample_grads = layer.ft_compute_sample_grad(
parameters, activations[0], backprops
)
ret = {}
for i_p, p in enumerate(parameters):
ret[p] = per_sample_grads[i_p]
return ret
create_or_accumulate_grad_sample
関数により、サンプルごとの勾配が_current_grad_sample
属性に蓄積されます。最後にpromote_current_grad_sample
関数によって、_current_grad_sample
属性がgrad_sample
属性に昇格されます。
def create_or_accumulate_grad_sample(
*, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int
) -> None:
if param.requires_grad:
if hasattr(param, "_current_grad_sample"):
param._current_grad_sample[: grad_sample.shape[0]] += grad_sample
else:
param._current_grad_sample = torch.zeros(
torch.Size([max_batch_len]) + grad_sample.shape[1:],
device=grad_sample.device,
dtype=grad_sample.dtype,
)
param._current_grad_sample[: grad_sample.shape[0]] = grad_sample
def promote_current_grad_sample(p: nn.Parameter) -> None:
if p.requires_grad:
if p.grad_sample is not None:
if isinstance(p.grad_sample, list):
p.grad_sample.append(p._current_grad_sample)
else:
p.grad_sample = [p.grad_sample, p._current_grad_sample]
else:
p.grad_sample = p._current_grad_sample
del p._current_grad_sample
続く
GradSampleModule
クラスの処理は以上になります。
次回は、PRVAccountant
クラスによるプライバシーバジェット計算の詳細を見ていきます。
参考記事
Discussion