Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding a five-dimensional DP with bitshifts and XORs?

I was looking over the solution to this problem here, and I didn't quite understand how the dynamic programming (DP) worked.


A summary of the problem is as follows: You are given a 9x9 grid of either ones or zeroes, arranged in nine 3x3 subgrids as follows:

000 000 000
001 000 100
000 000 000

000 110 000
000 111 000
000 000 000

000 000 000
000 000 000
000 000 000

You need to find the minimum number of changes needed so that each of the nine rows, columns, and 3x3 subgrids contain an even number of 1's. Here, a change is defined as toggling a given element from 1 to 0 or vice-versa.


The solution involves dynamic programming, and each state consists of the minimum number of moves such that all rows up to the current row being look at have even parity (even number of ones).

However, I do not understand the details of their implementation. First off, in their memoization array

int memo[9][9][1<<9][1<<3][2];

what do each of the indexes represent? I gathered that the first two are for current row and column, the third is for column parity, the fourth is for subgrid parity, and the fifth is for row parity. However, why does the column parity need 2^9 elements whereas row parity needs only 2?

Next, how are the transitions between the states handled? I would assume that you go across the row trying each element and moving to the next row when done, but after seeing their code I am quite confused

  int& ref = memo[r][c][mc][mb][p];

  /* Try setting the cell to 1. */
  ref = !A[r][c] + solve(r, c + 1, mc ^ 1 << c, mb ^ 1 << c / 3, !p);

  /* Try setting the cell to 0. */
  ref = min(ref, A[r][c] + solve(r, c + 1, mc, mb, p));

How do they try setting the cell to one by flipping the current bit in the grid? And I understand how when you make it a one the row parity changes, as indicated by !p but I don't understand how column parity would be affected, or what mc ^ 1 << c does -- why do you need xor and bitshifts? Same goes for the subgrid parity -- mb ^ 1 << c / 3. What is it doing?

Could someone please explain how these work?

like image 465
1110101001 Avatar asked Jun 07 '14 23:06

1110101001


People also ask

How do you read dynamic programming bottom up?

The bottom-up approach (to dynamic programming) consists in first looking at the "smaller" subproblems, and then solve the larger subproblems using the solution to the smaller problems.

What is 1D and 2D dynamic programming?

These are those variables or the parameters whose value changes in each recursive calls. If there is one such variable then it is a 1D dp problem. If there are two such variables then it is a 2D dp problem. Once you get the top down approach it would be easier to think in the bottom up way.

How do you explain dynamic programming?

Dynamic programming is nothing but recursion with memoization i.e. calculating and storing values that can be later accessed to solve subproblems that occur again, hence making your code faster and reducing the time complexity (computing CPU cycles are reduced).


2 Answers

I think I've figured this out. The idea is to sweep from top-to-bottom, left-to-right. At each step, we try moving to the next position by setting the current box either to 0 or to 1.

At the end of each row, if the parity is even, we move on to the next row; otherwise we backtrack. At the end of every third row, if the parity of all three boxes is even, we move on to the next row; otherwise we backtrack. Finally, at the end of the board, if all columns have even parity, we're done; otherwise we backtrack.

The state of the recursion at any point can be described in terms of the following five pieces of information:

  • The current row and column.
  • The parities of all the columns.
  • The parities of the three boxes we're currently in (each row intersects three).
  • The current parity of the column.

This is what the memoization table looks like:

int memo[9][9][1<<9][1<<3][2];
         ^  ^    ^     ^   ^
         |  |    |     |   |
   row --+  |    |     |   |
   col -----+    |     |   |
column parity ---+     |   |
  box parity ----------+   |
current row parity---------+

To see why there are bitshifts, let's look at the column parity. There are 9 columns, so we can write out their parities as a bitvector with 9 bits. Equivalently, we could use a nine-bit integer. 1 << 9 gives the number of possible nine-bit integers, so we can use a single integer to encode all column parities at the same time.

Why use XOR and bitshifts? Well, XORing a bitvector A with a second bitvector B inverts all the bits in A that are set in B and leaves all the other bits unchanged. If you're tracking parity, you can use XOR to toggle individual bits to represent a flip in parity; the shifting happens because we're packing multiple parity bits into a single machine word. The division you referred to is to map from a column index to the horizontal index of the box it passes through.

Hope this helps!

like image 119
templatetypedef Avatar answered Sep 28 '22 02:09

templatetypedef


The algorithm in the solution is an exhaustive depth-first search with a couple optimizations. Unfortunately, the description doesn't exactly explain it.

Exhaustive search means that we try to enumerate every possible combination of bits. Depth-first means we first try to set all bits to one, then set the last one to zero, then the second-to-last, then both the last and the second-to-last, etc.

The first optimization is to backtrack as soon as we detect that parity isn't even. So, for example, as we start our search and reach the first row, we check if that row has zero parity. If it doesn't, we don't continue. We stop, backtrack, and try setting the last bit in the row to zero.

