🔐

Opacusの実装を読んでみる (DP-SGD): Part 1

2024/05/23に公開

はじめに

DP-SGDは、差分プライバシーを満たす深層学習モデルを訓練する手法の一つです。
今回は、PyTorch向けのDP-SGDライブラリであるOpacusの実装を読んでみます。

バージョン情報:Opacus v1.4.0
https://opacus.ai/

対象者:差分プライバシーDP-SGDに関する基本的な知識がある方

DP-SGDの基本的な仕組みについては以下の論文をご覧ください。
https://arxiv.org/abs/1607.00133

実装を読んでみる

以下はOpacusを用いた学習の実装例です。(モデル訓練前まで)

from opacus import PrivacyEngine

privacy_engine = PrivacyEngine()

model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
    module=model,
    optimizer=optimizer,
    data_loader=train_loader,
    epochs=EPOCHS,
    target_epsilon=EPSILON,
    target_delta=DELTA,
    max_grad_norm=MAX_GRAD_NORM,
)

PrivacyEngine

PrivacyEngineクラスの初期化メソッドです。

privacy_engine.py
class PrivacyEngine:
    def __init__(self, *, accountant: str = "prv", secure_mode: bool = False):
        self.accountant = create_accountant(mechanism=accountant)
        self.secure_mode = secure_mode

accountant引数は、プライバシーの計算方法を指定します。デフォルトはprv、Privacy loss Random Variables を用いた方式です。
secure_mode引数は、暗号学的に強固なプライバシー保証が必要な場合にTrueに設定しますが、デフォルトはFalseのため詳細は割愛します。

accountants/__init__.py
def create_accountant(mechanism: str) -> IAccountant:
    if mechanism == "rdp":
        return RDPAccountant()
    elif mechanism == "gdp":
        return GaussianAccountant()
    elif mechanism == "prv":
        return PRVAccountant()

make_private_with_epsilon

make_private_with_epsilonメソッドでは、get_noise_multiplier関数でnoise_multiplierの値を求めて、make_privateメソッドに渡します。なお、sample_ratedata_loaderのミニバッチ数から決まります。

privacy_engine.py
    def make_private_with_epsilon(
        self,
        *,
        module: nn.Module,
        optimizer: optim.Optimizer,
        data_loader: DataLoader,
        target_epsilon: float,
        target_delta: float,
        epochs: int,
        max_grad_norm: Union[float, List[float]],
        batch_first: bool = True,
        loss_reduction: str = "mean",
        poisson_sampling: bool = True,
        clipping: str = "flat",
        noise_generator=None,
        grad_sample_mode: str = "hooks",
        **kwargs,
    ):
        sample_rate = 1 / len(data_loader)

        return self.make_private(
            module=module,
            optimizer=optimizer,
            data_loader=data_loader,
            noise_multiplier=get_noise_multiplier(
                target_epsilon=target_epsilon,
                target_delta=target_delta,
                sample_rate=sample_rate,
                epochs=epochs,
                accountant=self.accountant.mechanism(),
                **kwargs,
            ),
            max_grad_norm=max_grad_norm,
            batch_first=batch_first,
            loss_reduction=loss_reduction,
            noise_generator=noise_generator,
            grad_sample_mode=grad_sample_mode,
            poisson_sampling=poisson_sampling,
            clipping=clipping,
        )

get_noise_multiplier関数は、指定されたプライバシーバジェットを達成するためのノイズレベルを計算します。具体的には、\epsilon の値がtarget_epsilon以下かつその差がepsilon_tolerance以下になるnoise_multiplierの値を探索します。

accountants/utils.py
MAX_SIGMA = 1e6

def get_noise_multiplier(
    *,
    target_epsilon: float,
    target_delta: float,
    sample_rate: float,
    epochs: Optional[int] = None,
    steps: Optional[int] = None,
    accountant: str = "rdp",
    epsilon_tolerance: float = 0.01,
    **kwargs,
) -> float:
    if steps is None:
        steps = int(epochs / sample_rate)

    eps_high = float("inf")
    accountant = create_accountant(mechanism=accountant)

    sigma_low, sigma_high = 0, 10
    while eps_high > target_epsilon:
        sigma_high = 2 * sigma_high
        accountant.history = [(sigma_high, sample_rate, steps)]
        eps_high = accountant.get_epsilon(delta=target_delta, **kwargs)
        if sigma_high > MAX_SIGMA:
            raise ValueError("The privacy budget is too low.")

    while target_epsilon - eps_high > epsilon_tolerance:
        sigma = (sigma_low + sigma_high) / 2
        accountant.history = [(sigma, sample_rate, steps)]
        eps = accountant.get_epsilon(delta=target_delta, **kwargs)

        if eps < target_epsilon:
            sigma_high = sigma
            eps_high = eps
        else:
            sigma_low = sigma

    return sigma_high

