Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Input tensor <name> enters the loop with shape (), but has shape <unknown> after one iteration

I am trying to save a model using tf.function on a greedy-decoding method.

The code is tested and works in eager-mode (debug) as expected. However, it is not working in non-eager execution.

The method gets a namedtuple called Hyp which looks like this:

Hyp = namedtuple(
    'Hyp',
    field_names='score, yseq, encoder_state, decoder_state, decoder_output'
)

The while-loop gets invoked like this:

_, hyp = tf.while_loop(
    cond=condition_,
    body=body_,
    loop_vars=(tf.constant(0, dtype=tf.int32), hyp),
    shape_invariants=(
        tf.TensorShape([]),
        tf.nest.map_structure(get_shape_invariants, hyp),
    )
)

and this is the relevant part of the body_:

def body_(i_, hypothesis_: Hyp):

    # [:] Collapsed some code ..

    def update_from_next_id_():
        return Hyp(
            # Update values ..
        )

    # The only place where I generate a new hypothesis_ namedtuple
    hypothesis_ = tf.cond(
        tf.not_equal(next_id, blank),
        true_fn=lambda: update_from_next_id_(),
        false_fn=lambda: hypothesis_
    )

    return i_ + 1, hypothesis_

What I am getting is a ValueError:

ValueError: Input tensor 'hypotheses:0' enters the loop with shape (), but has shape <unknown> after one iteration. To allow the shape to vary across iterations, use the shape_invariants argument of tf.while_loop to specify a less-specific shape.

What could be the problem here?

The following is how input_signature is defined for the tf.function I would like to serialize.

Here, self.greedy_decode_impl is the actual implementation - I know this is a bit ugly here but self.greedy_decode is what I am calling.

self.greedy_decode = tf.function(
    self.greedy_decode_impl,
    input_signature=(
        tf.TensorSpec([1, None, self.config.encoder.lstm_units], dtype=tf.float32),
        Hyp(
            score=tf.TensorSpec([], dtype=tf.float32),
            yseq=tf.TensorSpec([1, None], dtype=tf.int32),
            encoder_state=tuple(
                (tf.TensorSpec([1, lstm.units], dtype=tf.float32),
                 tf.TensorSpec([1, lstm.units], dtype=tf.float32))
                for (lstm, _) in self.encoder_network.lstm_stack
            ),
            decoder_state=tuple(
                (tf.TensorSpec([1, lstm.units], dtype=tf.float32),
                 tf.TensorSpec([1, lstm.units], dtype=tf.float32))
                for (lstm, _) in self.predict_network.lstm_stack
            ),
            decoder_output=tf.TensorSpec([1, None, self.config.decoder.lstm_units], dtype=tf.float32)
        ),
    )
)

The implementation of greedy_decode_impl:

def greedy_decode_impl(self, encoder_outputs: tf.Tensor, hypotheses: Hyp, blank=0) -> Hyp:

    hyp = hypotheses

    encoder_outputs = encoder_outputs[0]

    def condition_(i_, *_):
        time_steps = tf.shape(encoder_outputs)[0]
        return tf.less(i_, time_steps)

    def body_(i_, hypothesis_: Hyp):

        encoder_output_ = tf.reshape(encoder_outputs[i_], shape=(1, 1, -1))

        join_out = self.join_network((encoder_output_, hypothesis_.decoder_output), training=False)

        logits = tf.squeeze(tf.nn.log_softmax(tf.squeeze(join_out)))
        next_id = tf.argmax(logits, output_type=tf.int32)
        log_prob = logits[next_id]
        next_id = tf.reshape(next_id, (1, 1))

        def update_from_next_id_():
            decoder_output_, decoder_state_ = self.predict_network(
                next_id,
                memory_states=hypothesis_.decoder_state,
                training=False
            )
            return Hyp(
                score=hypothesis_.score + log_prob,
                yseq=tf.concat([hypothesis_.yseq, next_id], axis=0),
                decoder_state=decoder_state_,
                decoder_output=decoder_output_,
                encoder_state=hypothesis_.encoder_state
            )

        hypothesis_ = tf.cond(
            tf.not_equal(next_id, blank),
            true_fn=lambda: update_from_next_id_(),
            false_fn=lambda: hypothesis_
        )

        return i_ + 1, hypothesis_

    _, hyp = tf.while_loop(
        cond=condition_,
        body=body_,
        loop_vars=(tf.constant(0, dtype=tf.int32), hyp),
        shape_invariants=(
            tf.TensorShape([]),
            tf.nest.map_structure(get_shape_invariants, hyp),
        )
    )

    return hyp

