Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient algorithm for random sampling from a distribution while allowing updates?

This is the question I was asked some time ago on interview, I could not find answer for.

Given some samples S1, S2, ... Sn and their probability distributions(or weights, whatever it is called) P1, P2, .. Pn, design algorithm that randomly chooses sample taking into account its probability. the solution I came with is as follows:

  1. Build cumulative array of weights Ci, such

    C0 = 0; Ci = C[i-1] + Pi.

at the same time calculate T=P1+P2+...Pn. It takes O(n) time

  1. Generate uniformly random number R = T*random[0..1]
  2. Using binary search algorithm, return least i such Ci >= R. result is Si. It takes O(logN) time.

Now the actual question is: Suppose I want to change one of the initial Weights Pj. how to do this in better than O(n) time? other data structures are acceptable, but random sampling algorithm should not get worse than O(logN).

like image 840
board reader Avatar asked Jan 20 '12 19:01

board reader


3 Answers

One way to solve this is to rethink how your binary search tree containing the cumulative totals is built. Rather than building a binary search tree, think about having each node interpreted as follows:

  • Each node stores a range of values that are dedicated to the node itself.
  • Nodes in the left subtree represent sampling from the probability distribution just to the left of that range.
  • Nodes in the right subtree represent sampling from the probability distribution just to the right of that range.

For example, suppose our weights are 3, 2, 2, 2, 2, 1, and 1 for events A, B, C, D, E, F, and G. We build this binary tree holding A, B, C, D, E, F, and G:

               D
             /   \
           B       F
          / \     / \
         A   C   E   G

Now, we annotate the tree with probabilities. Since A, C, E, and G are all leaves, we give each of them probability mass one:

               D
             /   \
           B       F
          / \     / \
         A   C   E   G
         1   1   1   1

Now, look at the tree for B. B has weight 2 of being chosen, A has weight 3 of being chosen, and C has probability 2 of being chosen. If we normalize these to the range [0, 1), then A accounts for 3/7 of the probability and B and C each account for 2/7s. Thus we have the node for B say that anything in the range [0, 3/7) goes to the left subtree, anything in the range [3/7, 5/7) maps to B, and anything in the range [5/7, 1) maps to the right subtree:

                   D
                 /   \
           B              F
 [0, 3/7) / \  [5/7, 1)  / \
         A   C          E   G
         1   1          1   1

Similarly, let's process F. E has weight 2 of being chosen while F and G each have probability weight 1 of being chosen. Thus the subtree for E accounts for 1/2 of the probability mass here, the node F accounts for 1/4, and the subtree for G accounts for 1/4. This means we can assign probabilities as

                       D
                     /   \
           B                        F
 [0, 3/7) / \  [5/7, 1)   [0, 1/2) / \  [3/4, 1)
         A   C                    E   G
         1   1                    1   1

Finally, let's look at the root. The combined weight of the left subtree is 3 + 2 + 2 = 7. The combined weight of the right subtree is 2 + 1 + 1 = 4. The weight of D itself is 2. Thus the left subtree has probability 7/13 of being picked, D has probability 2/13 of being picked, and the right subtree has probability 4/13 of being picked. We can thus finalized the probabilities as

                       D
           [0, 7/13) /   \ [9/13, 1)
           B                        F
 [0, 3/7) / \  [5/7, 1)   [0, 1/2) / \  [3/4, 1)
         A   C                    E   G
         1   1                    1   1

To generate a random value, you would repeat the following:

  • Starting at the root:
    • Choose a uniformly-random value in the range [0, 1).
    • If it's in the range for the left subtree, descend into it.
    • If it's in the range for the right subtree, descend into it.
    • Otherwise, return the value corresponding to the current node.

The probabilities themselves can be determined recursively when the tree is built:

  • The left and right probabilities are 0 for any leaf node.
  • If an interior node itself has weight W, its left tree has total weight WL, and its right tree has total weight WR, then the left probability is (WL) / (W + WL + WR) and the right probability is (WR) / (W + WL + WR).

The reason that this reformulation is useful is that it gives us a way to update probabilities in O(log n) time per probability updated. In particular, let's think about what invariants are going to change if we update some particular node's weight. For simplicity, let's assume the node is a leaf for now. When we update the leaf node's weight, the probabilities are still correct for the leaf node, but they're incorrect for the node just above it, because the weight of one of that node's subtrees has changed. Thus we can (in O(1) time) recompute the probabilities for the parent node by just using the same formula as above. But then the parent of that node no longer has the correct values because one of its subtree weights has changed, so we can recompute the probability there as well. This process repeats all the way back up to the root of the tree, with us doing O(1) computation per level to rectify the weights assigned to each edge. Assuming that the tree is balanced, we therefore have to do O(log n) total work to update one probability. The logic is identical if the node isn't a leaf node; we just start somewhere in the tree.

