Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to memoize recursive path of length n search

First time posting thought I would try this community out.

I have researched for hours and i just cant seem to find an example close enough to get ideas from. I dont care what language answers are in but would prefer java, c/c++, or pseudocode.

I am looking to find consecutive paths of length n in a grid.

I found a recursive solution which i think was clean and always worked but the runtime was poor if number of paths is too large. I realize i could implement it iteratively but i want to find a recursive solution first.

I dont care what language answers are in but would prefer java, c/c++.

The problem is this- for a String[] and a int pathLength how many paths are there of that length.

{ "ABC", "CBZ", "CZC", "BZZ", "ZAA" } of length 3

enter image description here This is the 3rd and 7th path from below.

A B C    A . C    A B .    A . .    A . .    A . .    . . .
. . .    . B .    C . .    C B .    . B .    . B .    . . .
. . .    . . .    . . .    . . .    C . .    . . C    C . .
. . .    . . .    . . .    . . .    . . .    . . .    B . .
. . .    . . .    . . .    . . .    . . .    . . .    . A .
(spaces are for clarity only) 

return 7 possible paths of length 3 (A-B-C)

This was the original recursive solution

public class SimpleRecursive {

    private int ofLength;
    private int paths = 0;
    private String[] grid;

    public int count(String[] grid, int ofLength) {
        this.grid = grid;
        this.ofLength = ofLength;
        paths = 0;

        long startTime = System.currentTimeMillis();
        for (int j = 0; j < grid.length; j++) {
            for (int index = grid[j].indexOf('A'); index >= 0; index = grid[j].indexOf('A', index + 1)) {

                recursiveFind(1, index, j);

            }

        }
        System.out.println(System.currentTimeMillis() - startTime);
        return paths;
    }

    private void recursiveFind(int layer, int x, int y) {

        if (paths >= 1_000_000_000) {

        }

        else if (layer == ofLength) {

            paths++;

        }

        else {

            int xBound = grid[0].length();
            int yBound = grid.length;

            for (int dx = -1; dx <= 1; ++dx) {
                for (int dy = -1; dy <= 1; ++dy) {
                    if (dx != 0 || dy != 0) {
                        if ((x + dx < xBound && y + dy < yBound) && (x + dx >= 0 && y + dy >= 0)) {
                            if (grid[y].charAt(x) + 1 == grid[y + dy].charAt(x + dx)) {

                                recursiveFind(layer + 1, x + dx, y + dy);

                            }

                        }

                    }
                }
            }

        }
    }
}

This was very slow because each new letter could spin off 8 recursions so the complexity skyrockets.

I decided to use memoization to improve performance.

This is what i came up with.

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

public class AlphabetCount {

    private int ofLength;
    private int paths = 0;
    private String[] grid;
//  This was an optimization that helped a little.  It would store possible next paths  
//  private HashMap<Integer, ArrayList<int[]>> memoStack = new HashMap<Integer, ArrayList<int[]>>();
    //hashmap of indices that are part of a complete path(memoization saves)
    private HashMap<Integer, int[]> completedPath = new HashMap<Integer, int[]>();
    //entry point 
    public int count(String[] grid, int ofLength) {
        this.grid = grid;
        //Since i find the starting point ('A') by brute force then i just need the next n-1 letters
        this.ofLength = ofLength - 1;
        //variable to hold number of completed runs
        paths = 0;

        //holds the path that was taken to get to current place.  determined that i dont really need to memoize 'Z' hence ofLength -1 again
        List<int[]> fullPath = new ArrayList<int[]>(ofLength - 1);

        //just a timer to compare optimizations
        long startTime = System.currentTimeMillis();

        //this just loops around finding the next 'A'
        for (int j = 0; j < grid.length; j++) {
            for (int index = grid[j].indexOf('A'); index >= 0; index = grid[j].indexOf('A', index + 1)) {

                //into recursive function.  fullPath needs to be kept in this call so that it maintains state relevant to call stack?  also the 0 here is technically 'B' because we already found 'A'
                recursiveFind(fullPath, 0, index, j);

            }

        }
        System.out.println(System.currentTimeMillis() - startTime);
        return paths;
    }

