Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Finding a better way to count matrices

I would like to count the number of 2d arrays with only 1 and 0 entries that have a disjoint pair of disjoint pairs of rows that have equal vector sums. For a 4 by 4 matrix the following code achieves this by just iterating over all of them and testing each one in turn.

import numpy as np
from itertools import combinations
n = 4
nxn = np.arange(n*n).reshape(n, -1)
count = 0
for i in xrange(2**(n*n)):
   A = (i >> nxn) %2
   p = 1
   for firstpair in combinations(range(n), 2):
       for secondpair in combinations(range(n), 2):
           if firstpair < secondpair and not set(firstpair) & set(secondpair):
              if (np.array_equal(A[firstpair[0]] + A[firstpair[1]], A[secondpair[0]] + A[secondpair[1]] )):
                  if (p):
                      count +=1
                      p = 0
print count

The output is 3136.

The problem with this is that it uses 2^(4^2) iterations and I would like to run it for n up to 8. Is there a cleverer way to count these without iterating over all the matrices? For example it seems pointless to create permutations of the same matrix over and over again.

like image 834
marshall Avatar asked Jan 15 '14 21:01

marshall


2 Answers

Computed in about a minute on my machine, with CPython 3.3:

4 3136
5 3053312
6 7247819776
7 53875134036992
8 1372451668676509696

Code, based on memoized inclusion-exclusion:

#!/usr/bin/env python3
import collections
import itertools

def pairs_of_pairs(n):
    for (i, j, k, m) in itertools.combinations(range(n), 4):
        (yield ((i, j), (k, m)))
        (yield ((i, k), (j, m)))
        (yield ((i, m), (j, k)))

def columns(n):
    return itertools.product(range(2), repeat=n)

def satisfied(pair_of_pairs, column):
    ((i, j), (k, m)) = pair_of_pairs
    return ((column[i] + column[j]) == (column[k] + column[m]))

def pop_count(valid_columns):
    return bin(valid_columns).count('1')

def main(n):
    pairs_of_pairs_n = list(pairs_of_pairs(n))
    columns_n = list(columns(n))
    universe = ((1 << len(columns_n)) - 1)
    counter = collections.defaultdict(int)
    counter[universe] = (- 1)
    for pair_of_pairs in pairs_of_pairs_n:
        mask = 0
        for (i, column) in enumerate(columns_n):
            mask |= (int(satisfied(pair_of_pairs, column)) << i)
        for (valid_columns, count) in list(counter.items()):
            counter[(valid_columns & mask)] -= count
    counter[universe] += 1
    return sum(((count * (pop_count(valid_columns) ** n)) for (valid_columns, count) in counter.items()))
if (__name__ == '__main__'):
    for n in range(4, 9):
        print(n, main(n))
like image 197
David Eisenstat Avatar answered Oct 01 '22 21:10

David Eisenstat


You can file this one under "better than nothing" ;-) Here's plain Python3 code that rethinks the problem a bit. Perhaps numpy tricks could speed it substantially, but hard to see how.

  1. "A row" here is an integer in range(2**n). So the array is just a tuple of integers.
  2. Because of that, it's dead easy to generate all arrays that are unique under row permutation via combinations_with_replacement(). That reduces the trip count on the outer loop from 2**(n**2) to (2**n+n-1)-choose-n). An enormous reduction, but still ...
  3. A precomputed dict maps pairs of rows (which means pairs of integers here!) to their vector sum as a tuple. So no array operations are required when testing, except to test the tuples for equality. With some more trickery, the tuples could be coded as (say) base-3 integers, reducing the inner-loop test to comparing two integers retrieved from a pair of dict lookups.
  4. The time and space required for that precomputed dict is relatively trivial, so no attempt was made to speed that part.
  5. The inner loop picks row indices 4 at a time, instead of your pair of loops each picking two indices at a time. It's faster to do all 4 in one gulp, in large part because there's no need to weed out pairs with a duplicated index.

Here's the code:

def calc_row_pairs(n):
    fmt = "0%db" % n
    rowpair2sum = dict()
    for i in range(2**n):
        row1 = list(map(int, format(i, fmt)))
        for j in range(2**n):
            row2 = map(int, format(j, fmt))
            total = tuple(a+b for a, b in zip(row1, row2))
            rowpair2sum[i, j] = total
    return rowpair2sum

def multinomial(n, ks):
    from math import factorial as f
    assert n == sum(ks)
    result = f(n)
    for k in ks:
        result //= f(k)
    return result

def count(n):
    from itertools import combinations_with_replacement as cwr
    from itertools import combinations
    from collections import Counter
    rowpair2sum = calc_row_pairs(n)
    total = 0
    class NextPlease(Exception):
        pass
    for a in cwr(range(2**n), n):
        try:
            for ix in combinations(range(n), 4):
                for ix1, ix2, ix3, ix4 in (
                       ix,
                       (ix[0], ix[2], ix[1], ix[3]),
                       (ix[0], ix[3], ix[1], ix[2])):
                    if rowpair2sum[a[ix1], a[ix2]] == \
                       rowpair2sum[a[ix3], a[ix4]]:
                        total += multinomial(n, Counter(a).values())
                        raise NextPlease
        except NextPlease:
            pass
    return total

That sufficed to find results through n=6, although it took a long time to finish the last one (how long? don't know - didn't time it - on the order of an hour, though - "long time" is relative ;-) ):

>>> count(4)
3136
>>> count(5)
3053312
>>> count(6)
7247819776

EDIT - removing some needless indexing

A nice speedup by changing the main function to this:

def count(n):
    from itertools import combinations_with_replacement as cwr
    from itertools import combinations
    from collections import Counter
    rowpair2sum = calc_row_pairs(n)
    total = 0
    for a in cwr(range(2**n), n):
        for r0, r1, r2, r3 in combinations(a, 4):
            if rowpair2sum[r0, r1] == rowpair2sum[r2, r3] or \
               rowpair2sum[r0, r2] == rowpair2sum[r1, r3] or \
               rowpair2sum[r0, r3] == rowpair2sum[r1, r2]:
                total += multinomial(n, Counter(a).values())
                break
    return total

EDIT - speeding the sum test

This is minor, but since this seems to be the best exact approach on the table so far, may as well squeeze some more out of it. As noted before, since each sum is in range(3), each tuple of sums can be replaced with an integer (viewing the tuple as giving the digits of a base-3 integer). Replace calc_row_pairs() like so:

def calc_row_pairs(n):
    fmt = "0%db" % n
    rowpair2sum = dict()
    for i in range(2**n):
        row1 = list(map(int, format(i, fmt)))
        for j in range(2**n):
            row2 = map(int, format(j, fmt))
            total = 0
            for a, b in zip(row1, row2):
                t = a+b
                assert 0 <= t <= 2
                total = total * 3 + t
            rowpair2sum[i, j] = total
    return rowpair2sum

I'm sure numpy has a much faster way to do that, but the time taken by calc_row_pairs() is insignificant, so why bother? BTW, the advantage to doing this is that the inner-loop == tests change from needing to compare tuples to just comparing small integers. Plain Python benefits from that, but I bet pypy could benefit even more.

like image 23
Tim Peters Avatar answered Oct 01 '22 21:10

Tim Peters