Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to sample from a log-probability distribution?

I have some code that uses log-probability. When I want to draw a sample from the probability distribution, I use

import numpy as np

probs = np.exp(logprobs)
probs /= probs.sum()
sample = np.random.choice(X, p=probs, size=1)[0]

But there is some overhead here in the exponentiation and division. And the numpy random.choice function requires the probabilities are between 0 and 1, and sum to 1.

Are there any fast techniques for me to draw a sample using the nonnormalized log-probability array? I only ever need one sample at a time, and the frequency of drawing it just needs to be proportional to the log-probability.

like image 796
cschlick Avatar asked Oct 18 '25 05:10

cschlick


1 Answers

Use the Gumbel-max trick. See more explanation and references in this answer on Cross Validated. Here's a minimal code example:

import numpy as np

# Assume we only have log-probabilities (for sampling, even logits will do)
log_probs = np.log([0.1, 0.2, 0.3, 0.4])
num_categories = len(log_probs)

# Sample a single category
gumbels = np.random.gumbel(size=num_categories)
sample = np.argmax(log_probs + gumbels)

Vectorized implementation

Note that this function just returns indices. And I haven't yet verified that it's correct when replace=False or in all cases where gumbels is not None.

from typing import Union

import numpy as np


def sample_gumbels(size: Union[int, tuple[int]] = 1,
                   rng: Union[np.random.Generator, int] = None) -> np.ndarray:
    if not isinstance(rng, np.random.Generator):
        rng = np.random.default_rng(seed=rng)
    return -np.log(-np.log(rng.uniform(size=size)))


def random_choice_log_space(
    logits: np.ndarray,
    replace: bool = True,
    size: int = 1,
    gumbels: np.ndarray = None,
    rng: Union[np.random.Generator, int] = None,
    shuffle_gumbels: bool = True,
) -> np.ndarray:
    """
    Sample from a categorical distribution parametrized by logits or log-probabilities.

    Parameters
    ----------
    logits : np.ndarray
        the last dimension contains log-probabilities or unnormalized logits
        corresponding to the categorical distribution(s)
    replace : bool, optional
        whether or not to sample with or without replacement, by default True
    size : int, optional
        sample size, by default 1
    gumbels : np.ndarray, optional
        experimental feature: pre-computed Gumbel samples which will be added to
        `logits`, by default None
    rng : Union[np.random.Generator, int], optional
        ``np.random.Generator`` object or an integer seed, by default None
    shuffle_gumbels : bool, optional
        experimental feature: whether or not to shuffle `gumbels` (if provided)
        before taking the argmax, by default True. If you're not careful, setting this
        to False causes serial categorical samples to become correlated through
        `gumbels`, which is wrong

    Returns
    -------
    np.ndarray
        sampled indices

    Raises
    ------
    ValueError
        if `size` is less than 1
    """
    if size < 1:
        raise ValueError("size must be at least 1.")
    is_gumbels_precomputed = gumbels is not None
    # Create a Generator if needed
    if not isinstance(rng, np.random.Generator):
        rng = np.random.default_rng(seed=rng)
    if not is_gumbels_precomputed:
        # Independently sample as many Gumbels as needed
        if replace:
            # Pretty sure we have to generate Gumbels for each sample in this case.
            # During addition w/ logits, they'll be broadcasted
            _gumbels_shape = (size,) + logits.shape if size > 1 else logits.shape
        else:
            # We'll just take the top k
            _gumbels_shape = logits.shape
        gumbels = sample_gumbels(size=_gumbels_shape)
    if is_gumbels_precomputed and shuffle_gumbels:
        # Shuffling is unnecessary if the Gumbels were just randomly sampled.
        # It's necessary if the Gumbels were pre-computed and plan to be re-used.
        gumbels_original_shape = gumbels.shape
        if len(gumbels_original_shape) > 1:
            gumbels = gumbels.ravel()
        # gumbels is 1-D. For some reason, choice is faster than shuffle and permuted
        gumbels = rng.choice(gumbels, size=len(gumbels), replace=False, shuffle=True)
        if len(gumbels_original_shape) > 1:
            gumbels = gumbels.reshape(gumbels_original_shape)
    gumbels_rescaled: np.ndarray = logits + gumbels
    if replace:
        # gumbels_rescaled has shape (size, logits.shape) b/c of broadcasting
        return gumbels_rescaled.argmax(axis=-1)
    else:
        # take the top k (k=sample size) indices, as noted here:
        # https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/
        return np.argpartition(gumbels_rescaled, -size, axis=-1)[..., -size:]

