Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Implementing the generalized birthday paradox in Python

My question is about numerical problems I am running into when implementing a probability function, and not about the probability/mathematics behind it. I'm also aware that my code below is probably not well-optimized (e.g. I could vectorize the first function if I use exact=False in comb). So I'm open to optimization suggestions, but it's not really my main concern right now.

I am trying to numerically verify the formula given here for "the probability of getting m unique values from [0,k) when choosing n times".

To do this, in Python 3.6.5, I am using numpy.ramdom.choice(k, n, replace=True) to obtain a multiset, and then counting the unique values in the multiset, saving this number. And repeat.

For smallish values of k and n I get good agreement between the simulations and the formula, so I'm pretty happy that it is more-or-less correct. However, when k and n are slightly larger, I obtain negative values from the formula. I suspect this is because it includes products of tiny fractions and very large factorials, and so precision can be lost at some of these stages.

To try and combat this, I implemented the same formula but using logs wherever I could, before finally exponentiating. Annoyingly, it didn't really help, as can be seen in the output of my code given below.

My question is therefore, does anyone have a suggestion as to how I can continue to implement this formula for larger values of n and k? Am I right in thinking it's numerical weirdness introduced by the products of large and small numbers?

My code:

import numpy as np
import numpy.random as npr
from scipy.special import comb, gammaln
import matplotlib.pyplot as plt

def p_unique_birthdays(m, k, n):
    """PMF for obtaining m unique elements when selecting from [0,k) n times.

    I wanted to use exact=True to see if that helped, hence why this is not
    vectorised.
    """
    total = 0
    for i in range(m):
        total += (-1)**i * comb(m, i, exact=True) * ((m-i)/k)**n
    return comb(k, m, exact=True) * total

def p_unique_birthdays_logs(m, k, n):
    """PMF for obtaining m unique elements when selecting from [0,k) n times.

    I use logs to try and deal with some of the numerical craziness that seems
    to arise.
    """
    total = 0
    for i in range(m):
        log_mCi = gammaln(m+1) - gammaln(i+1) - gammaln(m-i+1)
        log_exp_bit = n * (np.log(m-i) - np.log(k))
        total += (-1)**i * np.exp(log_mCi + log_exp_bit)
    return comb(k, m, exact=True) * total

def do_stuff(k, n, pmf):
    n_samples = 50000
    p_ms = np.zeros(n)
    for i in range(n):
        temp_p = pmf(i+1, k, n)
        p_ms[i] = temp_p
    print("Sum of probabilities:", p_ms.sum())

    samples = np.zeros(n_samples)
    for i in range(n_samples):
        samples[i] = np.unique(npr.choice(k, n, replace=True)).size

    # So that the histogram is centered on the correct integers.
    d = np.diff(np.unique(samples)).min()
    left_of_first_bin = samples.min() - float(d)/2
    right_of_last_bin = samples.max() + float(d)/2
    fig = plt.figure(figsize=(8,5))
    ax = fig.add_subplot(111)
    ax.grid()
    ax.bar(range(1, n+1), p_ms, color="C0",
            label=labels[j])
    ax.hist(samples, np.arange(left_of_first_bin, right_of_last_bin + d, d),
            alpha=0.5, color="C1", density=True, label="Samples")
    ax.legend()
    ax.set_xlabel("Unique birthdays")
    ax.set_ylabel("Normalised frequency")
    ax.set_title(f"k = {k}, n = {n}")
    #fig.savefig(f"k{k}_n{n}_{labels[j]}.png")
    plt.show()

random_seed = 1234
npr.seed(random_seed)

labels = ["PMF", "PMF (logs)"]
pmfs = [p_unique_birthdays, p_unique_birthdays_logs]
for j in range(2):
    for k, n in [(30, 20), (60, 40)]:
        do_stuff(k, n, pmfs[j])

The outputted figures: output output output output

Thanks for any ideas/advice/suggestions.

like image 428
combinatoricky Avatar asked Feb 24 '26 23:02

combinatoricky


2 Answers

You were right, it was some odd numeric reason.

Change this line:

total += (-1)**i * comb(m, i, exact=True) * ((m-i)/k)**n

to this:

total += (-1)**i * comb(m, i, exact=True) * ((m-i)**n)/(k**n)

For some reason, if you force a different operation order, things come out nicely.

You might have to spend some more time figuring out how to modify your "log'd" version, but given that the change above fixes things, you might just want to discard the "log'd" version altogether.

Hope it helps!

like image 130
Felipe D. Avatar answered Feb 27 '26 13:02

Felipe D.


You can use built-in decimal module to increase precision.

from decimal import *

getcontext().prec = 10000

def factorial(n):
    res = Decimal(1)
    for i in range(int(n)):
        res = res * Decimal(i + 1)
    return res

def binomial_coefficient(n, k):
    return factorial(n) / factorial(k) / factorial(n - k)

def p_unique_birthdays(m, k, n):
    m = Decimal(m)
    k = Decimal(k)
    n = Decimal(n)
    total = Decimal(0)
    for i in range(int(m) + 1):
        total += Decimal((-1) ** i) * binomial_coefficient(m, i) * binomial_coefficient(k, m) * ((m - i) / k) ** n
    return total

print(p_unique_birthdays(49, 365, 50))

Above code prints 0.11484925 which is the same as http://www.wolframalpha.com/input/?i=sum+combination(49,x)combination(365,49)++(((49-x)%2F365)%5E50)+*+(-1)%5Ex,+x%3D0+to+49

like image 37
shaun shia Avatar answered Feb 27 '26 12:02

shaun shia