    private void recursiveFind(List<int[]> fullPath, int layer, int x, int y) {
        //hashing key. mimics strings tohash.  should not have any duplicates to my knowledge
        int key = 31 * (x) + 62 * (y) + 93 * layer;

        //if there is more than 1000000000 paths then just stop counting and tell me its over 1000000000
        if (paths >= 1_000_000_000) {

        //this if statement never returns true unfortunately.. this is the optimization that would actually help me.
        } else if (completedPath.containsKey(key)) {
            paths++;
            for (int i = 0; i < fullPath.size() - 1; i++) {
                int mkey = 31 * fullPath.get(i)[0] + 62 * fullPath.get(i)[1] + 93 * (i);
                if (!completedPath.containsKey(mkey)) {
                    completedPath.put(mkey, fullPath.get(i));
                }
            }

        }
        //if we have a full run then save the path we took into the memoization hashmap and then increase paths 
        else if (layer == ofLength) {

            for (int i = 0; i < fullPath.size() - 1; i++) {
                int mkey = 31 * fullPath.get(i)[0] + 62 * fullPath.get(i)[1] + 93 * (i);
                if (!completedPath.containsKey(mkey)) {
                    completedPath.put(mkey, fullPath.get(i));
                }
            }

            paths++;

        }


//everything with memoStack is an optimization that i used that increased performance marginally.
//      else if (memoStack.containsKey(key)) {
//          for (int[] path : memoStack.get(key)) {
//              recursiveFind(fullPath,layer + 1, path[0], path[1]);
//          }
//      } 

        else {

            int xBound = grid[0].length();
            int yBound = grid.length;

            // ArrayList<int[]> newPaths = new ArrayList<int[]>();
            int[] pair = new int[2];

            //this loop checks indices adjacent in all 8 directions ignoring index you are in then checks to see if you are out of bounds then checks to see if one of those directions has the next character
            for (int dx = -1; dx <= 1; ++dx) {
                for (int dy = -1; dy <= 1; ++dy) {
                    if (dx != 0 || dy != 0) {
                        if ((x + dx < xBound && y + dy < yBound) && (x + dx >= 0 && y + dy >= 0)) {
                            if (grid[y].charAt(x) + 1 == grid[y + dy].charAt(x + dx)) {

                                pair[0] = x + dx;
                                pair[1] = y + dy;
                                // newPaths.add(pair.clone());
                                //not sure about this... i wanted to save space by not allocating everything but i needed fullPath to only have the path up to the current call
                                fullPath.subList(layer, fullPath.size()).clear();
                                //i reuse the int[] pair so it needs to be cloned
                                fullPath.add(pair.clone());
                                //recursive call
                                recursiveFind(fullPath, layer + 1, x + dx, y + dy);

                            }

                        }

                    }
                }
            }
            // memoStack.putIfAbsent(key, newPaths);

            // memo thought! if layer, x and y are the same as a successful runs then you can use a
            // previous run

        }
    }

}

The issue is that my memoization never gets used really. The recursive calls kinda mimics a depth first search. ex-

     1
   / | \
  2  5  8
 /\  |\  |\
3 4  6 7 9 10

So saving a run will not overlap with another run in any performance saving way because it is searching in the bottom of the tree before going back down the call stack. So the question is... how do i memoize this? or once i get a full run how do i recurse back to the beginning of the tree so the memoization i wrote works.

