Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Sequence prediction LSTM neural network is falling behind

I'm trying to implement a guessing game where user guesses coinflip and neural network tries to predict his guess (without hindsight knowledge of course). The game is supposed to be realtime, it adapts to the user. I've used synaptic js as it seemed solid.

Yet I seem not to be able to get past a stumbling block: neural network constantly trails behind with its guesses. Like, if user presses

heads heads tail heads heads tail heads heads tail

It does recognize the pattern but it is lagging behind by two moves like

tail heads heads tail heads heads tail heads heads

I've tried countless strategies:

  • train network as user clicks heads or tails with going along with the user
  • have a log of user entries and clear network memory and retrain it with all entries up to the point of guess
  • mix and match training with activations a bunch of ways
  • try moving to perceptron passing it a bunch of moves at once (worked worse than LSTM)
  • bunch of other things I forgot

Architecture:

  • 2 inputs, whether user clicked heads or tails in previous turn
  • 2 outputs, prediction what user will click next (this will be input on the next turn)

I've tried 10-30 neurons in the hidden layers and variety of training epochs yet I constantly arrive at the same problem!

I'll post the bucklescript code I'm doing this with.

What am I doing wrong? Or my expectations are simply unreasonable to predict user guess realtime? Are there any alternative algorithms?

class type _nnet = object
    method activate : float array -> float array
    method propagate : float -> float array -> unit
    method clone : unit -> _nnet Js.t
    method clear : unit -> unit
end [@bs]

type nnet = _nnet Js.t

external ltsm : int -> int -> int -> nnet = "synaptic.Architect.LSTM" [@@bs.new]
external ltsm_2 : int -> int -> int -> int -> nnet = "synaptic.Architect.LSTM" [@@bs.new]
external ltsm_3 : int -> int -> int -> int -> int -> nnet = "synaptic.Architect.LSTM" [@@bs.new]
external perceptron : int -> int -> int -> nnet = "synaptic.Architect.Perceptron" [@@bs.new]

type id
type dom
  (** Abstract type for id object *)

external dom : dom = "document" [@@bs.val]

external get_by_id : dom -> string -> id =
  "getElementById" [@@bs.send]

external set_text : id -> string -> unit =
  "innerHTML" [@@bs.set]

(*THE CODE*)

let current_net = ltsm 2 16 2
let training_momentum = 0.1
let training_epochs = 20
let training_memory = 16

let rec train_sequence_rec n the_array =
    if n > 0 then (
        current_net##propagate training_momentum the_array;
        train_sequence_rec (n - 1) the_array
    )

let print_arr prefix the_arr =
    print_endline (prefix ^ " " ^
        (Pervasives.string_of_float (Array.get the_arr 0)) ^ " " ^
        (Pervasives.string_of_float (Array.get the_arr 1)))

let blank_arr =
    fun () ->
    let res = Array.make_float 2 in
    Array.fill res 0 2 0.0;
    res

let derive_guess_from_array the_arr =
    Array.get the_arr 0 < Array.get the_arr 1

let set_array_inp the_value the_arr =
    if the_value then
        Array.set the_arr 1 1.0
    else
        Array.set the_arr 0 1.0

let output_array the_value =
    let farr = blank_arr () in
    set_array_inp the_value farr;
    farr

let by_id the_id = get_by_id (dom) the_id

let update_prediction_in_ui the_value =
    let elem = by_id "status-text" in
    if not the_value then
        set_text elem "Predicted Heads"
    else
        set_text elem "Predicted Tails"

let inc_ref the_ref = the_ref := !the_ref + 1

let total_guesses_count = ref 0
let steve_won_count = ref 0

let sequence = Array.make training_memory false
let seq_ptr = ref 0
let seq_count = ref 0

let push_seq the_value =
    Array.set sequence (!seq_ptr mod training_memory) the_value;
    inc_ref seq_ptr;
    if !seq_count < training_memory then
        inc_ref seq_count

let seq_start_offset () =
    (!seq_ptr - !seq_count) mod training_memory

let traverse_seq the_fun =
    let incr = ref 0 in
    let begin_at = seq_start_offset () in
    let next_i () = (begin_at + !incr) mod training_memory in
    let rec loop () =
        if !incr < !seq_count then (
            let cval = Array.get sequence (next_i ()) in
            the_fun cval;
            inc_ref incr;
            loop ()
        ) in
    loop ()

let first_in_sequence () =
    Array.get sequence (seq_start_offset ())

let last_in_sequence_n n =
    let curr = ((!seq_ptr - n) mod training_memory) - 1 in
    if curr >= 0 then
        Array.get sequence curr
    else
        false

let last_in_sequence () = last_in_sequence_n 0

let perceptron_input last_n_fields =
    let tot_fields = (3 * last_n_fields) in
    let out_arr = Array.make_float tot_fields in
    Array.fill out_arr 0 tot_fields 0.0;
    let rec loop count =
        if count < last_n_fields then (
            if count >= !seq_count then (
                Array.set out_arr (3 * count) 1.0;
            ) else (
                let curr = last_in_sequence_n count in
                let the_slot = if curr then 1 else 0 in
                Array.set out_arr (3 * count + 1 + the_slot) 1.0
            );
            loop (count + 1)
        ) in
    loop 0;
    out_arr

