Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Permutation where element can be repeated specific times and do it fast

Tags:

python

I have been looking for a function like this, sadly I was not able to found it.

Here a code that do what I write:

import itertools

#Repeat every element specific times
data = {
  1: 3,
  2: 1,
  3: 2
}

#Length
n = 0
to_rep = []
for i in data:
  to_rep += [i]*data[i]
  n += data[i]

#itertools will generate also duplicated ones
ret = itertools.permutations(to_rep, n)
#clean dups
ret = list(set(ret))

So, the code will show all lists of length 6, where there is 3 times "1", 1 time "2", and 2 times "3", the code works.

So.... the challenge here is time, this method is too expensive! which would be the fastest way to do this?

I have tested this with some samples of 27 times True and one time False, which is not much, in total there is 28 ways to do it, but this code takes forever... there is even more ways, but I would like to know a more efficient one.

I was able to write a code when we have two elements, like True/False, but seems a lot more complex with more than two elements.

def f(n_true, n_false):
  ret = []
  length = n_true + n_false
  full = [True]*length
  for rep in itertools.combinations(range(length), n_false):
    full_ = full.copy()
    for i in rep:
      full_[i] = False
    ret.append(full_)
  return ret

Exists a function that already do this? Which is best way to run this?

like image 304
Latot Avatar asked Dec 29 '25 06:12

Latot


1 Answers

Here are several options.

Simple recursive algorithm

You can use the following recursive function that iterates over keys:

def multiset_permutations(d: dict) -> list:
    if not d:
        return [tuple()]
    else:
        perms = []
        for k in d.keys():
            remainder = d.copy()
            if remainder[k] > 1:
                remainder[k] -= 1
            else:
                del remainder[k]
            perms += [(k,) + q for q in multiset_permutations(remainder)]
        return perms

Example usage:

In: data = {1: 3, 2: 1, 3: 2}
In: multiset_permutations(data)
Out: 
[(1, 1, 1, 2, 3, 3),
 (1, 1, 1, 3, 2, 3),
 (1, 1, 1, 3, 3, 2),
 (1, 1, 2, 1, 3, 3),
 (1, 1, 2, 3, 1, 3),
...

To see why it works, notice that all permutations must start with one of the keys as the first element. Once you have chosen the first element, 'k' from among the keys, the permutations that start with that element have the form [k] + q, where q is a permutation of the remainder multiset where the count of k is reduced by one (or deleted when it hits zero). So, the code simply chooses all first elements from among the keys, then recurses on the remainders.

For an explanation of the base case, see: https://math.stackexchange.com/questions/4293329/does-the-set-of-permutations-of-an-empty-set-contain-an-empty-set

This has a huge benefit in algorithmic operations when the number of keys is small but the counts per key is large, because the branching factor at each recursion is reduced to the number of keys at that stage (rather than the total size which may be much larger).

Using the package more_itertools

The function distinct_permutations() in the more_itertools package can do this. Thanks for Kelly Bundy for suggesting this in the comments.

import more_itertools

def multiset_to_list(d: dict) -> list:
    l = []
    for k, v in d.items():
        l += [k]*v
    return l

def multiset_permutations_more_itertools(d: dict) -> list:
    l = multiset_to_list(d)
    return list(more_itertools.distinct_permutations(l, len(l)))

Using the sympy package

You can use the function multiset_permutations() in the sympy package, as detailed in Andrej Kesely's answer.

from sympy.utilities import iterables

def multiset_permutations_sympy(d: dict) -> list:
    return [tuple(x) for x in iterables.multiset_permutations(multiset_to_list(d))]

Timing comparison on 2-group example

Here we compare the timing of these different methods on a generalization of the 2-group example from the question. The recursive method is slower than more_itertools.distinct_permutations(), faster than sympy.utilities.iterables.multiset_permutations() for small to moderate n, and slightly slower for large n. The iteratools.permutations() method from the original question is not feasible for moderate or large n because of it's poor asymptotic complexity.

To summarize, if you want peak performance, you should use more_itertools.distinct_permutations(). The recursive code would be useful if you don't want to add dependencies to your project, or if you are writing in a different language than python, or for educational purposes if you have general interest in the algorithm. The sympy version is also competitive and could be useful if you would prefer to have sympy as a dependency of your project.

import itertools
import more_itertools
from sympy.utilities import iterables
from time import time
import matplotlib.pyplot as plt

def multiset_to_list(d: dict) -> list:
    l = []
    for k, v in d.items():
        l += [k]*v
    return l

def multiset_permutations_itertools(d: dict) -> list:
    l = multiset_to_list(d)
    return list(set(itertools.permutations(l, len(l))))

def multiset_permutations_more_itertools(d: dict) -> list:
    l = multiset_to_list(d)
    return list(more_itertools.distinct_permutations(l, len(l)))

def multiset_permutations_sympy(d: dict) -> list:
    return [tuple(x) for x in iterables.multiset_permutations(multiset_to_list(d))]

def multiset_permutations(d: dict) -> list:
    if not d:
        return [tuple()]
    else:
        perms = []
        for k in d.keys():
            remainder = d.copy()
            if remainder[k] > 1:
                remainder[k] -= 1
            else:
                del remainder[k]
            perms += [(k,) + q for q in multiset_permutations(remainder)]
        return perms

def mean_timing(f: callable, x, num_samples=10):
    total_time = 0.0
    for _ in range(num_samples):
        t = time()
        y = f(x)
        dt = time() - t
        total_time += dt
    mean_time = total_time / num_samples
    return y, mean_time

timings_itertools = []
timings_more_itertools = []
timings_sympy = []
timings = []
nn = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]#, 2048, 4096]
for n in nn:
    print('n=', n)
    data = {0:n, 1:1}

    ret,                dt                  = mean_timing(multiset_permutations,                data)
    ret_more_itertools, dt_more_itertools   = mean_timing(multiset_permutations_more_itertools, data)
    ret_sympy,          dt_sympy            = mean_timing(multiset_permutations_sympy,          data)
    print('dt=',                dt)
    print('dt_more_itertools=', dt_more_itertools)
    print('dt_sympy=',          dt_sympy)
    assert(set(ret) == set(ret_more_itertools))
    assert(set(ret) == set(ret_sympy))
    timings.append(dt)
    timings_more_itertools.append(dt_more_itertools)
    timings_sympy.append(dt_sympy)
    if n < 11:
        ret_itertools, dt_itertools = mean_timing(multiset_permutations_itertools, data)
        assert (set(ret) == set(ret_itertools))
        print('dt_itertools=', dt_itertools)
        timings_itertools.append(dt_itertools)

plt.figure()
plt.loglog(nn, timings)
plt.loglog(nn, timings_more_itertools)
plt.loglog(nn, timings_sympy)
plt.loglog(nn[:len(timings_itertools)], timings_itertools)
plt.xlabel('Problem size, n')
plt.ylabel('Mean time (seconds)')
plt.title('Multiset permutation timings comparison')
plt.legend(['recursive', 'more_itertools', 'sympy', 'itertools'])
plt.show()

plt.savefig('multiset_permutations_timing_comparison.png', bbox_inches='tight')

And the output:

multiset permutations timing comparison

like image 177
Nick Alger Avatar answered Dec 30 '25 21:12

Nick Alger



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!