Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cascade multiple RNN models for N-dimensional output

I'm having some difficulty with chaining together two models in an unusual way.

I am trying to replicate the following flowchart:

Cascaded RNN 2-D

For clarity, at each timestep of Model[0] I am attempting to generate an entire time series from IR[i] (Intermediate Representation) as a repeated input using Model[1]. The purpose of this scheme is it allows the generation of a ragged 2-D time series from a 1-D input (while both allowing the second model to be omitted when the output for that timestep is not needed, and not requiring Model[0] to constantly "switch modes" between accepting input, and generating output).

I assume a custom training loop will be required, and I already have a custom training loop for handling statefulness in the first model (the previous version only had a single output at each timestep). As depicted, the second model should have reasonably short outputs (able to be constrained to fewer than 10 timesteps).

But at the end of the day, while I can wrap my head around what I want to do, I'm not nearly adroit enough with Keras and/or Tensorflow to actually implement it. (In fact, this is my first non-toy project with the library.)

I have unsuccessfully searched literature for similar schemes to parrot, or example code to fiddle with. And I don't even know if this idea is possible from within TF/Keras.

I already have the two models working in isolation. (As in I've worked out the dimensionality, and done some training with dummy data to get garbage outputs for the second model, and the first model is based off of a previous iteration of this problem and has been fully trained.) If I have Model[0] and Model[1] as python variables (let's call them model_a and model_b), then how would I chain them together to do this?

Edit to add:

If this is all unclear, perhaps having the dimensions of each input and output will help:

The dimensions of each input and output are:

Input: (batch_size, model_a_timesteps, input_size)
IR: (batch_size, model_a_timesteps, ir_size)

IR[i] (after duplication): (batch_size, model_b_timesteps, ir_size)
Out[i]: (batch_size, model_b_timesteps, output_size)
Out: (batch_size, model_a_timesteps, model_b_timesteps, output_size)

like image 402
OmnipotentEntity Avatar asked Jul 29 '20 15:07

OmnipotentEntity


1 Answers

As this question has multiple major parts, I've dedicated a Q&A to the core challenge: stateful backpropagation. This answer focuses on implementing the variable output step length.


Description:

  • As validated in Case 5, we can take a bottom-up first approach. First we feed the complete input to model_a (A) - then, feed its outputs as input to model_b (B), but this time one step at a time.
  • Note that we must chain B's output steps per A's input step, not between A's input steps; i.e., in your diagram, gradient is to flow between Out[0][1] and Out[0][0], but not between Out[2][0] and Out[0][1].
  • For computing loss it won't matter whether we use a ragged or padded tensor; we must however use a padded tensor for writing to TensorArray.
  • Loop logic in code below is general; specific attribute handling and hidden state passing, however, is hard-coded for simplicity, but can be rewritten for generality.

Code: at bottom.


Example:

  • Here we predefine the number of iterations for B per input from A, but we can implement any arbitrary stopping logic. For example, we can take a Dense layer's output from B as a hidden state and check if its L2-norm exceeds a threshold.
  • Per above, if longest_step is unknown to us, we can simply set it, which is common for NLP & other tasks with a STOP token.
    • Alternatively, we may write to separate TensorArrays at every A's input with dynamic_size=True; see "point of uncertainty" below.
  • A valid concern is, how do we know gradients flow correctly? Note that we've validate them for both vertical and horizontal in the linked Q&A, but it didn't cover multiple output steps per an input step, for multiple input steps. See below.

Point of uncertainty: I'm not entirely sure whether gradients interact between e.g. Out[0][1] and Out[2][0]. I did, however, verify that gradients will not flow horizontally if we write to separate TensorArrays for B's outputs per A's inputs (case 2); reimplementing for cases 4 & 5, grads will differ for both models, including lower one with a complete single horizontal pass.

Thus we must write to a unified TensorArray. For such, as there are no ops leading from e.g. IR[1] to Out[0][1], I can't see how TF would trace it as such - so it seems we're safe. Note, however, that in below example, using steps_at_t=[1]*6 will make gradient flow in the both model horizontally, as we're writing to a single TensorArray and passing hidden states.

