The Gumbel-Max Trick for Discrete Distributions

This is just a rewrite of this blog post for my understanding. Took me while to parse their notation so I wrote this as a reference. There’s no new information in this post.

Problem Statement

We have N items and \(x_1, x_2 \dots x_N\) where \(x_i \in \mathbb{R}\) reprsents an unnormalized probability for selecting item i.

To be able to sample from the above discrete distribution we can generate data by performing the softmax transformation and then doing the usual thing to draw from a discrete distribution. The following equaltion represents the normalized probability of being selected:

\[ \pi_{k} = \frac{e^{x_k}}{\sum_{j=1}^{N}e^{x_j}}\]

where \(\pi \in [0,1]^N\) is a vector representation of all the probabilities.

Adding noise and re-ranking

Turns out that there is an alternative way to arrive at such discrete samples, that doesn’t actually require constructing the discrete distribution.

The algorithm turns out to be the following:

For each \(i=1, \dots, N\)

\(z_i = x_i + \alpha_i\) where \(\alpha_i\) is a random variable drawn from the standard Gumbel distribution.

Then \(z_k = \text{argmax}_{i=1, .., N} z_i\) is equivalent to sampling from the discrete distribution \(\pi\)

Proof

By the additive property; \(z_i\) is distributed according to the Gumbel distribution with location parameter \(x_i\).

Now assume we add to noise to the k’th item only and leave the remaining unchanged. So this means

\(z_k = x_k + \alpha_k\) and \(z_i=x_i\) if \(i \neq k\).

Now if we were to apply the algorithm described above; what is the probability that we selected item k?

\[\begin{align*} \mathbb{P}\Big( z_k \text{ is the largest} | z_k, \{ x_j \}_{j=1}^N \Big) &= \prod_{j \neq k} F_{z_k}(x_j) \\ \end{align*}\]

Where \(F_z(a) = e^{-e^{(z - a)}}\) is the CDF for the Gumbel distribution.

Now the above was true for a single item noisified - we need to marginalize that variable if we wanted to know what was the probability of sampling the largest item when all of the items are noisified based on the algorithm defined above.

From Bayes rule we know

\(P(X) = \int P(X|Y)P(Y)dx\)

So when we marginalize the above equation it turns out that

\[\begin{align*} \mathbb{P}\Big( z_k \text{ is the largest} | \{ x_j \}_{j=1}^N \Big) &= \int f(z_k) \prod_{j \neq k} F_{z_k}(x_j)dz \\ \end{align*}\]

where the integral is over the domain of z and f is the pdf of the Gumbel distribution with \(x_k\) is the location parameter of the distribution.

\[\begin{align*} \mathbb{P}\Big( z_k \text{ is the largest} | \{ x_j \}_{j=1}^N \Big) &= \int (e^{−(z_k − x_k)–e^{−(z−x_k)}}) \prod_{j \neq k} e^{-e^{(z_k - x_j)}}dz \\ &= \int e^{-z_k + x_k - e^{-z_k}\sum_{j=1}^{N} e^{x_j}} dz \end{align*}\]

Also the algebraic manipulation in the above step is just rearranging of terms. The above integral has a closed from solution which is equal to (I haven’t explicitly solved for this; Just believed their word for now)

\[\frac{e^{x_k}}{\sum_{j=1}^{N}e^{x_j}} \]

An implementation of this can be found here

Also note: Taking top K highest ranked scores is the same as sampling K items with replacement