The test string that really kills the performance is { "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "ABCDEFGHIJKLMNOPQRSTUVWXYZ" }; for all paths of length 26 (Should return 1000000000)

PS. As a first time poster any comments about general code improvements or bad coding habits would be appreciated. Additionally since i havent posted before let me know if this question was unclear or formatted poorly or too long etc.

like image 494
Zachary Konovalov Avatar asked Mar 19 '19 21:03

Zachary Konovalov


2 Answers

I'm not sure what you're memoizing (perhaps you could explain it in words?) but there seem to be overlapping sub-problems here. If I understand correctly, except for "A", any specific instance of a letter can only be reached from a neighbouring previous letter in the alphabet. That means we can store the number of paths from each specific instance of a letter. When that specific instance is reached on subsequent occasions, we can avoid recursing into it.

Depth first search:

 d1 d2 d3 d4
   c1   c2
      b
    a1 a2

 .....f(c1) = f(d1) + f(d2) = 2
 .....f(c2) = f(d3) + f(d4) = 2
 ...f(b) = f(c1) + f(c2) = 4
 f(a1) = f(b) = 4
 f(a2) = f(b) = 4
like image 57
גלעד ברקן Avatar answered Sep 30 '22 13:09

גלעד ברקן


Well I figured it out! Thanks in part to @גלעד ברקן 's recommendation. I was originally just trying to make it work by saying that any if any two paths had any same index then then it was a full path so we wouldnt have to recurse further which was a gross oversimplification. I ended up writing a little graph visualizer so that i could see exactly what i was looking at. (This is the first example from above ({ "ABC", "CBZ", "CZC", "BZZ", "ZAA" } of length 3))

L stands for layer- each layer corresponds to a letter ie layer 1 == 'A'

enter image description here

From this i determined that each node could save the number of completed paths that go from it. In the picture about this means that node L[2]X1Y1 would be given the number 4 because any time you reach that node there are 4 complete paths.

Anyways, i memoized into an int[][] so the only thing that I would like to do from here is make that a hashmap so that there isnt as much wasted space.

Here is the code that I came up with.

package practiceproblems;

import java.util.ArrayDeque;


public class AlphabetCount {

    private int ofLength;
    private int paths = 0;
    private String[] grid;

    //this is the array that we memoize.  could be hashmap
    private int[][] memoArray;// spec says it initalizes to zero so i am just going with it

    //entry point func
    public int count(String[] grid, int ofLength) {

        //initialize all the member vars
        memoArray = new int[grid[0].length()][grid.length];
        this.grid = grid;
        // this is minus 1 because we already found "A"
        this.ofLength = ofLength - 1;
        paths = 0;
        //saves the previous nodes visited.
        ArrayDeque<int[]> goodPathStack = new ArrayDeque<int[]>();


        long startTime = System.currentTimeMillis();
        for (int j = 0; j < grid.length; j++) {
            for (int index = grid[j].indexOf('A'); index >= 0; index = grid[j].indexOf('A', index + 1)) {
                //kinda wasteful to clone i would think... but easier because it stays in its stack frame
                recursiveFind(goodPathStack.clone(), 0, index, j);

            }

        }
        System.out.println(System.currentTimeMillis() - startTime);
        //if we have more than a bil then just return a bil
        return paths >= 1_000_000_000 ? 1_000_000_000 : paths;
    }

    //recursive func
    private void recursiveFind(ArrayDeque<int[]> fullPath, int layer, int x, int y) {

        //debugging
        System.out.println("Preorder " + layer + " " + (x) + " " + (y));

        //start pushing onto the path stack so that we know where we have been in a given recursion
        int[] pair = { x, y };
        fullPath.push(pair);

        if (paths >= 1_000_000_000) {
            return;

            //we found a full path 'A' thru length
        } else if (layer == this.ofLength) {

            paths++;
            //poll is just pop for deques apparently.
            // all of these paths start at 'A' which we find manually. so pop last.
            // all of these paths incluse the last letter which wouldnt help us because if
            // we find the last letter we already know we are at the end.
            fullPath.pollFirst();
            fullPath.pollLast();

            //this is where we save memoization info
            //each node on fullPath leads to a full path
            for (int[] p : fullPath) {
                memoArray[p[0]][p[1]]++;
            }
            return;

        } else if (memoArray[x][y] > 0) {

            //this is us using our memoization cache
            paths += memoArray[x][y];
            fullPath.pollLast();
            fullPath.pollFirst();
            for (int[] p : fullPath) {
                memoArray[p[0]][p[1]] += memoArray[x][y];
            }

        }

//      else if (memoStack.containsKey(key)) {
//          for (int[] path : memoStack.get(key)) {
//              recursiveFind(fullPath,layer + 1, path[0], path[1]);
//          }
//      } 

        else {

            int xBound = grid[0].length();
            int yBound = grid.length;

            //search in all 8 directions for a letter that comes after the one that you are on.
            for (int dx = -1; dx <= 1; ++dx) {
                for (int dy = -1; dy <= 1; ++dy) {
                    if (dx != 0 || dy != 0) {
                        if ((x + dx < xBound && y + dy < yBound) && (x + dx >= 0 && y + dy >= 0)) {
                            if (grid[y].charAt(x) + 1 == grid[y + dy].charAt(x + dx)) {

                                recursiveFind(fullPath.clone(), layer + 1, x + dx, y + dy);

                            }
                        }

                    }

                }
            }
        }

        // memoStack.putIfAbsent(key, newPaths);

        // memo thought! if one runs layer, x and y are the same then you can use a
        // previous run

    }

}

It works! and the time needed to complete the 1_000_000_000 paths was reduced by a lot. Like sub second.

Hope this example can help anyone else who ended up stumped for days.

like image 44
Zachary Konovalov Avatar answered Sep 30 '22 13:09

Zachary Konovalov