Logo Questions Linux Laravel Mysql Ubuntu Git Menu

How to implement a simple greedy multiset based algorithm in python


I would like to implement the following algorithm. For n and k, consider all combinations with repetitions in sorted order where we choose k numbers from {0,..n-1} with repetitions. For example, if n=5 and k =3 we have:

[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 0, 4), (0, 1, 1), (0, 1, 2), (0, 1, 3), (0, 1, 4), (0, 2, 2), (0, 2, 3), (0, 2, 4), (0, 3, 3), (0, 3, 4), (0, 4, 4), (1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4), (1, 2, 2), (1, 2, 3), (1, 2, 4), (1, 3, 3), (1, 3, 4), (1, 4, 4), (2, 2, 2), (2, 2, 3), (2, 2, 4), (2, 3, 3), (2, 3, 4), (2, 4, 4), (3, 3, 3), (3, 3, 4), (3, 4, 4), (4, 4, 4)]

I will treat each combination as a multiset from now on. I want to greedily go through these multisets and partition the list. A partition has the property the size of the intersection of all the multisets within it must be at least k-1. So in this case we have:

(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 0, 4)


 (0, 1, 1), (0, 1, 2), (0, 1, 3), (0, 1, 4)


(0, 2, 2), (0, 2, 3), (0, 2, 4)


(0, 3,  3), (0, 3, 4)


(0, 4, 4)

and so on.

In python you can iterate over the combinations as follows:

import itertools
for multiset in itertools.combinations_with_replacement(range(5),3):
    #Greedy algo

How can I create these partitions?

One problem I have is how to compute the size of the intersection of multisets. The intersection of multisets (2,1,2) and (3,2,2) has size 2, for example.

Here is the full answer for n=4, k=4.

(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 0, 2), (0, 0, 0, 3)
(0, 0, 1, 1), (0, 0, 1, 2), (0, 0, 1, 3)
(0, 0, 2, 2), (0, 0, 2, 3)
(0, 0, 3, 3)
(0, 1, 1, 1), (0, 1, 1, 2), (0, 1, 1, 3)
(0, 1, 2, 2), (0, 1, 2, 3)
(0, 1, 3, 3)
(0, 2, 2, 2), (0, 2, 2, 3)
(0, 2, 3, 3), (0, 3, 3, 3)
(1, 1, 1, 1), (1, 1, 1, 2), (1, 1, 1, 3)
(1, 1, 2, 2), (1, 1, 2, 3)
(1, 1, 3, 3)
(1, 2, 2, 2), (1, 2, 2, 3)
(1, 2, 3, 3), (1, 3, 3, 3)
(2, 2, 2, 2), (2, 2, 2, 3)
(2, 2, 3, 3), (2, 3, 3, 3)
(3, 3, 3, 3)
like image 835
graffe Avatar asked Apr 15 '17 18:04


2 Answers

One way to create the partitions is to iterate over your iterator and then compare each multiset* to the previous one. I tested 4 ways** to compare the multisets and the fastest I found was to test membership in an iterator of the previous multiset that is consumed and short-circuits once the membership test fails. If the number of equal items in the multiset and the previous multiset equals the length of the multiset minus 1 then the criteria to group them is met. Then a resulting output generator of lists is built up where you append items that meet the criteria to the previous list and start a new list containing the tuple otherwise, yielding the groups one at a time to minimize memory usage:

import itertools

def f(n,k):
    prev, group = None, []
    for multiset in itertools.combinations_with_replacement(range(n),k):
        if prev:
            it = iter(prev)
            for idx, item in enumerate(multiset):
                if item not in it:
            if idx == len(multiset) - 1:
        if group:
            yield group
        group = [multiset]
        prev = multiset
    yield group

Test cases


for item in f(4,4):