The examined case is confounded, however, with B being stateful at all steps; lifting this requirement, we might not need to write to a unified TensorArray for all Out[0], Out[1], etc, but we must still test against something we know works, which is no longer as straightforward.


Example [code]:

import numpy as np
import tensorflow as tf

#%%# Make data & models, then fit ###########################################
x0 = y0 = tf.constant(np.random.randn(2, 3, 4))
msn = MultiStatefulNetwork(batch_shape=(2, 3, 4), steps_at_t=[3, 4, 2])

#%%#############################################
with tf.GradientTape(persistent=True) as tape:
    outputs = msn(x0)
    # shape: (3, 4, 2, 4), 0-padded
    # We can pad labels accordingly.
    # Note the (2, 4) model_b's output shape, which is a timestep slice;
    # model_b is a *slice model*. Careful in implementing various logics
    # which are and aren't intended to be stateful.

Methods:

Not the cleanest, nor most optimal code, but it works; room for improvement.

More importantly: I implemented this in Eager, and have no idea how it'll work in Graph, and making it work for both can be quite tricky. If needed, just run in Graph and compare all values as done in the "cases".

# ideally we won't `import tensorflow` at all; kept for code simplicity
import tensorflow as tf
from tensorflow.python.util import nest
from tensorflow.python.ops import array_ops, tensor_array_ops
from tensorflow.python.framework import ops

from tensorflow.keras.layers import Input, SimpleRNN, SimpleRNNCell
from tensorflow.keras.models import Model

#######################################################################
class MultiStatefulNetwork():
    def __init__(self, batch_shape=(2, 6, 4), steps_at_t=[]):
        self.batch_shape=batch_shape
        self.steps_at_t=steps_at_t

        self.batch_size = batch_shape[0]
        self.units = batch_shape[-1]
        self._build_models()

    def __call__(self, inputs):
        outputs = self._forward_pass_a(inputs)
        outputs = self._forward_pass_b(outputs)
        return outputs

    def _forward_pass_a(self, inputs):
        return self.model_a(inputs, training=True)

    def _forward_pass_b(self, inputs):
        return model_rnn_outer(self.model_b, inputs, self.steps_at_t)

    def _build_models(self):
        ipt = Input(batch_shape=self.batch_shape)
        out = SimpleRNN(self.units, return_sequences=True)(ipt)
        self.model_a = Model(ipt, out)

        ipt  = Input(batch_shape=(self.batch_size, self.units))
        sipt = Input(batch_shape=(self.batch_size, self.units))
        out, state = SimpleRNNCell(4)(ipt, sipt)
        self.model_b = Model([ipt, sipt], [out, state])

        self.model_a.compile('sgd', 'mse')
        self.model_b.compile('sgd', 'mse')


def inner_pass(model, inputs, states):
    return model_rnn(model, inputs, states)


def model_rnn_outer(model, inputs, steps_at_t=[2, 2, 4, 3]):
    def outer_step_function(inputs, states):
        x, steps = inputs
        x = array_ops.expand_dims(x, 0)
        x = array_ops.tile(x, [steps, *[1] * (x.ndim - 1)])  # repeat steps times
        output, new_states = inner_pass(model, x, states)
        return output, new_states

    (outer_steps, steps_at_t, longest_step, outer_t, initial_states,
     output_ta, input_ta) = _process_args_outer(model, inputs, steps_at_t)

    def _outer_step(outer_t, output_ta_t, *states):
        current_input = [input_ta.read(outer_t), steps_at_t.read(outer_t)]
        output, new_states = outer_step_function(current_input, tuple(states))

        # pad if shorter than longest_step.
        # model_b may output twice, but longest in `steps_at_t` is 4; then we need
        # output.shape == (2, *model_b.output_shape) -> (4, *...)
        # checking directly on `output` is more reliable than from `steps_at_t`
        output = tf.cond(
            tf.math.less(output.shape[0], longest_step),
            lambda: tf.pad(output, [[0, longest_step - output.shape[0]],
                                    *[[0, 0]] * (output.ndim - 1)]),
            lambda: output)

        output_ta_t = output_ta_t.write(outer_t, output)
        return (outer_t + 1, output_ta_t) + tuple(new_states)

    final_outputs = tf.while_loop(
        body=_outer_step,
        loop_vars=(outer_t, output_ta) + initial_states,
        cond=lambda outer_t, *_: tf.math.less(outer_t, outer_steps))

    output_ta = final_outputs[1]
    outputs = output_ta.stack()
    return outputs