In short, this gives

  • O(n) time to construct the tree (using a bottom-up approach),
  • O(log n) time to generate a random value, and
  • O(log n) time to update any one value.

Hope this helps!

like image 144
templatetypedef Avatar answered Nov 08 '22 13:11

templatetypedef


Instead of an array, store the search structured as a balanced binary tree. Every node of the tree should store the total weight of the elements it contains. Depending on the value of R, the search procedure either returns the current node or searches through the left or right subtree.

When the weight of an element is changed, the updating of the search structure is a matter of adjusting the weights on the path from the element to the root of the tree.

Since the tree is balanced, the search and the weight update operations are both O(log N).

like image 31
antonakos Avatar answered Nov 08 '22 13:11

antonakos


For those of you who would like some code, here's a python implementation:

import numpy


class DynamicProbDistribution(object):
  """ Given a set of weighted items, randomly samples an item with probability
  proportional to its weight. This class also supports fast modification of the
  distribution, so that changing an item's weight requires O(log N) time. 
  Sampling requires O(log N) time. """

  def __init__(self, weights):
    self.num_weights = len(weights)
    self.weights = numpy.empty((1+len(weights),), 'float32')
    self.weights[0] = 0 # Not necessary but easier to read after printing
    self.weights[1:] = weights
    self.weight_tree = numpy.zeros((1+len(weights),), 'float32')
    self.populate_weight_tree()

  def populate_weight_tree(self):
    """ The value of every node in the weight tree is equal to the sum of all 
    weights in the subtree rooted at that node. """
    i = self.num_weights
    while i > 0:
      weight_sum = self.weights[i]
      twoi = 2*i
      if twoi < self.num_weights:
        weight_sum += self.weight_tree[twoi] + self.weight_tree[twoi+1]
      elif twoi == self.num_weights:
        weight_sum += self.weights[twoi]
      self.weight_tree[i] = weight_sum
      i -= 1

  def set_weight(self, item_idx, weight):
    """ Changes the weight of the given item. """
    i = item_idx + 1
    self.weights[i] = weight
    while i > 0:
      weight_sum = self.weights[i]
      twoi = 2*i
      if twoi < self.num_weights:
        weight_sum += self.weight_tree[twoi] + self.weight_tree[twoi+1]
      elif twoi == self.num_weights:
        weight_sum += self.weights[twoi]
      self.weight_tree[i] = weight_sum
      i /= 2 # Only need to modify the parents of this node

  def sample(self):
    """ Returns an item index sampled from the distribution. """
    i = 1
    while True:
      twoi = 2*i

      if twoi < self.num_weights:
        # Two children
        val = numpy.random.random() * self.weight_tree[i]
        if val < self.weights[i]:
          # all indices are offset by 1 for fast traversal of the
          # internal binary tree
          return i-1
        elif val < self.weights[i] + self.weight_tree[twoi]:
          i = twoi # descend into the subtree
        else:
          i = twoi + 1

      elif twoi == self.num_weights:
        # One child
        val = numpy.random.random() * self.weight_tree[i]
        if val < self.weights[i]:
          return i-1
        else:
          i = twoi

      else:
        # No children
        return i-1


def validate_distribution_results(dpd, weights, samples_per_item=1000):
  import time

  bins = numpy.zeros((len(weights),), 'float32')
  num_samples = samples_per_item * numpy.sum(weights)

  start = time.time()
  for i in xrange(num_samples):
    bins[dpd.sample()] += 1
  duration = time.time() - start

  bins *= numpy.sum(weights)
  bins /= num_samples

  print "Time to make %s samples: %s" % (num_samples, duration)

  # These should be very close to each other
  print "\nWeights:\n", weights
  print "\nBins:\n", bins

  sdev_tolerance = 10 # very unlikely to be exceeded
  tolerance = float(sdev_tolerance) / numpy.sqrt(samples_per_item)
  print "\nTolerance:\n", tolerance

  error = numpy.abs(weights - bins)
  print "\nError:\n", error

  assert (error < tolerance).all()


#@test
def test_DynamicProbDistribution():
  # First test that the initial distribution generates valid samples.

  weights = [2,5,4, 8,3,6, 6,1,3, 4,7,9]
  dpd = DynamicProbDistribution(weights)

  validate_distribution_results(dpd, weights)

  # Now test that we can change the weights and still sample from the 
  # distribution.

  print "\nChanging weights..."
  dpd.set_weight(4, 10)
  weights[4] = 10
  dpd.set_weight(9, 2)
  weights[9] = 2
  dpd.set_weight(5, 4)
  weights[5] = 4
  dpd.set_weight(11, 3)
  weights[11] = 3

  validate_distribution_results(dpd, weights)

  print "\nTest passed"


if __name__ == '__main__':
  test_DynamicProbDistribution()
like image 1
Ken Avatar answered Nov 08 '22 12:11

Ken