let steve_won () = inc_ref steve_won_count

let propogate_n_times the_output =
    let rec loop cnt =
        if cnt < training_epochs then (
            current_net##propagate training_momentum the_output;
            loop (cnt + 1)
        ) in
    loop 0

let print_prediction prev exp pred =
    print_endline ("Current training, previous: " ^ (Pervasives.string_of_bool prev) ^
        ", expected: " ^ (Pervasives.string_of_bool exp)
        ^ ", predicted: " ^ (Pervasives.string_of_bool pred))

let train_from_sequence () =
    current_net##clear ();
    let previous = ref (first_in_sequence ()) in
    let count = ref 0 in
    print_endline "NEW TRAINING BATCH";
    traverse_seq (fun i ->
        let inp_arr = output_array !previous in
        let out_arr = output_array i in
        let act_res = current_net##activate inp_arr in
        print_prediction !previous i (derive_guess_from_array act_res);
        propogate_n_times out_arr;
        previous := i;
        inc_ref count
    )

let update_counts_in_ui () =
    let tot = by_id "total-count" in
    let won = by_id "steve-won-count" in
    set_text tot (Pervasives.string_of_int !total_guesses_count);
    set_text won (Pervasives.string_of_int !steve_won_count)

let train_sequence (the_value : bool) =
    train_from_sequence ();
    let last_guess = (last_in_sequence ()) in
    let before_train = current_net##activate (output_array last_guess) in
    let act_result = derive_guess_from_array before_train in
    (*side effects*)

    push_seq the_value;

    inc_ref total_guesses_count;
    if the_value = act_result then steve_won ();
    print_endline "CURRENT";
    print_prediction last_guess the_value act_result;
    update_prediction_in_ui act_result;
    update_counts_in_ui ()

let guess (user_guess : bool) =
    train_sequence user_guess

let () = ()
like image 378
Vanilla Face Avatar asked Mar 12 '17 15:03

Vanilla Face


People also ask

What is LSTM for time series prediction?

LSTM for time series prediction. Training a Long Short Term Memory… | by Roman Orac | Towards Data Science The idea of using a Neural Network (NN) to predict the stock price movement on the market is as old as Neural nets. Intuitively, it seems difficult to predict the future price movement looking only at its past.

What is the difference between LSTM and recurrent neural network?

In this example with LSTM, the feature and the target are from the same sequence, so the only difference is that the target is shifted by 1 time bar. The Long Short Term Memory neural network is a type of a Recurrent Neural Network (RNN). RNNs use previous time events to inform the later ones.

What is long short term memory neural network (LSTM)?

The Long Short Term Memory neural network is a type of a Recurrent Neural Network (RNN). RNNs use previous time events to inform the later ones. For example, to classify what kind of event is happening in a movie, the model needs to use information about previous events.

What is LSTM in deep learning?

LSTM is a type of Recurrent Neural Network in Deep Learning that has been specifically developed for the use of handling sequential prediction problems. For example: Need a refresher on Neural Networks as a whole?


1 Answers

Clearing the network context before every training iteration is the fix

The problem in your code is that your network is trained circular. Instead of training 1 > 2 > 3 RESET 1 > 2 > 3 you're training the network 1 > 2 > 3 > 1 > 2 > 3. This makes your network believe that the value after 3 should be 1.

Second of all, there is no reason to use 2 output neurons. Having one is enough, output 1 equals heads, output 0 equals tails. We'll just round the output.

Instead of using Synaptic, I used Neataptic in this code - it is an improved version of Synaptic, adding functionality and genetic algorithms.

The code

The code is fairly simple. Uglyfying it a little, it looks like this:

var network = new neataptic.Architect.LSTM(1,12,1);;
var previous = null;
var trainingData = [];

// side is 1 for heads and 0 for tails
function onSideClick(side){
  if(previous != null){
    trainingData.push({ input: [previous], output: [side] });

    // Train the data
    network.train(trainingData, {
      log: 500,
      iterations: 5000,
      error: 0.03,
      clear: true,
      rate: 0.05,
    });

    // Iterate over previous sets to get into the 'flow'
    for(var i in trainingData){
      var input = trainingData[i].input;
      var output = Math.round(network.activate([input]));
    }

    // Activate network with previous output, aka make a prediction
    var input = output;
    var output = Math.round(network.activate([input]))
  }

  previous = side;
}

Run the code here!

The key to this code is clear: true. This basically makes sure that the network knows it is starting from the first training sample, and not continuing from the last training sample. The size of the LSTM, iteration count and learning rate are fully customisable.

Success!

Please do note that it takes about 2x the pattern for the network to learn it.

enter image description here

enter image description here

enter image description here

It does have problems with non-repetitive patterns though: enter image description here

like image 124
Thomas Wagenaar Avatar answered Oct 13 '22 19:10

Thomas Wagenaar