📘

torch.distributionsにおける負の二項分布 (negative binomial)

2024/03/01に公開

負の二項分布には複数の定義があります。そのうちどれが torch.distributionで採用されているか忘れがちなのでまとめました。

負の二項分布

結果が成功か失敗かの2択になるような確率的試行をBernoulli試行と呼びます。
Bernoulli試行を何度も行う場合を考えます。成功確率を p とし、r失敗したら試行を止めることとします。このとき、試行を止めるまでに (r 回失敗するまでに) 成功した回数 k は以下の確率分布に従います。

P(X = k) = \binom{k+r-1}{k} (1-p)^r p^k

これを負の二項分布 (negative binomial distribution) と呼びます。確率変数とパラメータはそれぞれ,

  • 確率変数: k \in \{0, 1, 2, \cdots\}, 試行を止めるまでの成功回数
  • パラメータ:
    • r \in \mathbb{N}, 試行を止めるまでの失敗回数
    • p \in [0, 1], 成功確率

となります。torch.distributions の実装 (torch.distributions.negative_binomial.NegativeBinomial) はこの定義です。

複数の定義

上記の定義における成功回数と失敗回数は、文献により逆になる場合があります。即ち、試行を止めるまでの成功回数をパラメータ r、失敗回数を確率変数 k とするパターンです。この場合は、確率質量関数の指数部分が逆になります。

P(X = k) = \begin{cases} \binom{k+r-1}{k} (1-p)^r p^k & r \text{ is failure, } k \text{ is success} \\ \\ \binom{k+r-1}{k} (1-p)^k p^r & r \text{ is succes, } k \text{ is failure} \end{cases}

例えば、pytorch日本語版Wikipediaでは前者、英語版Wikipediaでは後者の定義が紹介されています。

Gamma関数を用いた定義

ガンマ関数 \Gamma(n) = (n-1)! を用いると、

\binom{k+r-1}{k} = \frac{(k+r-1)!}{k!(r-1)!} = \frac{\Gamma(k+r)}{k!\Gamma(r)}

となるため、確率質量関数は

P(X = k) = \frac{\Gamma(k+r)}{k!\Gamma(r)} (1-p)^r p^k

と書けます。これを用いると、パラメータ r の定義域を自然数から正の実数へと拡張することができます。上記の定義で言う失敗「回数」が実数になるのには違和感を感じますが、例えば回帰モデルを作る場合等はこちらの方が最適化が容易というメリットがあります。

torch.distributionsにおける実装

torch.distributions.negative_binomial.NegativeBinomial は, 以下の3つの引数を受け取ります。

  • total_count: r \geqq 0, 試行を止めるまでの失敗回数
  • probs: p \in [0, 1), 成功確率
  • logits: \mathrm{logit}(p) \in \mathbb{R}

\mathrm{logit}(p) = \ln \left( \dfrac{p}{1-p} \right) で、sigmoid関数 \sigma(\alpha) = \dfrac{1}{1 + e^{-\alpha}} の逆関数です。

\alpha = \mathrm{logit}(p) とすると、1-p = \sigma(-\alpha) より

  • 平均 \mathbb{E}[k] = \dfrac{rp}{1-p} = r e^\alpha
  • 分散 \mathbb{V}[k] = \dfrac{rp}{(1-p)^2} = \dfrac{r e^\alpha}{\sigma(-\alpha)}

と表記できます。torch.distributionsにおけるmean及びvarianceも上式で計算されています。
また、対数尤度log_probは以下のように計算されます。

\ln P(k) = \ln \Gamma(k+r) - \ln \Gamma(k+1) - \ln \Gamma(r) + r \ln \sigma(-\alpha) + k \ln \sigma(\alpha)

\ln \Gamma, \ln \sigma はそれぞれ関数torch.lgamma, torch.nn.functional.logsigmoid が存在します。

その他のparametarization

回帰モデルを作る場合は、平均や分散に基づくパラメータを用いた方が何かと便利です。
そこでパラメータ \{r, p\} を以下のように \{\mu, \theta\} に置き換えます。

\mu = \dfrac{rp}{1-p}, \hspace{0.5in} \theta = r

平均は \mu、分散は \mu + \dfrac{\mu^2}{\theta} となるので、
\mu をmean parameter、\theta をdispersion parameterと呼ぶことがあります
(1/\theta をdispersionとし、\theta をinverse dispersionと呼ぶこともあります)。

p = \dfrac{\mu}{\theta + \mu},\ 1 - p = \dfrac{\theta}{\theta + \mu} より、確率質量関数は

P(X = k) = \frac{\Gamma(k+\theta)}{k!\Gamma(\theta)} \left( \dfrac{\theta}{\theta + \mu} \right)^\theta \left( \dfrac{\mu}{\theta + \mu} \right)^k

と表せます。このパラメータを用いてtorch.distributionsのインスタンスを作る場合は以下のように設定します。

  • total_count: \theta
  • logits: \ln \mu - \ln \theta

Reference

Discussion