make_private

make_privateメソッドは、moduleoptimizerdata_loaderを引数として受け取り、それらのオブジェクトをDP-SGDで学習するために修正したものを返します。それぞれの処理の詳細を以下で見ていきます。

privacy_engine.py
    def make_private(
        self,
        *,
        module: nn.Module,
        optimizer: optim.Optimizer,
        data_loader: DataLoader,
        noise_multiplier: float,
        max_grad_norm: Union[float, List[float]],
        batch_first: bool = True,
        loss_reduction: str = "mean",
        poisson_sampling: bool = True,
        clipping: str = "flat",
        noise_generator=None,
        grad_sample_mode: str = "hooks",
    ) -> Tuple[GradSampleModule, DPOptimizer, DataLoader]:
        module = self._prepare_model(
            module,
            batch_first=batch_first,
            loss_reduction=loss_reduction,
            grad_sample_mode=grad_sample_mode,
        )
        if poisson_sampling:
            module.register_backward_hook(forbid_accumulation_hook)

        data_loader = self._prepare_data_loader(
            data_loader, distributed=distributed, poisson_sampling=poisson_sampling
        )

        sample_rate = 1 / len(data_loader)
        expected_batch_size = int(len(data_loader.dataset) * sample_rate)

        optimizer = self._prepare_optimizer(
            optimizer,
            noise_multiplier=noise_multiplier,
            max_grad_norm=max_grad_norm,
            expected_batch_size=expected_batch_size,
            loss_reduction=loss_reduction,
            noise_generator=noise_generator,
            distributed=distributed,
            clipping=clipping,
            grad_sample_mode=grad_sample_mode,
        )

        optimizer.attach_step_hook(
            self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate)
        )

        return module, optimizer, data_loader

_prepare_model & forbid_accumulation_hook

これらはmoduleを修正するためのメソッド・関数です。
まず_prepare_modelメソッドでは、validateメソッドでmoduleが修正可能かどうかチェックした後、AbstractGradSampleModuleでラップしたものを返します。

privacy_engine.py
    def _prepare_model(
        self,
        module: nn.Module,
        *,
        batch_first: bool = True,
        loss_reduction: str = "mean",
        grad_sample_mode: str = "hooks",
    ) -> AbstractGradSampleModule:
        self.validate(module=module, optimizer=None, data_loader=None)

        # wrap
        if isinstance(module, AbstractGradSampleModule):
            ### 省略 ###
        else:
            return wrap_model(
                module,
                grad_sample_mode=grad_sample_mode,
                batch_first=batch_first,
                loss_reduction=loss_reduction,
            )

validateメソッドでは、引数のmoduleが訓練モードかつOpacusで訓練不可能なモジュールが含まれていないかどうかをチェックします。例えば、バッチ正規化はOpacusで使用できないため、デコレータ関数register_module_validatorを用いてVALIDATORSに登録されています。

validators/module_validator.py
class ModuleValidator:
    VALIDATORS = {}
    FIXERS = {}

    @classmethod
    def validate(
        cls, module: nn.Module, *, strict: bool = False
    ) -> List[UnsupportedModuleError]:
        errors = []
        # 1. validate that module is in training mode
        if not module.training:
            errors.append(
                IllegalModuleConfigurationError("Model needs to be in training mode")
            )
        # 2. perform module specific validations for trainable modules.
        for _, sub_module in trainable_modules(module):
            if type(sub_module) in ModuleValidator.VALIDATORS:
                sub_module_validator = ModuleValidator.VALIDATORS[type(sub_module)]
                errors.extend(sub_module_validator(sub_module))
        # raise/return as needed
        if strict and len(errors) > 0:
            raise UnsupportedModuleError(errors)
        else:
            return errors
validators/utils.py
DEFAULT_MODULE_VALIDATOR = ModuleValidator

