I have some code that uses log-probability. When I want to draw a sample from the probability distribution, I use
import numpy as np
probs = np.exp(logprobs)
probs /= probs.sum()
sample = np.random.choice(X, p=probs, size=1)[0]
But there is some overhead here in the exponentiation and division. And the numpy random.choice
function requires the probabilities are between 0 and 1, and sum to 1.
Are there any fast techniques for me to draw a sample using the nonnormalized log-probability array? I only ever need one sample at a time, and the frequency of drawing it just needs to be proportional to the log-probability.
Use the Gumbel-max trick. See more explanation and references in this answer on Cross Validated. Here's a minimal code example:
import numpy as np
# Assume we only have log-probabilities (for sampling, even logits will do)
log_probs = np.log([0.1, 0.2, 0.3, 0.4])
num_categories = len(log_probs)
# Sample a single category
gumbels = np.random.gumbel(size=num_categories)
sample = np.argmax(log_probs + gumbels)
Note that this function just returns indices. And I haven't yet verified that it's correct when replace=False
or in all cases where gumbels is not None
.
from typing import Union
import numpy as np
def sample_gumbels(size: Union[int, tuple[int]] = 1,
rng: Union[np.random.Generator, int] = None) -> np.ndarray:
if not isinstance(rng, np.random.Generator):
rng = np.random.default_rng(seed=rng)
return -np.log(-np.log(rng.uniform(size=size)))
def random_choice_log_space(
logits: np.ndarray,
replace: bool = True,
size: int = 1,
gumbels: np.ndarray = None,
rng: Union[np.random.Generator, int] = None,
shuffle_gumbels: bool = True,
) -> np.ndarray:
"""
Sample from a categorical distribution parametrized by logits or log-probabilities.
Parameters
----------
logits : np.ndarray
the last dimension contains log-probabilities or unnormalized logits
corresponding to the categorical distribution(s)
replace : bool, optional
whether or not to sample with or without replacement, by default True
size : int, optional
sample size, by default 1
gumbels : np.ndarray, optional
experimental feature: pre-computed Gumbel samples which will be added to
`logits`, by default None
rng : Union[np.random.Generator, int], optional
``np.random.Generator`` object or an integer seed, by default None
shuffle_gumbels : bool, optional
experimental feature: whether or not to shuffle `gumbels` (if provided)
before taking the argmax, by default True. If you're not careful, setting this
to False causes serial categorical samples to become correlated through
`gumbels`, which is wrong
Returns
-------
np.ndarray
sampled indices
Raises
------
ValueError
if `size` is less than 1
"""
if size < 1:
raise ValueError("size must be at least 1.")
is_gumbels_precomputed = gumbels is not None
# Create a Generator if needed
if not isinstance(rng, np.random.Generator):
rng = np.random.default_rng(seed=rng)
if not is_gumbels_precomputed:
# Independently sample as many Gumbels as needed
if replace:
# Pretty sure we have to generate Gumbels for each sample in this case.
# During addition w/ logits, they'll be broadcasted
_gumbels_shape = (size,) + logits.shape if size > 1 else logits.shape
else:
# We'll just take the top k
_gumbels_shape = logits.shape
gumbels = sample_gumbels(size=_gumbels_shape)
if is_gumbels_precomputed and shuffle_gumbels:
# Shuffling is unnecessary if the Gumbels were just randomly sampled.
# It's necessary if the Gumbels were pre-computed and plan to be re-used.
gumbels_original_shape = gumbels.shape
if len(gumbels_original_shape) > 1:
gumbels = gumbels.ravel()
# gumbels is 1-D. For some reason, choice is faster than shuffle and permuted
gumbels = rng.choice(gumbels, size=len(gumbels), replace=False, shuffle=True)
if len(gumbels_original_shape) > 1:
gumbels = gumbels.reshape(gumbels_original_shape)
gumbels_rescaled: np.ndarray = logits + gumbels
if replace:
# gumbels_rescaled has shape (size, logits.shape) b/c of broadcasting
return gumbels_rescaled.argmax(axis=-1)
else:
# take the top k (k=sample size) indices, as noted here:
# https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/
return np.argpartition(gumbels_rescaled, -size, axis=-1)[..., -size:]
As noted in the docstring, you can pass log-probabilities or unnormalized logits to the input logits
. That's because these two inputs only differ by a constant—specifically, the log-sum-exp(probabilities)—which is irrelevant because an argmax is taken.
For random_choice_log_space
to be correct, it needs to independently sample from the probability distribution implied by logits
. The independence part is already clear. So we just need to compare the empirical distribution of samples to the actual distribution.
_probs = np.array([0.1, 0.2, 0.3, 0.4])
log_probs = np.log(_probs)
logits = np.log(_probs) + logsumexp(_probs, axis=-1)
# You start out with access to log_probs or logits
num_categories = len(_probs)
sample_size = 500_000
seed = 123
rng = np.random.default_rng(seed)
# helper function
def empirical_distr(discrete_samples):
return (pd.Series(discrete_samples)
.value_counts(normalize=True)
.sort_index()
.to_numpy())
# np.random.choice (select one at a time) AKA vanilla sampling
def random_choice_log_space_vanilla(logits, size, rng=None):
probs = softmax(logits, axis=-1)
if not isinstance(rng, np.random.Generator):
rng = np.random.default_rng(seed=rng)
return rng.choice(len(probs), p=probs, size=size, replace=True)
samples = random_choice_log_space_vanilla(logits, size=sample_size, rng=rng)
distr_vanilla = empirical_distr(samples)
# random_choice_log_space for log-probabilities input
samples = random_choice_log_space(log_probs, size=sample_size, rng=rng)
distr_log_probs = empirical_distr(samples)
# random_choice_log_space for logits input
samples = random_choice_log_space(logits, size=sample_size, rng=rng)
distr_logits = empirical_distr(samples)
print(pd.DataFrame({'rel error (vanilla)': (distr_vanilla - _probs)/_probs,
'rel error (log-probs)': (distr_log_probs - _probs)/_probs,
'rel error (logits)': (distr_logits - _probs)/_probs},
index=pd.Index(range(num_categories), name='category')))
rel error (vanilla) rel error (log-probs) rel error (logits)
category
0 -0.004760 0.006600 0.001760
1 0.000170 -0.004080 -0.000170
2 0.002773 0.001373 -0.000693
3 -0.000975 -0.000640 0.000165
Even at this super high sample size, the p-value for a chi-squared test is nearly 1 for all methods.
print(chisquare(f_obs=distr_vanilla, f_exp=_probs).pvalue)
print(chisquare(f_obs=distr_log_probs, f_exp=_probs).pvalue)
print(chisquare(f_obs=distr_logits, f_exp=_probs).pvalue)
0.999999997062783
0.9999999935077725
0.999999999914127
None of this work matters if random_choice_log_space
is always slower than softmaxing and using np.random.choice
. Luckily, there's one common-enough problem where random_choice_log_space
is faster. It's the problem in your question: you have logits
, and you want to sample exactly one element.
from time import time
from scipy.stats import trim_mean
def time_func(func, *args, num_replications: int=50, **kwargs) -> list[float]:
'''
Returns a list, `times`, where `times[i]` is the time it took to run
`func(*args, **kwargs)` at replication `i` for `i in range(num_replications)`.
'''
times = []
for _ in range(num_replications):
time_start = time()
_ = func(*args, **kwargs)
time_end = time()
times.append(time_end - time_start)
return times
category_sizes = np.power(2, np.arange(1, 14+1))
num_replications = 100
times_vanilla = []
times_gumbel = []
for size in category_sizes:
logits = np.random.normal(size=size)
times_vanilla.append(time_func(random_choice_log_space_vanilla, logits, size=1,
num_replications=num_replications))
times_gumbel.append(time_func(random_choice_log_space, logits, size=1,
num_replications=num_replications))
(pd.DataFrame({'vanilla': trim_mean(times_vanilla, 0.1, axis=1),
'Gumbel': trim_mean(times_gumbel, 0.1, axis=1)},
index=pd.Index(category_sizes, name='# categories'))
.plot.bar(title='Categorical sampling',
figsize=(8,5),
ylabel='mean wall-clock time (sec)'));
The plot (and me running it many times) demonstrates that the comparison becomes unstable for larger category sizes.
A note: precomputing and reusing Gumbel samples across draws would significantly speed up the trick. But you need to shuffle the Gumbel samples before adding them to logits. Otherwise, categorical samples across draws become dependent through the common Gumbel sample. (The Gumbel samples being independent of the logits data is not enough, which is what the linked comment says.) These facts are demonstrated in my notebook here.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With