[(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 0, 2), (0, 0, 0, 3)]
[(0, 0, 1, 1), (0, 0, 1, 2), (0, 0, 1, 3)]
[(0, 0, 2, 2), (0, 0, 2, 3)]
[(0, 0, 3, 3)]
[(0, 1, 1, 1), (0, 1, 1, 2), (0, 1, 1, 3)]
[(0, 1, 2, 2), (0, 1, 2, 3)]
[(0, 1, 3, 3)]
[(0, 2, 2, 2), (0, 2, 2, 3)]
[(0, 2, 3, 3), (0, 3, 3, 3)]
[(1, 1, 1, 1), (1, 1, 1, 2), (1, 1, 1, 3)]
[(1, 1, 2, 2), (1, 1, 2, 3)]
[(1, 1, 3, 3)]
[(1, 2, 2, 2), (1, 2, 2, 3)]
[(1, 2, 3, 3), (1, 3, 3, 3)]
[(2, 2, 2, 2), (2, 2, 2, 3)]
[(2, 2, 3, 3), (2, 3, 3, 3)]
[(3, 3, 3, 3)]


for item in f(5,3):


[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 0, 4)]
[(0, 1, 1), (0, 1, 2), (0, 1, 3), (0, 1, 4)]
[(0, 2, 2), (0, 2, 3), (0, 2, 4)]
[(0, 3, 3), (0, 3, 4)]
[(0, 4, 4)]
[(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]
[(1, 2, 2), (1, 2, 3), (1, 2, 4)]
[(1, 3, 3), (1, 3, 4)]
[(1, 4, 4)]
[(2, 2, 2), (2, 2, 3), (2, 2, 4)]
[(2, 3, 3), (2, 3, 4)]
[(2, 4, 4)]
[(3, 3, 3), (3, 3, 4)]
[(3, 4, 4), (4, 4, 4)]

* I'm calling them multisets to match your terminology but their actually tuples (ordered and immutable data structures); using a collections.Counter object, for example Counter((0, 0, 0, 1)) returns Counter({0: 3, 1: 1}), and decrementing would be like a true multiset approach but I found this to be slower because using the order is actually useful.

** Other slower functions that give the same output that I tested:

def f2(n,k):
    prev, group = None, []
    for multiset in itertools.combinations_with_replacement(range(n),k):
        if prev:
            if sum(item1 == item2 for item1, item2 in zip(prev,multiset)) == len(multiset) - 1:
        if group:
            yield group
        group = [multiset]
        prev = multiset
    yield group

def f3(n,k):
    prev, group = None, []
    for multiset in itertools.combinations_with_replacement(range(n),k):
        if prev:
            lst = list(prev)
            for item in multiset:
                if item in lst:
            if len(multiset) - len(lst) == len(multiset) - 1:
        if group:
            yield group
        group = [multiset]
        prev = multiset
    yield group

import collections
def f4(n,k):
    prev, group = None, []
    for multiset in itertools.combinations_with_replacement(range(n),k):
        if prev:
            if sum((collections.Counter(prev) - collections.Counter(multiset)).values()) == 1:
        if group:
            yield group
        group = [multiset]
        prev = multiset
    yield group

Example timings:

from timeit import timeit
list(f(11,10)) == list(f2(11,10)) == list(f3(11,10)) == list(f4(11,10))
# True
timeit(lambda: list(f(11,10)), number = 10)
# 4.19157001003623
timeit(lambda: list(f2(11,10)), number = 10)
# 7.32002648897469
timeit(lambda: list(f3(11,10)), number = 10)
# 6.236868146806955
timeit(lambda: list(f4(11,10)), number = 10)
# 47.20136355608702

Note all approaches becomes slow for large values of n and k because of the large number of combinations generated.

like image 60
Chris_Rands Avatar answered Sep 21 '22 09:09


We can view the tuples in the set/list you want to partition as numbers of length k with base n. Viewed as numbers, your algorithm is greedy on a smallest number first basis. Let the set of all numbers with k "digits" and base n be denoted N(k,n). Ignoring the fact that N(k,n) isn't exactly the list you want to partition for now, we can partition N(k,n) by the partition criteria, greedily on a smallest first basis pretty trivially; by counting from 0 (i.e. 00000 in the case of k=5 for example), and creating a new partition every time there is a carry as we count (i.e. overflow from digit i to digit i+1). I.e. the rule is: carry <=> new_partition.

Proof: Suppose A is the value after the the carry, and the carry was on into i-th digit. A shares a common prefix with all numbers in the previous partition before the carry upto but excluding the i-th, and is therefore at least 1 different. A only shares a suffix after i with one other previous (smaller) number, but that number is already in a partition with other numbers different by more than 1 from A, so A starts a new partition.

However, according to your specification, we are only considering a subset of N(k,n); X, where for any x in X, x[i] <= x[j] when i > j. This adds a slight complication to the above carry <=> new partition rule. Now:

  • new_partition => carry
  • but carry does not necessarily imply new_partition

There is only one condition where carry does not imply new_partition: There has just been a carry creating a new partition, then there is another carry, caused by the x[i] <= x[j] when i > j rule. This next carry does not cause a change by more than one so doesn't imply a new partition.


class ExpNum:
  ''' Represents a number with base @base, @size digits, and funny successor semantics. '''
  def __init__(self, base, size):
    if size <= 0 or base <= 1:
      raise Exception("Bad args")
    self.size = size
    self.base = base
    self.number = [0]*size
    self.zero = [0]*size

  def increment(self):
    ''' Increment number by one. If we carry return index of carry else return -1. '''
    carried = -1
    for i in reversed(range(0, len(self.number))):
      self.number[i] = (self.number[i]+1)%self.base
      if self.number[i] != 0:
      carried = i
    if carried >= 0:
    return carried

  def pullup(self):
    ''' Ensure x[i] <= x[j] when i > j '''
    for i in range(0, len(self.number)):
      if self.number[i] == 0 and i > 0:
        self.number[i] = self.number[i-1]

  def out_by_one_partition(self):
    ''' Do the partition by counting from 0 to n**k '''
    self.number = [0]*self.size
    just_carried = False
    partition = [list(self.number)]
    carried = self.increment()
    while self.number != self.zero:
      # Check for exception to carry => new partition.
      if carried >= 0 and not (just_carried and list(self.number)[carried] == (self.base -1) and len(partition) == 1):
        partition = []
      partition += [list(self.number)]
      just_carried = carried >= 0
      carried = self.increment()


from ExpNum import ExpNum
from timeit import timeit
from pprint import pprint
print(timeit(lambda: list(ExpNum(11,10).out_by_one_partition()), number = 10))

Test Result:

[[[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 2], [0, 0, 0, 3]],
 [[0, 0, 1, 1], [0, 0, 1, 2], [0, 0, 1, 3]],
 [[0, 0, 2, 2], [0, 0, 2, 3]],
 [[0, 0, 3, 3]],
 [[0, 1, 1, 1], [0, 1, 1, 2], [0, 1, 1, 3]],
 [[0, 1, 2, 2], [0, 1, 2, 3]],
 [[0, 1, 3, 3]],
 [[0, 2, 2, 2], [0, 2, 2, 3]],
 [[0, 2, 3, 3], [0, 3, 3, 3]],
 [[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 1, 3]],
 [[1, 1, 2, 2], [1, 1, 2, 3]],
 [[1, 1, 3, 3]],
 [[1, 2, 2, 2], [1, 2, 2, 3]],
 [[1, 2, 3, 3], [1, 3, 3, 3]],
 [[2, 2, 2, 2], [2, 2, 2, 3]],
 [[2, 2, 3, 3], [2, 3, 3, 3]],
 [[3, 3, 3, 3]]]
like image 33
spinkus Avatar answered Sep 24 '22 09:09
