Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I pause/serialize a genetic algorithm in Encog?

How can I pause a genetic algorithm in Encog 3.4 (the version currently under development in Github)?

I am using the Java version of Encog.

I am trying to modify the Lunar example that comes with Encog. I want to pause/serialize the genetic algorithm and then continue/deserialize at a later stage.

When I call train.pause(); it simply returns null - which is pretty obvious from the code since the method always returns null.

I would assume that it would be pretty straight forward since there can be a scenario in which I want to train a neural network, use it for some predictions and then continue training with the genetic algorithm as I get more data before resuming with more predictions - without having to restart the training from the beginning.

Please note that I am not trying to serialize or persist a neural network but rather the entire genetic algorithm.

like image 597
Tmr Avatar asked Sep 28 '22 02:09

Tmr


People also ask

How does Ga work in Matlab?

The following outline summarizes how the genetic algorithm works: The algorithm begins by creating a random initial population. The algorithm then creates a sequence of new populations. At each step, the algorithm uses the individuals in the current generation to create the next population.

How do you initialize population in genetic algorithm in Matlab?

ga creates a default initial population by using a uniform random number generator. The default population size used by ga is 50 when the number of decision variables is less than 5, and 200 otherwise.


1 Answers

Not all trainers in Encog support the simple pause/resume. If they do not support it, they return null, like this one. The genetic algorithm trainer is much more complex than a simple propagation trainer that supports pause/resume. To save the state of the genetic algorithm, you must save the entire population, as well as the scoring function (which may or may not be serializable). I modified the Lunar Lander example to show you how you might save/reload your population of neural networks to do this.

You can see that it trains 50 iterations, then round-trips (load/saves) the genetic algorithm, then trains 50 more.

package org.encog.examples.neural.lunar;

import java.io.File;
import java.io.IOException;

import org.encog.Encog;
import org.encog.engine.network.activation.ActivationTANH;
import org.encog.ml.MLMethod;
import org.encog.ml.MLResettable;
import org.encog.ml.MethodFactory;
import org.encog.ml.ea.population.Population;
import org.encog.ml.genetic.MLMethodGeneticAlgorithm;
import org.encog.ml.genetic.MLMethodGenomeFactory;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.pattern.FeedForwardPattern;
import org.encog.util.obj.SerializeObject;

public class LunarLander {

    public static BasicNetwork createNetwork()
    {
        FeedForwardPattern pattern = new FeedForwardPattern();
        pattern.setInputNeurons(3);
        pattern.addHiddenLayer(50);
        pattern.setOutputNeurons(1);
        pattern.setActivationFunction(new ActivationTANH());
        BasicNetwork network = (BasicNetwork)pattern.generate();
        network.reset();
        return network;
    }

    public static void saveMLMethodGeneticAlgorithm(String file, MLMethodGeneticAlgorithm ga ) throws IOException
    {
        ga.getGenetic().getPopulation().setGenomeFactory(null);
        SerializeObject.save(new File(file),ga.getGenetic().getPopulation());   
    }

    public static MLMethodGeneticAlgorithm loadMLMethodGeneticAlgorithm(String filename) throws ClassNotFoundException, IOException {
        Population pop = (Population) SerializeObject.load(new File(filename));
        pop.setGenomeFactory(new MLMethodGenomeFactory(new MethodFactory(){
            @Override
            public MLMethod factor() {
                final BasicNetwork result = createNetwork();
                ((MLResettable)result).reset();
                return result;
            }},pop));

        MLMethodGeneticAlgorithm result = new MLMethodGeneticAlgorithm(new MethodFactory(){
            @Override
            public MLMethod factor() {
                return createNetwork();
            }},new PilotScore(),1);

        result.getGenetic().setPopulation(pop);

        return result;
    }


    public static void main(String args[])
    {
        BasicNetwork network = createNetwork();

        MLMethodGeneticAlgorithm train;


        train = new MLMethodGeneticAlgorithm(new MethodFactory(){
            @Override
            public MLMethod factor() {
                final BasicNetwork result = createNetwork();
                ((MLResettable)result).reset();
                return result;
            }},new PilotScore(),500);

        try {
            int epoch = 1;

            for(int i=0;i<50;i++) {
                train.iteration();
                System.out
                        .println("Epoch #" + epoch + " Score:" + train.getError());
                epoch++;
            } 
            train.finishTraining();

            // Round trip the GA and then train again
            LunarLander.saveMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin",train);
            train = LunarLander.loadMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin");

            // Train again
            for(int i=0;i<50;i++) {
                train.iteration();
                System.out
                        .println("Epoch #" + epoch + " Score:" + train.getError());
                epoch++;
            } 
            train.finishTraining();

        } catch(IOException ex) {
            ex.printStackTrace();
        } catch (ClassNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        int epoch = 1;

        for(int i=0;i<50;i++) {
            train.iteration();
            System.out
                    .println("Epoch #" + epoch + " Score:" + train.getError());
            epoch++;
        } 
        train.finishTraining();

        System.out.println("\nHow the winning network landed:");
        network = (BasicNetwork)train.getMethod();
        NeuralPilot pilot = new NeuralPilot(network,true);
        System.out.println(pilot.scorePilot());
        Encog.getInstance().shutdown();
    }
}
like image 120
JeffHeaton Avatar answered Oct 31 '22 16:10

JeffHeaton