Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to improve the performance of the recursive method?

Tags:

java

recursion

I'm learning data structures and algorithms, and here is a question that I'm stuck with.

I have to improve the performance of the recursive call by storing the value into memory.

But the problem is that the non-improved version seems faster than this.

Can someone help me out?

Syracuse numbers are a sequence of positive integers defined by the following rules:

syra(1) ≡ 1

syra(n) ≡ n + syra(n/2), if n mod 2 == 0

syra(n) ≡ n + syra((n*3)+1), otherwise

import java.util.HashMap;
import java.util.Map;

public class SyraLengthsEfficient {

    int counter = 0;
    public int syraLength(long n) {
        if (n < 1) {
            throw new IllegalArgumentException();
        }

        if (n < 500 && map.containsKey(n)) {
            counter += map.get(n);
            return map.get(n);
        } else if (n == 1) {
            counter++;
            return 1;
        } else if (n % 2 == 0) {
            counter++;
            return syraLength(n / 2);
        } else {
            counter++;
            return syraLength(n * 3 + 1);
        }
    }

    Map<Integer, Integer> map = new HashMap<Integer, Integer>();

    public int lengths(int n) {
        if (n < 1) {
            throw new IllegalArgumentException();
        }    
        for (int i = 1; i <= n; i++) {
            syraLength(i);
            if (i < 500 && !map.containsKey(i)) {
                map.put(i, counter);
            }
        }    
        return counter;
    }

    public static void main(String[] args) {
        System.out.println(new SyraLengthsEfficient().lengths(5000000));
    }
}

Here is the normal version that i wrote:

 public class SyraLengths{

        int total=1;
        public int syraLength(long n) {
            if (n < 1)
                throw new IllegalArgumentException();
            if (n == 1) {
                int temp=total;
                total=1;
                return temp;
            }
            else if (n % 2 == 0) {
                total++;
                return syraLength(n / 2);
            }
            else {
                total++;
                return syraLength(n * 3 + 1);
            }
        }

        public int lengths(int n){
            if(n<1){
                throw new IllegalArgumentException();
            }
            int total=0;
            for(int i=1;i<=n;i++){
                total+=syraLength(i);
            }

            return total;
        }

        public static void main(String[] args){
            System.out.println(new SyraLengths().lengths(5000000));
        }
       }

EDIT

It is slower than non-enhanced version.

import java.util.HashMap;
import java.util.Map;

public class SyraLengthsEfficient {

    private Map<Long, Long> map = new HashMap<Long, Long>();

    public long syraLength(long n, long count) {

        if (n < 1)
            throw new IllegalArgumentException();

        if (!map.containsKey(n)) {
            if (n == 1) {
                count++;
                map.put(n, count);
            } else if (n % 2 == 0) {
                count++;
                map.put(n, count + syraLength(n / 2, 0));
            } else {
                count++;
                map.put(n, count + syraLength(3 * n + 1, 0));
            }
        }

        return map.get(n);

    }

    public int lengths(int n) {
        if (n < 1) {
            throw new IllegalArgumentException();
        }
        int total = 0;
        for (int i = 1; i <= n; i++) {
            // long temp = syraLength(i, 0);
            // System.out.println(i + " : " + temp);
            total += syraLength(i, 0);

        }
        return total;
    }

    public static void main(String[] args) {
        System.out.println(new SyraLengthsEfficient().lengths(50000000));
    }
}

FINAL SOLUTION (mark as correct by school auto mark system)

public class SyraLengthsEfficient {

private int[] values = new int[10 * 1024 * 1024];

public int syraLength(long n, int count) {

    if (n <= values.length && values[(int) (n - 1)] != 0) {
        return count + values[(int) (n - 1)];
    } else if (n == 1) {
        count++;
        values[(int) (n - 1)] = 1;
        return count;
    } else if (n % 2 == 0) {
        count++;
        if (n <= values.length) {
            values[(int) (n - 1)] = count + syraLength(n / 2, 0);
            return values[(int) (n - 1)];
        } else {
            return count + syraLength(n / 2, 0);
        }
    } else {
        count++;
        if (n <= values.length) {
            values[(int) (n - 1)] = count + syraLength(n * 3 + 1, 0);
            return values[(int) (n - 1)];
        } else {
            return count + syraLength(n * 3 + 1, 0);
        }
    }

}

public int lengths(int n) {
    if (n < 1) {
        throw new IllegalArgumentException();
    }
    int total = 0;
    for (int i = 1; i <= n; i++) {
        total += syraLength(i, 0);
    }
    return total;
}

public static void main(String[] args) {
    SyraLengthsEfficient s = new SyraLengthsEfficient();
    System.out.println(s.lengths(50000000));
}

}

like image 331
Timeless Avatar asked Jun 04 '12 15:06

Timeless


People also ask

How do you improve recursion performance?

Bottom-up. Sometimes the best way to improve the efficiency of a recursive algorithm is to not use recursion at all. In the case of generating Fibonacci numbers, an iterative technique called the bottom-up approach can save us both time and space.

Do recursive methods run faster?

Although recursive methods run slower, they sometimes use less lines of code than iteration and for many are easier to understand. Recursive methods are useful for certain specific tasks, as well, such as traversing tree structures.

How do you make a recursive function faster in Python?

Caching recurve function is one way to improve this function speed. There is a standard Python library called functools . It has a memory caching function lru_cache . It could be used as a Python decorator.


1 Answers

Forget about the answers that say that your code is inefficient because of the use of a Map, that's not the reason why it's going slow - it's the fact that you're limiting the cache of calculated numbers to n < 500. Once you remove that restriction, things start to work pretty fast; here's a proof of concept for you to fill-in the details:

private Map<Long, Long> map = new HashMap<Long, Long>();

public long syraLength(long n) {

    if (!map.containsKey(n)) {
        if (n == 1)
            map.put(n, 1L);
        else if (n % 2 == 0)
            map.put(n, n + syraLength(n/2));
        else
            map.put(n, n + syraLength(3*n+1));
    }

    return map.get(n);

}

If you want to read more about what's happening in the program and why is so fast, take a look at this wikipedia article about Memoization.

Also, I think you're misusing the counter variable, you increment it (++) when a value is calculated the first time, but you accumulate over it (+=) when a value is found in the map. That doesn't seem right to me, and I doubt that it gives the expected result.

like image 59
Óscar López Avatar answered Sep 17 '22 13:09

Óscar López