📑

Gumbel-max sampling

2024/02/07に公開

Gumbel cumulative distirubtion function is

\begin{align*} F(x | \mu, \eta) &= \exp \left[ - \exp\left( -\frac{x - \mu}{\eta} \right) \right] ,\quad -\infty < x < \infty , \end{align*}

and its probability distribution function is

\begin{align*} f(x|\mu, \eta) &= \frac{1}{\eta} \exp\left( -\frac{x - \mu}{\eta} \right) \exp \left[ - \exp\left(- \frac{x-\mu}{\eta} \right) \right] \end{align*}

Let l = \set{l_1, l_2, \dots, l_C} ,\quad l_c \in \mathbb{R} be logits.

  • for c = \set{1, 2, \dots, C}
    • g_c \sim f(G_c | \mu=0, \eta=1)
    • z_c = l_c + g_c

then the probability of z_c > z_{\backslash c} matches \pi_c.


Proof p(z_c > z_{\backslash c}) = \pi_c .

\begin{align*} p(z_c > z_{\backslash c} | G_c = g_c) &= \prod_{i \neq c} p(z_c > z_i | G_c = g_c) \\ &= \prod_{i \neq c} p(l_c + g_c > l_i + G_i) \\ &= \prod_{i \neq c} p(G_i < l_c + g_c - l_i) \\ &= \prod_{i \neq c} F(l_c + g_c - l_i | \mu=0, \eta=1) \\ &= \prod_{i \neq c} \exp \left[ - \exp\left( l_i - l_c - g_c \right) \right] \end{align*}
\begin{align*} p(G_c = g_k) &= f(g_k | \mu=0, \eta=1) \\ &= \exp(- g_c) \exp \left[ - \exp\left( -g_c \right) \right] \\ \end{align*}
\begin{align*} p(z_c > z_{\backslash c}) &= \int p(z_c > z_{\backslash c} | G_c = g_c) p(G_c = g_c) dg_c \\ &= \int \prod_{i \neq c} \exp \left[ - \exp\left( l_i - l_c - g_c \right) \right] \exp(- g_c) \exp \left[ - \exp\left( -g_c \right) \right] d g_c \\ &= \int \prod_{i \neq c} \exp \left[ - \exp\left( l_i - l_c - g_c \right) \right] \exp(- g_c) \exp \left[ - \exp\left(l_c - l_c -g_c \right) \right] d g_c \\ &= \int \prod_{i=1}^C \exp \left[ - \exp\left( l_i - l_c - g_c \right) \right] \exp(- g_c) d g_c \\ &= \int \exp \left[ - \frac{\sum_{i=1}^C \exp\left( l_i \right)}{\exp(l_c)} \exp(-g_c) \right] \exp(- g_c) d g_c \\ &= \int \exp \left[ - \frac{\exp(-g_c)}{\pi_c} \right] \exp(- g_c) d g_c \\ &= \int \exp \left[ -g_c - \frac{\exp(-g_c)}{\pi_c} \right] d g_c \\ &= \left[ \pi_c \exp\left( \frac{-\exp(-g_c)}{\pi_c} \right) \right]_{-\infty}^{\infty} \\ &= \pi_c \end{align*}

Discussion