As noted in the docstring, you can pass log-probabilities or unnormalized logits to the input logits. That's because these two inputs only differ by a constant—specifically, the log-sum-exp(probabilities)—which is irrelevant because an argmax is taken.

Quick and dirty statistical check

For random_choice_log_space to be correct, it needs to independently sample from the probability distribution implied by logits. The independence part is already clear. So we just need to compare the empirical distribution of samples to the actual distribution.

_probs = np.array([0.1, 0.2, 0.3, 0.4])

log_probs = np.log(_probs)
logits = np.log(_probs) + logsumexp(_probs, axis=-1)
# You start out with access to log_probs or logits

num_categories = len(_probs)

sample_size = 500_000
seed = 123
rng = np.random.default_rng(seed)


# helper function
def empirical_distr(discrete_samples):
    return (pd.Series(discrete_samples)
            .value_counts(normalize=True)
            .sort_index()
            .to_numpy())


# np.random.choice (select one at a time) AKA vanilla sampling
def random_choice_log_space_vanilla(logits, size, rng=None):
    probs = softmax(logits, axis=-1)
    if not isinstance(rng, np.random.Generator):
        rng = np.random.default_rng(seed=rng)
    return rng.choice(len(probs), p=probs, size=size, replace=True)

samples = random_choice_log_space_vanilla(logits, size=sample_size, rng=rng)
distr_vanilla = empirical_distr(samples)


# random_choice_log_space for log-probabilities input
samples = random_choice_log_space(log_probs, size=sample_size, rng=rng)
distr_log_probs = empirical_distr(samples)


# random_choice_log_space for logits input
samples = random_choice_log_space(logits, size=sample_size, rng=rng)
distr_logits = empirical_distr(samples)
print(pd.DataFrame({'rel error (vanilla)': (distr_vanilla - _probs)/_probs,
                    'rel error (log-probs)': (distr_log_probs - _probs)/_probs,
                    'rel error (logits)': (distr_logits - _probs)/_probs},
                    index=pd.Index(range(num_categories), name='category')))
          rel error (vanilla)  rel error (log-probs)  rel error (logits)
category                                                                
0                   -0.004760               0.006600            0.001760
1                    0.000170              -0.004080           -0.000170
2                    0.002773               0.001373           -0.000693
3                   -0.000975              -0.000640            0.000165

Even at this super high sample size, the p-value for a chi-squared test is nearly 1 for all methods.

print(chisquare(f_obs=distr_vanilla, f_exp=_probs).pvalue)
print(chisquare(f_obs=distr_log_probs, f_exp=_probs).pvalue)
print(chisquare(f_obs=distr_logits, f_exp=_probs).pvalue)
0.999999997062783
0.9999999935077725
0.999999999914127

Efficiency

None of this work matters if random_choice_log_space is always slower than softmaxing and using np.random.choice. Luckily, there's one common-enough problem where random_choice_log_space is faster. It's the problem in your question: you have logits, and you want to sample exactly one element.

from time import time

from scipy.stats import trim_mean


def time_func(func, *args, num_replications: int=50, **kwargs) -> list[float]:
    '''
    Returns a list, `times`, where `times[i]` is the time it took to run
    `func(*args, **kwargs)` at replication `i` for `i in range(num_replications)`.
    '''
    times = []
    for _ in range(num_replications):
        time_start = time()
        _ = func(*args, **kwargs)
        time_end = time()
        times.append(time_end - time_start)
    return times


category_sizes = np.power(2, np.arange(1, 14+1))
num_replications = 100

times_vanilla = []
times_gumbel = []
for size in category_sizes:
    logits = np.random.normal(size=size)
    times_vanilla.append(time_func(random_choice_log_space_vanilla, logits, size=1,
                                   num_replications=num_replications))
    times_gumbel.append(time_func(random_choice_log_space, logits, size=1,
                                  num_replications=num_replications))


(pd.DataFrame({'vanilla': trim_mean(times_vanilla, 0.1, axis=1),
               'Gumbel': trim_mean(times_gumbel, 0.1, axis=1)},
              index=pd.Index(category_sizes, name='# categories'))
 .plot.bar(title='Categorical sampling',
           figsize=(8,5),
           ylabel='mean wall-clock time (sec)'));

plot

The plot (and me running it many times) demonstrates that the comparison becomes unstable for larger category sizes.

A note: precomputing and reusing Gumbel samples across draws would significantly speed up the trick. But you need to shuffle the Gumbel samples before adding them to logits. Otherwise, categorical samples across draws become dependent through the common Gumbel sample. (The Gumbel samples being independent of the logits data is not enough, which is what the linked comment says.) These facts are demonstrated in my notebook here.

like image 97
chicxulub Avatar answered Oct 19 '25 20:10

chicxulub