Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Handling StackOverflow in Java for Trampoline

I would like to implement a trampoline in java by returning a thunk whenever I hit a StackOverflowError. Are there any guarantees about the StackOverflowError, like, if the only thing I do after the StackOverflowError is creating objects on the heap and returning from functions, I will be fine?

If the above sounds vague, I have added some code for computing even/odd in a tail-recursive manner in continuation passing style, returning a delayed thunk whenever the stack flows over. The code works on my machine, but does Java guarantee that it will always work?

public class CPS {
public static class Thunk {
    final Object r;
    final Continuation c;
    final boolean isDelayed;
    public Object force() {
        Thunk t = this;
        while (t.isDelayed)
            t = t.compute();
        return t.r;
    }
    public Thunk compute() {
        return this;
    }
    public Thunk(Object answer) {
        isDelayed = false;
        r = answer;
        c = null;
    }
    public Thunk(Object intermediate, Continuation cont) {
        r = intermediate;
        c = cont;
        isDelayed = true;
    }
}

public static class Continuation {
    public Thunk apply(Object result) {
        return new Thunk(result);
    }
}

public static Thunk even(final int n, final Continuation c) {
    try {
        if (n == 0) return c.apply(true);
        else return odd(n-1, c);
    } catch (StackOverflowError x) {
        return new Thunk(n, c) {
            public Thunk compute() {
                return even(((Integer)n).intValue(), c);
            }
        };
    }
}

public static Thunk odd(final int n, final Continuation c) {
    try {
        if (n == 0) return c.apply(false);
        else return even(n-1, c);
    } catch (StackOverflowError x) {
        return new Thunk(n, c) {
            public Thunk compute() {
                return odd(((Integer)n).intValue(), c);
            }
        };
    }
}

public static void main(String args[]) {
    System.out.println(even(100001, new Continuation()).force());
}

}
like image 224
Steven Obua Avatar asked Dec 24 '10 18:12

Steven Obua


2 Answers

I tried the following implementation possibilities: A) With thunks (see code CPS below) B) Without thunks as suggested by chris (see code CPS2 below) C) With thunks with the stack overflow replaced by a depth check (see code CPS3 below)

In each case I checked if 100,000,000 is an even number. This check lasted A) about 2 seconds B) about 17 seconds C) about 0.2 seconds

So returning from a long chain of functions is match faster than throwing an exception that unwinds that chain. Also, instead of waiting for a stack overflow, it is much faster to just record the recursion depth and unwind at depth 1000.

Code for CPS:

public class CPS {

public static class Thunk {
    final Object r;
    final boolean isDelayed;
    public Object force() {
        Thunk t = this;
        while (t.isDelayed)
            t = t.compute();
        return t.r;
    }
    public Thunk compute() {
        return this;
    }
    public Thunk(Object answer) {
        isDelayed = false;
        r = answer;
    }
    public Thunk() {
        isDelayed = true;
        r = null;
    }
}

public static class Continuation {
    public Thunk apply(Object result) {
        return new Thunk(result);
    }
}

public static Thunk even(final int n, final Continuation c) {
    try {
        if (n == 0) return c.apply(true);
        else return odd(n-1, c);
    } catch (StackOverflowError x) {
        return new Thunk() {
            public Thunk compute() {
                return even(n, c);
            }
        };
    }
}

public static Thunk odd(final int n, final Continuation c) {
    try {
        if (n == 0) return c.apply(false);
        else return even(n-1, c);
    } catch (StackOverflowError x) {
        return new Thunk() {
            public Thunk compute() {
                return odd(n, c);
            }
        };
    }
}


public static void main(String args[]) {
    long time1 = System.currentTimeMillis();
    Object b =  even(100000000, new Continuation()).force();
    long time2 = System.currentTimeMillis();
    System.out.println("time = "+(time2-time1)+", result = "+b);
}

}

Code for CPS2:

public class CPS2 {

public abstract static class Unwind extends RuntimeException {
    public abstract Object compute();
    public Object force() {
        Unwind w = this;
        do {
            try {
                return w.compute();
            } catch (Unwind unwind) {
                w = unwind;
            }
        } while (true);
    }
}

public static class Continuation {
    public Object apply(Object result) {
        return result;
    }
}

public static Object even(final int n, final Continuation c) {
    try {
        if (n == 0) return c.apply(true);
        else return odd(n-1, c);
    } catch (StackOverflowError x) {
        throw new Unwind()  {
            public Object compute() {
                return even(n, c);
            }
        };
    }
}

public static Object odd(final int n, final Continuation c) {
    try {
        if (n == 0) return c.apply(false);
        else return even(n-1, c);
    } catch (StackOverflowError x) {
        return new Unwind() {
            public Object compute() {
                return odd(n, c);
            }
        };
    }
}


public static void main(String args[]) {
    long time1 = System.currentTimeMillis();
    Unwind w = new Unwind() {
        public Object compute() {
            return even(100000000, new Continuation());
        }
    };
    Object b = w.force();
    long time2 = System.currentTimeMillis();
    System.out.println("time = "+(time2-time1)+", result = "+b);
}

}

Code for CPS3:

public class CPS3 {

public static class Thunk {
    final Object r;
    final boolean isDelayed;
    public Object force() {
        Thunk t = this;
        while (t.isDelayed)
            t = t.compute();
        return t.r;
    }
    public Thunk compute() {
        return this;
    }
    public Thunk(Object answer) {
        isDelayed = false;
        r = answer;
    }
    public Thunk() {
        isDelayed = true;
        r = null;
    }
}

public static class Continuation {
    public Thunk apply(Object result) {
        return new Thunk(result);
    }
}

public static Thunk even(final int n, final Continuation c, final int depth) {
    if (depth >= 1000) {
        return new Thunk() {
            public Thunk compute() {
                return even(n, c, 0);
            }
        };
    }
    if (n == 0) return c.apply(true);
    else return odd(n-1, c, depth+1);
}

public static Thunk odd(final int n, final Continuation c, final int depth) {
    if (depth >= 1000) {
        return new Thunk() {
            public Thunk compute() {
                return odd(n, c, 0);
            }
        };
    }
    if (n == 0) return c.apply(false);
    else return even(n-1, c, depth+1);
}


public static void main(String args[]) {
    long time1 = System.currentTimeMillis();
    Object b =  even(100000000, new Continuation(), 0).force();
    long time2 = System.currentTimeMillis();
    System.out.println("time = "+(time2-time1)+", result = "+b);
}

}
like image 65
Steven Obua Avatar answered Oct 30 '22 10:10

Steven Obua


That's an interesting way to jump up the stack. It seems to work, but is probably slower than the usual way to implement this technique, which is to throw an exception that is caught $BIGNUM layers up the call stack.

like image 22
Chris Jester-Young Avatar answered Oct 30 '22 10:10

Chris Jester-Young