Why does it work in eager-mode but not in non-eager?

According to the docs of tf.while_loop a namedtuple should be alright to use.


Fibonacci example

In order to check whether this should work with a namedtuple, I have implemented the fibonacci sequence using similar mechanisms. In order to include a condition, the loop stops appending new numbers when reaching step n // 2:

As we can see below, the approach should work without Python side-effects.

from collections import namedtuple

import tensorflow as tf

FibonacciStep = namedtuple('FibonacciStep', field_names='seq, prev_value')


def shape_list(x):
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]


def get_shape_invariants(tensor):
    shapes = shape_list(tensor)
    return tf.TensorShape([i if isinstance(i, int) else None for i in shapes])


def save_tflite(fp, concrete_fn):
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_fn])
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter.optimizations = []
    tflite_model = converter.convert()
    with tf.io.gfile.GFile(fp, 'wb') as f:
        f.write(tflite_model)


@tf.function(
    input_signature=(
        tf.TensorSpec([], dtype=tf.int32),
        FibonacciStep(
            seq=tf.TensorSpec([1, None], dtype=tf.int32),
            prev_value=tf.TensorSpec([], dtype=tf.int32),
        )
    )
)
def fibonacci(n: tf.Tensor, fibo: FibonacciStep):

    def cond_(i_, *args):
        return tf.less(i_, n)

    def body_(i_, fibo_: FibonacciStep):

        prev_value = fibo_.seq[0, -1] + fibo_.prev_value

        def append_value():
            return FibonacciStep(
                seq=tf.concat([fibo_.seq, tf.reshape(prev_value, shape=(1, 1))], axis=-1),
                prev_value=fibo_.seq[0, -1]
            )

        fibo_ = tf.cond(
            tf.less_equal(i_, n // 2),
            true_fn=lambda: append_value(),
            false_fn=lambda: fibo_
        )

        return i_ + 1, fibo_

    _, fibo = tf.while_loop(
        cond=cond_,
        body=body_,
        loop_vars=(0, fibo),
        shape_invariants=(
            tf.TensorShape([]),
            tf.nest.map_structure(get_shape_invariants, fibo),
        )
    )

    return fibo


def main():

    n = tf.constant(10, dtype=tf.int32)
    fibo = FibonacciStep(
        seq=tf.constant([[0, 1]], dtype=tf.int32),
        prev_value=tf.constant(0, dtype=tf.int32),
    )

    fibo = fibonacci(n, fibo=fibo)
    fibo = fibonacci(n + 10, fibo=fibo)

    fp = '/tmp/fibonacci.tflite'
    concrete_fn = fibonacci.get_concrete_function()
    save_tflite(fp, concrete_fn)

    print(fibo.seq.numpy()[0].tolist())

    print('All done.')


if __name__ == '__main__':
    main()

Output:

[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584]
All done.
like image 239
Stefan Falk Avatar asked Oct 24 '25 14:10

Stefan Falk


1 Answers

Alright, it turns out that

tf.concat([hypothesis_.yseq, next_id], axis=0),

was supposed to be

tf.concat([hypothesis_.yseq, next_id], axis=-1),

To be fair, the error message kind of gives you a hint where to look but "helpful" would be too much to describe it. I violated the TensorSpec by concatenating over the wrong axis, that's all, but Tensorflow is not able to point directly at the affected Tensor (yet).

like image 110
Stefan Falk Avatar answered Oct 26 '25 03:10

Stefan Falk



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!