Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras Sequential without providing input shape

I currently have a keras model that looks like this:

model = keras.Sequential()
model.add(keras.layers.Dense(100, activation=tf.nn.relu))
model.add(keras.layers.Dense(100, activation=tf.nn.relu))
model.add(keras.layers.Dense(len(labels), activation=tf.nn.softmax))

The Keras documentation tells me:

The model needs to know what input shape it should expect. For this reason, the first layer in a Sequential model (and only the first, because following layers can do automatic shape inference) needs to receive information about its input shape

However, the model as it is actually trains fine, without errors, even though I never specified the shape of the inputs.

How does it know what shape to expect? What is the default behaviour if I don't provide an input shape? How will it affect my model?

edit: this is using tf.keras, aka the Tensorflow backend for keras

like image 363
Migwell Avatar asked Sep 10 '19 13:09

Migwell


2 Answers

Nice observation - I believe the Keras documentation should be updated. When the input shape is not provided, Keras infers it from the argument x of Model.fit and only then it builds the whole model. Concretely, this is what's happening:

  1. When adding Keras layers in the Sequential model, since the argument input_shape (and, by extension, batch_input_shape) is never set, the attribute Model.inputs remains None (see Sequential.add).
  2. Then, in Model.fit, they check whether Model.inputs has been set (see Model.fit and Model._standardize_user_data) and, when it hasn't, they infer the input shape from the provided input array.
  3. Finally, in Model._set_inputs, they build the whole model with the inferred input_shape (see Model._set_inputs).

This can be verified by printing some weights (e.g. print(model.layers[0].get_weights())) before fitting the model. You will see that, when the argument input_shape or batch_input_shape is not provided to the first layer of the model, the weight's array is empty as the model is yet to be built.

like image 122
rvinas Avatar answered Oct 15 '22 20:10

rvinas


The document The Sequential model has been updated in 2020/04/12. The section Specifying the input shape in advance clarifies this problem.

When you instantiate a Sequential model without an input shape, it isn't "built": it has no weights (and calling model.weights results in an error stating just this). The weights are created when the model first sees some input data:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


model1 = keras.Sequential(
    [
        layers.Dense(2, activation="relu"),
        layers.Dense(3, activation="relu"),
        layers.Dense(4),
    ]
)  # No weights at this stage!

for layer in model1.layers:
    print(layer.weights)  # Empty

# At this point, you can't do this:
# model1.weights

# You also can't do this:
# model1.summary()

# Call the model on a test input
x = tf.ones((1, 4))
y = model1(x)

# Once a model is "built", you can call its summary() method to display its contents:

model1.summary()

You could start your model by passing an Input object to your model, so that it knows its input shape from the start:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

model2 = keras.Sequential()
model2.add(keras.Input(shape=(4,)))
model2.add(layers.Dense(2, activation="relu"))
model2.add(layers.Dense(3, activation="relu"))
model2.add(layers.Dense(4))

for layer in model2.layers:
    print(layer.weights)

print(model2.weights)

model2.summary()

A simple alternative is to just pass an input_shape argument to your first layer:

model2.add(layers.Dense(2, activation="relu", input_shape=(4,)))

Models built with a predefined input shape like this always have weights (even before seeing any data) and always have a defined output shape.

At last, the document says:

In general, it's a recommended best practice to always specify the input shape of a Sequential model in advance if you know what it is.

Back to your question:

How does it know what shape to expect?

If no input shape is defined, the model will match the data it first sees.

x = tf.ones((1, 4))
y = model1(x)

model1.summary()

# Output
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_2 (Dense)              (1, 2)                    10        
_________________________________________________________________
dense_3 (Dense)              (1, 3)                    9         
_________________________________________________________________
dense_4 (Dense)              (1, 4)                    16        
=================================================================
Total params: 35
Trainable params: 35
Non-trainable params: 0
_________________________________________________________________
x = tf.ones((3, 5, 10))
y = model1(x)
model1.summary()

# Output:
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_2 (Dense)              (3, 5, 2)                 22        
_________________________________________________________________
dense_3 (Dense)              (3, 5, 3)                 9         
_________________________________________________________________
dense_4 (Dense)              (3, 5, 4)                 16        
=================================================================
Total params: 47
Trainable params: 47
Non-trainable params: 0
_________________________________________________________________

What is the default behavior if I don't provide an input shape? How will it affect my model?

If you don't specify the input shape of model in advance, the model has no weights and you can't call model.summary() since it isn't built.

like image 28
Ynjxsjmh Avatar answered Oct 15 '22 19:10

Ynjxsjmh