def _process_args_outer(model, inputs, steps_at_t):
    def swap_batch_timestep(input_t):
        # Swap the batch and timestep dim for the incoming tensor.
        # (samples, timesteps, channels) -> (timesteps, samples, channels)
        # iterating dim0 to feed (samples, channels) slices expected by RNN
        axes = list(range(len(input_t.shape)))
        axes[0], axes[1] = 1, 0
        return array_ops.transpose(input_t, axes)

    inputs = nest.map_structure(swap_batch_timestep, inputs)

    assert inputs.shape[0] == len(steps_at_t)
    outer_steps = array_ops.shape(inputs)[0]  # model_a_steps
    longest_step = max(steps_at_t)
    steps_at_t = tensor_array_ops.TensorArray(
        dtype=tf.int32, size=len(steps_at_t)).unstack(steps_at_t)

    # assume single-input network, excluding states which are handled separately
    input_ta = tensor_array_ops.TensorArray(
        dtype=inputs.dtype,
        size=outer_steps,
        element_shape=tf.TensorShape(model.input_shape[0]),
        tensor_array_name='outer_input_ta_0').unstack(inputs)

    # TensorArray is used to write outputs at every timestep, but does not
    # support RaggedTensor; thus we must make TensorArray such that column length
    # is that of the longest outer step, # and pad model_b's outputs accordingly
    element_shape = tf.TensorShape((longest_step, *model.output_shape[0]))

    # overall shape: (outer_steps, longest_step, *model_b.output_shape)
    # for every input / at each step we write in dim0 (outer_steps)
    output_ta = tensor_array_ops.TensorArray(
        dtype=model.output[0].dtype,
        size=outer_steps,
        element_shape=element_shape,
        tensor_array_name='outer_output_ta_0')

    outer_t = tf.constant(0, dtype='int32')
    initial_states = (tf.zeros(model.input_shape[0], dtype='float32'),)

    return (outer_steps, steps_at_t, longest_step, outer_t, initial_states,
            output_ta, input_ta)


def model_rnn(model, inputs, states):
    def step_function(inputs, states):
        output, new_states = model([inputs, *states], training=True)
        return output, new_states

    initial_states = states
    input_ta, output_ta, time, time_steps_t = _process_args(model, inputs)

    def _step(time, output_ta_t, *states):
        current_input = input_ta.read(time)
        output, new_states = step_function(current_input, tuple(states))

        flat_state = nest.flatten(states)
        flat_new_state = nest.flatten(new_states)
        for state, new_state in zip(flat_state, flat_new_state):
            if isinstance(new_state, ops.Tensor):
                new_state.set_shape(state.shape)

        output_ta_t = output_ta_t.write(time, output)
        new_states = nest.pack_sequence_as(initial_states, flat_new_state)
        return (time + 1, output_ta_t) + tuple(new_states)

    final_outputs = tf.while_loop(
        body=_step,
        loop_vars=(time, output_ta) + tuple(initial_states),
        cond=lambda time, *_: tf.math.less(time, time_steps_t))

    new_states = final_outputs[2:]
    output_ta = final_outputs[1]
    outputs = output_ta.stack()
    return outputs, new_states


def _process_args(model, inputs):
    time_steps_t = tf.constant(inputs.shape[0], dtype='int32')

    # assume single-input network (excluding states)
    input_ta = tensor_array_ops.TensorArray(
        dtype=inputs.dtype,
        size=time_steps_t,
        tensor_array_name='input_ta_0').unstack(inputs)

    # assume single-output network (excluding states)
    output_ta = tensor_array_ops.TensorArray(
        dtype=model.output[0].dtype,
        size=time_steps_t,
        element_shape=tf.TensorShape(model.output_shape[0]),
        tensor_array_name='output_ta_0')

    time = tf.constant(0, dtype='int32', name='time')
    return input_ta, output_ta, time, time_steps_t
like image 190
OverLordGoldDragon Avatar answered Nov 05 '22 08:11

OverLordGoldDragon