Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Need help in understanding Dynamic Programming approach for "balanced 0-1 matrix"?

Problem: I am struggling to understand/visualize the Dynamic Programming approach for "A type of balanced 0-1 matrix in "Dynamic Programming - Wikipedia Article."

Wikipedia Link: https://en.wikipedia.org/wiki/Dynamic_programming#A_type_of_balanced_0.E2.80.931_matrix

I couldn't understand how the memoization works when dealing with a multidimensional array. For example, when trying to solve the Fibonacci series with DP, using an array to store previous state results is easy, as the index value of the array store the solution for that state.

Can someone explain DP approach for the "0-1 balanced matrix" in simpler manner?

like image 837
ruthra Avatar asked May 16 '16 11:05

ruthra


People also ask

What is dynamic programming approach?

Dynamic programming is a technique that breaks the problems into sub-problems, and saves the result for future purposes so that we do not need to compute the result again. The subproblems are optimized to optimize the overall solution is known as optimal substructure property.

Does dynamic programming use Matrix?

A dynamic programming algorithmTake the sequence of matrices and separate it into two subsequences. Find the minimum cost of multiplying out each subsequence. Add these costs together, and add in the cost of multiplying the two result matrices.

Which algorithm uses a dynamic programming approach?

Algorithms that use dynamic programming (from wikipedia)Beat tracking in Music Information Retrieval. Stereo algorithms for solving the Correspondence problem used in stereo vision. The Bellman-Ford algorithm for finding the shortest distance in a graph. Some approximate solution methods for the linear search problem.

What is dynamic programming explain with example?

Dynamic Programming ExampleA fibonacci series is the sequence of numbers in which each number is the sum of the two preceding ones. For example, 0,1,1, 2, 3 . Here, each number is the sum of the two preceding numbers. Algorithm. Let n be the number of terms.


1 Answers

Wikipedia offered both a crappy explanation and a not ideal algorithm. But let's work with it as a starting place.

First let's take the backtracking algorithm. Rather than put the cells of the matrix "in some order", let's go everything in the first row, then everything in the second row, then everything in the third row, and so on. Clearly that will work.

Now let's modify the backtracking algorithm slightly. Instead of going cell by cell, we'll go row by row. So we make a list of the n choose n/2 possible rows which are half 0 and half 1. Then have a recursive function that looks something like this:

def count_0_1_matrices(n, filled_rows=None):
    if filled_rows is None:
        filled_rows = []
    if some_column_exceeds_threshold(n, filled_rows):
        # Cannot have more than n/2 0s or 1s in any column
        return 0
    else:
        answer = 0
        for row in possible_rows(n):
            answer = answer + count_0_1_matrices(n, filled_rows + [row])
        return answer

This is a backtracking algorithm like what we had before. We are just doing whole rows at a time, not cells.

But notice, we're passing around more information than we need. There is no need to pass in the exact arrangement of rows. All that we need to know is how many 1s are needed in each remaining column. So we can make the algorithm look more like this:

def count_0_1_matrices(n, still_needed=None):
    if still_needed is None:
        still_needed = [int(n/2) for _ in range(n)]

    # Did we overrun any column?
    for i in still_needed:
        if i < 0:
            return 0

    # Did we reach the end of our matrix?
    if 0 == sum(still_needed):
        return 1

    # Calculate the answer by recursion.
    answer = 0
    for row in possible_rows(n):
        next_still_needed = [still_needed[i] - row[i] for i in range(n)]
        answer = answer + count_0_1_matrices(n, next_still_needed)

    return answer

This version is almost the recursive function in the Wikipedia version. The main difference is that our base case is that after every row is finished, we need nothing, while Wikipedia would have us code up the base case to check the last row after every other is done.

To get from this to a top-down DP, you only need to memoize the function. Which in Python you can do by defining then adding an @memoize decorator. Like this:

from functools import wraps

def memoize(func):
    cache = {}
    @wraps(func)
    def wrap(*args):
        if args not in cache:
            cache[args] = func(*args)
        return cache[args]
    return wrap

But remember that I criticized the Wikipedia algorithm? Let's start improving it! The first big improvement is this. Do you notice that the order of the elements of still_needed can't matter, just their values? So just sorting the elements will stop you from doing the calculation separately for each permutation. (There can be a lot of permutations!)

@memoize
def count_0_1_matrices(n, still_needed=None):
    if still_needed is None:
        still_needed = [int(n/2) for _ in range(n)]

    # Did we overrun any column?
    for i in still_needed:
        if i < 0:
            return 0

    # Did we reach the end of our matrix?
    if 0 == sum(still_needed):
        return 1

    # Calculate the answer by recursion.
    answer = 0
    for row in possible_rows(n):
        next_still_needed = [still_needed[i] - row[i] for i in range(n)]
        answer = answer + count_0_1_matrices(n, sorted(next_still_needed))

    return answer

That little innocuous sorted doesn't look important, but it saves a lot of work! And now that we know that still_needed is always sorted, we can simplify our checks for whether we are done, and whether anything went negative. Plus we can add an easy check to filter out the case where we have too many 0s in a column.

@memoize
def count_0_1_matrices(n, still_needed=None):
    if still_needed is None:
        still_needed = [int(n/2) for _ in range(n)]

    # Did we overrun any column?
    if still_needed[-1] < 0:
        return 0

    total = sum(still_needed)
    if 0 == total:
        # We reached the end of our matrix.
        return 1
    elif total*2/n < still_needed[0]:
        # We have total*2/n rows left, but won't get enough 1s for a
        # column.
        return 0

    # Calculate the answer by recursion.
    answer = 0
    for row in possible_rows(n):
        next_still_needed = [still_needed[i] - row[i] for i in range(n)]
        answer = answer + count_0_1_matrices(n, sorted(next_still_needed))

    return answer

And, assuming you implement possible_rows, this should both work and be significantly more efficient than what Wikipedia offered.

=====

Here is a complete working implementation. On my machine it calculated the 6'th term in under 4 seconds.

#! /usr/bin/env python

from sys import argv
from functools import wraps

def memoize(func):
    cache = {}
    @wraps(func)
    def wrap(*args):
        if args not in cache:
            cache[args] = func(*args)
        return cache[args]
    return wrap

@memoize
def count_0_1_matrices(n, still_needed=None):
    if 0 == n:
        return 1

    if still_needed is None:
        still_needed = [int(n/2) for _ in range(n)]

    # Did we overrun any column?
    if still_needed[0] < 0:
        return 0

    total = sum(still_needed)
    if 0 == total:
        # We reached the end of our matrix.
        return 1
    elif total*2/n < still_needed[-1]:
        # We have total*2/n rows left, but won't get enough 1s for a
        # column.
        return 0
    # Calculate the answer by recursion.
    answer = 0
    for row in possible_rows(n):
        next_still_needed = [still_needed[i] - row[i] for i in range(n)]
        answer = answer + count_0_1_matrices(n, tuple(sorted(next_still_needed)))

    return answer

@memoize
def possible_rows(n):
    return [row for row in _possible_rows(n, n/2)]


def _possible_rows(n, k):
    if 0 == n:
        yield tuple()
    else:
        if k < n:
            for row in _possible_rows(n-1, k):
                yield tuple(row + (0,))
        if 0 < k:
            for row in _possible_rows(n-1, k-1):
                yield tuple(row + (1,))

n = 2
if 1 < len(argv):
    n = int(argv[1])

print(count_0_1_matrices(2*n)))
like image 174
btilly Avatar answered Nov 15 '22 09:11

btilly