The second optimization is DP-like, in that we cache partial results and re-use them. This takes advantage of the fact that, in terms of the problem, different paths in the search can converge to the same logical state. What is a logical search state? The description in the solution begins to explain it ("begins" being the key word). In essence, the trick is that, at any given point in the search, the minimum number of additional bit flips does not depend on the exact state of the whole sudoku board, but only on the state of the various parities that we need to track. (See further explanation below.) There are 27 parities that we are tracking (accounting for 9 columns, 9 rows, and 9 3x3 boxes). Moreover, we can optimize some of them away. The parity for all higher rows, given how we perform the search, will always be even, while the parity of all lower rows, not yet touched by the search, doesn't change. We only track the parity of 1 row. By the same logic, the parity of the boxes above and below are disregarded, and we only need to track the "active" 3 boxes.

Therefore, instead of 2^9 * 2^9 * 2^9 = 134,217,728 states, we only have 2^9 * 2^1 * 2^3 = 8,192 states. Unfortunately, we need a separate cache for each depth level in the search. So, we multiply by the 81 possible depths to the search, to discover that we need an array of size 663,552. To borrow from templatetypedef:

int memo[9][9][1<<9][1<<3][2];
         ^  ^    ^     ^   ^
         |  |    |     |   |
   row --+  |    |     |   |
   col -----+    |     |   |
column parity ---+     |   |
  box parity ----------+   |
current row parity---------+

1<<9 simply means 2^9, given how integers and bit shifts work.

Further explanation: Due to how parity works, a bit flip will always flip its 3 corresponding parities. Therefore, all the permutations of sudoku boards that have the same parities can be solved with the same winning pattern of bit flips. The function 'solve' gives the answer to the problem: "Assuming you can only perform bit flips starting with the cell at position (x,y), what is the minimum number of bit flips to get a solved board." All sudoku boards with the same parities will yield the same answer. The search algorithm considers many permutations of boards. It starts modifying them from the top, counts how many bit flips it's already done, then asks the function 'solve' to see how many more it would need. If 'solve' has already been called with the same values of (x,y) and the same parities, we can just return the cached result.

The confusing part is the code that actually does the search and updates state:

/* Try setting the cell to 1. */
ref = !A[r][c] + solve(r, c + 1, mc ^ 1 << c, mb ^ 1 << c / 3, !p);

/* Try setting the cell to 0. */
ref = min(ref, A[r][c] + solve(r, c + 1, mc, mb, p));

It could be more clearly rendered as:

/* Try having this cell equal 0 */
bool areWeFlipping = A[r][c] == 1;
int nAdditionalFlipsIfCellIs0 = (areWeFlipping ? 1 : 0) + solve(r, c + 1, mc, mb, p); // Continue the search

/* Try having this cell equal 1 */
areWeFlipping = A[r][c] == 0;
// At the start, we assume the sudoku board is all zeroes, and therefore the column parity is all even. With each additional cell, we update the column parity with the value of tha cell. In this case, we assume it to be 1.
int newMc = mc ^ (1 << c); // Update the parity of column c. ^ (1 << c) means "flip the bit denoting the parity of column c"
int newMb = mb ^ (1 << (c / 3)); // Update the parity of 'active' box (c/3) (ie, if we're in column 5, we're in box 1)
int newP = p ^ 1; // Update the current row parity
int nAdditionalFlipsIfCellIs1 = (areWeFlipping ? 1 : 0) + solve(r, c + 1, newMc, newMb, newP); // Continue the search

ref = min( nAdditionalFlipsIfCellIs0, nAdditionalFlipsIfCellIs1 );

Personally, I would've implemented the two sides of the search as "flip" and "don't flip". This makes the algorithm make more sense, conceptually. It would make the second paragraph read: "Depth-first means we first try to not flip any bits, then flip the last one, then the second-to-last, then both the last and the second-to-last, etc." In addition, before we start the search, we would need to pre-calculate the values of 'mc', 'mb', and 'p' for our board, instead of passing 0's.

/* Try not flipping the current cell */
int nAdditionalFlipsIfDontFlip = 0 + solve(r, c + 1, mc, mb, p);

/* Try flipping it */
int newMc = mc ^ (1 << c);
int newMb = mb ^ (1 << (c / 3));
int newP = p ^ 1;
int nAdditionalFlipsIfFlip = 1 + solve(r, c + 1, newMc, newMb, newP);

ref = min( nAdditionalFlipsIfDontFlip, nAdditionalFlipsIfFlip );

However, this change doesn't seem to affect performance.

UPDATE

Most surprisingly, the key to the algorithm's blazing speed seems to be that the memoization array ends up rather sparse. At each depth level, there is typically 512 (sometimes, 256 or 128) states visited (out of 8192). Moreover, it is always one state per column parity. The box and row parities don't seem to matter! Omitting them from the memoization array improves performance another 30-fold. Yet, can we prove that it is always true?

like image 26
Aleksandr Dubinsky Avatar answered Sep 28 '22 02:09

Aleksandr Dubinsky