def register_module_validator(
    target_class_or_classes: Union[type, Sequence[type]],
    validator_class: type = DEFAULT_MODULE_VALIDATOR,
):
    def decorator(f):
        target_classes = (
            target_class_or_classes
            if isinstance(target_class_or_classes, Sequence)
            else [target_class_or_classes]
        )
        for target_class in target_classes:
            validator_class.VALIDATORS[target_class] = f
        return f

    return decorator
validators/batch_norm.py
@register_module_validator(
    [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm]
)
def validate(module: BATCHNORM) -> List[UnsupportedModuleError]:
    return [ ShouldReplaceModuleError( ### 省略 ### ) ]

moduleのラップはwrap_model関数で行います。執筆時点でgrad_sample_modeのデフォルトであるhooksの場合、GradSampleModuleクラスでラップされます。

grad_sample/utils.py
def wrap_model(model: nn.Module, grad_sample_mode: str, *args, **kwargs):
    cls = get_gsm_class(grad_sample_mode)
    if grad_sample_mode == "functorch":
        kwargs["force_functorch"] = True
    return cls(model, *args, **kwargs)

def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]:
    if grad_sample_mode in ["hooks", "functorch"]:
        return GradSampleModule
    elif grad_sample_mode == "ew":
        return GradSampleModuleExpandedWeights
    elif grad_sample_mode == "no_op":
        return GradSampleModuleNoOp

また、Opacusでは原則ポアソンサンプリングを使用するため、勾配の蓄積が許されません。そこで、PyTorchのregister_backward_hookメソッドを用いて、勾配蓄積を検知するためのforbid_accumulation_hook関数を登録します。

privacy_engine.py
def forbid_accumulation_hook(
    module: AbstractGradSampleModule,
    _grad_input: torch.Tensor,
    _grad_output: torch.Tensor,
):
    for _, p in trainable_parameters(module):
        if p.grad_sample is not None:
            if isinstance(p.grad_sample, torch.Tensor):
                accumulated_iterations = 1
            elif isinstance(p.grad_sample, list):
                accumulated_iterations = len(p.grad_sample)

            if accumulated_iterations > 1:
                raise ValueError( ### 省略 ### )

_prepare_data_loader

_prepare_data_loaderメソッドは、data_loaderを修正するためのメソッドです。引数で渡されたdata_loaderDPDataLoaderクラスに変更します。

privacy_engine.py
    def _prepare_data_loader(
        self,
        data_loader: DataLoader,
        *,
        poisson_sampling: bool,
        distributed: bool,
    ) -> DataLoader:
        if poisson_sampling:
            return DPDataLoader.from_data_loader(
                data_loader, generator=self.secure_rng, distributed=distributed
            )

DPDataLoaderクラスは、ポアソンサンプリングを行うため、batch_samplerとしてUniformWithReplacementSamplerクラスを使用します。また、空のミニバッチが生成された場合に対応するため、wrap_collate_with_empty関数を定義します。

data_loader.py
def wrap_collate_with_empty(
    *,
    collate_fn: Optional[_collate_fn_t],
    sample_empty_shapes: Sequence[Tuple],
    dtypes: Sequence[Union[torch.dtype, Type]],
):
    def collate(batch):
        if len(batch) > 0:
            return collate_fn(batch)
        else:
            return [
                torch.zeros(shape, dtype=dtype)
                for shape, dtype in zip(sample_empty_shapes, dtypes)
            ]

    return collate

class DPDataLoader(DataLoader):
    def __init__(
        self,
        dataset: Dataset,
        *,
        sample_rate: float,
        collate_fn: Optional[_collate_fn_t] = None,
        drop_last: bool = False,
        generator=None,
        distributed: bool = False,
        **kwargs,
    ):
        self.sample_rate = sample_rate
        self.distributed = distributed

        if distributed:
            ### 省略 ###
        else:
            batch_sampler = UniformWithReplacementSampler(
                num_samples=len(dataset),  # type: ignore[assignment, arg-type]
                sample_rate=sample_rate,
                generator=generator,
            )
        sample_empty_shapes = [(0, *shape_safe(x)) for x in dataset[0]]
        dtypes = [dtype_safe(x) for x in dataset[0]]
        if collate_fn is None:
            collate_fn = default_collate

        super().__init__(
            dataset=dataset,
            batch_sampler=batch_sampler,
            collate_fn=wrap_collate_with_empty(
                collate_fn=collate_fn,
                sample_empty_shapes=sample_empty_shapes,
                dtypes=dtypes,
            ),
            generator=generator,
            **kwargs,
        )

    @classmethod
    def from_data_loader(
        cls, data_loader: DataLoader, *, distributed: bool = False, generator=None
    ):
        return cls(
            dataset=data_loader.dataset,
            sample_rate=1 / len(data_loader),
            num_workers=data_loader.num_workers,
            collate_fn=data_loader.collate_fn,
            pin_memory=data_loader.pin_memory,
            drop_last=data_loader.drop_last,
            timeout=data_loader.timeout,
            worker_init_fn=data_loader.worker_init_fn,
            multiprocessing_context=data_loader.multiprocessing_context,
            generator=generator if generator else data_loader.generator,
            prefetch_factor=data_loader.prefetch_factor,
            persistent_workers=data_loader.persistent_workers,
            distributed=distributed,
        )

UniformWithReplacementSamplerクラスは、sample_rateの確率で1を持つmaskテンソルを生成し、そのインデックスを取得することでポアソンサンプリングを実現しています。

utils/uniform_sampler.py
class UniformWithReplacementSampler(Sampler[List[int]]):
    def __init__(
        self, *, num_samples: int, sample_rate: float, generator=None, steps=None
    ):
        self.num_samples = num_samples
        self.sample_rate = sample_rate
        self.generator = generator

        if steps is not None:
            self.steps = steps
        else:
            self.steps = int(1 / self.sample_rate)

    def __len__(self):
        return self.steps

    def __iter__(self):
        num_batches = self.steps
        while num_batches > 0:
            mask = (
                torch.rand(self.num_samples, generator=self.generator)
                < self.sample_rate
            )
            indices = mask.nonzero(as_tuple=False).reshape(-1).tolist()
            yield indices

            num_batches -= 1

_prepare_optimizer & attach_step_hook

これらはoptimizerを修正するためのメソッドです。
_prepare_optimizerメソッドは、引数で渡されたoptimizerDPOptimizerクラスに変更します。

privacy_engine.py
    def _prepare_optimizer(
        self,
        optimizer: optim.Optimizer,
        *,
        noise_multiplier: float,
        max_grad_norm: Union[float, List[float]],
        expected_batch_size: int,
        loss_reduction: str = "mean",
        distributed: bool = False,
        clipping: str = "flat",
        noise_generator=None,
        grad_sample_mode="hooks",
    ) -> DPOptimizer:
        optim_class = get_optimizer_class(
            clipping=clipping,
            distributed=distributed,
            grad_sample_mode=grad_sample_mode,
        )

        return optim_class(
            optimizer=optimizer,
            noise_multiplier=noise_multiplier,
            max_grad_norm=max_grad_norm,
            expected_batch_size=expected_batch_size,
            loss_reduction=loss_reduction,
            generator=generator,
            secure_mode=self.secure_mode,
        )

デフォルトの設定ではDPOptimizerクラスが選ばれます。

optimizers/__init__.py
def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str = None):
    if clipping == "flat" and distributed is False:
        return DPOptimizer
    elif clipping == "flat" and distributed is True:
        return DistributedDPOptimizer
    elif clipping == "per_layer" and distributed is False:
        return DPPerLayerOptimizer
    elif clipping == "per_layer" and distributed is True:
        if grad_sample_mode == "hooks":
            return DistributedPerLayerOptimizer
        elif grad_sample_mode == "ew":
            return SimpleDistributedPerLayerOptimizer
    elif clipping == "adaptive" and distributed is False:
        return AdaClipDPOptimizer

次に、attach_step_hookメソッドを用いて、最適化ステップ毎にget_optimizer_hook_fnメソッド内で定義されているhook_fn関数が実行されるようにします。

optimizers/optimizer.py
    def attach_step_hook(self, fn: Callable[[DPOptimizer], None]):
        self.step_hook = fn
accountants/accountant.py
    def get_optimizer_hook_fn(
        self, sample_rate: float
    ) -> Callable[[DPOptimizer], None]:
        def hook_fn(optim: DPOptimizer):
            self.step(
                noise_multiplier=optim.noise_multiplier,
                sample_rate=sample_rate * optim.accumulated_iterations,
            )

        return hook_fn

続く

モデルを訓練する前の処理は以上になります。
次回は、モデル訓練中の処理を見ていきます。
続きの記事はこちらです。

Discussion