Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Designing tail recursion using java 8

I was trying out the following example provide in the talk to understand the tail recursion in java8.

@FunctionalInterface
public interface TailCall<T> {
    TailCall<T> apply();

    default boolean isComplete() {
        return false;
    }

    default T result() {
        throw new Error("not implemented");
    }

    default T get() {
        return Stream.iterate(this, TailCall::apply).filter(TailCall::isComplete)
                                                .findFirst().get().result();
    }
}

Utility class to use the TailCall

public class TailCalls {
    public static <T> TailCall<T> call(final TailCall<T> nextcall) {
        return nextcall;
    }

    public static <T> TailCall<T> done(final T value) {
        return new TailCall<T>() {
            @Override
            public boolean isComplete() {
                return true;
            }

            @Override
            public T result() {
                return value;
            }

            @Override
            public TailCall<T> apply() {
                throw new Error("not implemented.");
            }
        };
    }
}

Here is the use of the of Tail recursion :

public class Main {

    public static TailCall<Integer> factorial(int fact, int n) {
        if (n == 1) {
            return TailCalls.done(fact);
        } else {
            return TailCalls.call(factorial(fact * n, n-1));
        }
    }

    public static void main(String[] args) {
        System.out.println(factorial(1, 5).get());
        }
}

It worked correctly, but I feel like we don't require the TailCall::get to compute the result. As per my understanding we can directly compute the result using:

System.out.println(factorial(1, 5).result());

instead of:

System.out.println(factorial(1, 5).get());

Please let me know if I am missing the gist of TailCall::get.

like image 695
subhash kumar singh Avatar asked May 12 '17 11:05

subhash kumar singh


People also ask

Is tail recursion possible in Java?

Java doesn't have tail call optimization for the same reason most imperative languages don't have it. Imperative loops are the preferred style of the language, and the programmer can replace tail recursion with imperative loops. (Source)

What is tail recursion explain with example?

Tail recursion is defined as a recursive function in which the recursive call is the last statement that is executed by the function. So basically nothing is left to execute after the recursion call. For example the following C++ function print() is tail recursive.

Is tail recursion good for programming?

Tail-recursive functions are considered better than non-tail-recursive functions — the compiler can easily optimize the tail-recursive function as there is nothing left to do in the current function after the recursive call. Hence, the function's stack frame need not be saved.


1 Answers

There is a mistake in the example. It will just preform plain recursion, without tail call optimization. You can see this by adding Thread.dumpStack to the base case:

if (n == 1) {
    Thread.dumpStack();
    return TailCalls.done(fact);
}

The stack trace will look something like:

java.lang.Exception: Stack trace
    at java.lang.Thread.dumpStack(Thread.java:1333)
    at test.Main.factorial(Main.java:14)
    at test.Main.factorial(Main.java:18)
    at test.Main.factorial(Main.java:18)
    at test.Main.factorial(Main.java:18)
    at test.Main.factorial(Main.java:18)
    at test.Main.main(Main.java:8)

As you can see there are multiple calls to factorial. This means that plain recursion takes place, without the tail call optimization. In that case there is indeed no point in calling get since the TailCall object you get back from factorial already has the result in it.


The right way to implement this is to return a new TailCall object that defers the actual call:

public static TailCall<Integer> factorial(int fact, int n) {
    if (n == 1) {
        return TailCalls.done(fact);
    }

    return () -> factorial(fact * n, n-1);
}

If you also add the Thread.dumpStack there will be only 1 call to factorial:

java.lang.Exception: Stack trace
    at java.lang.Thread.dumpStack(Thread.java:1333)
    at test.Main.factorial(Main.java:14)
    at test.Main.lambda$0(Main.java:18)
    at java.util.stream.Stream$1.next(Stream.java:1033)
    at java.util.Spliterators$IteratorSpliterator.tryAdvance(Spliterators.java:1812)
    at java.util.stream.ReferencePipeline.forEachWithCancel(ReferencePipeline.java:126)
    at java.util.stream.AbstractPipeline.copyIntoWithCancel(AbstractPipeline.java:498)
    at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:485)
    at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:471)
    at java.util.stream.FindOps$FindOp.evaluateSequential(FindOps.java:152)
    at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234)
    at java.util.stream.ReferencePipeline.findFirst(ReferencePipeline.java:464)
    at test.TailCall.get(Main.java:36)
    at test.Main.main(Main.java:9)
like image 114
Jorn Vernee Avatar answered Sep 30 '22 16:09

Jorn Vernee