Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Issue with Keras backend flatten

Why does Keras.backend.flatten not show proper dimension? I have the following:

x is <tf.Tensor 'concat_8:0' shape=(?, 4, 8, 62) dtype=float32>

After:

Keras.backend.flatten(x)

x becomes: <tf.Tensor 'Reshape_22:0' shape=(?,) dtype=float32>

Why is x not of shape=(?, 4*8*62)

EDIT-1

I get (?, ?) if I use batch_flatten (branch3x3 & branch5x5 below are tensors from previous convolutions):

x = Lambda(lambda v: K.concatenate([v[0], v[1]], axis=3))([branch3x3, branch5x5])
x = Lambda(lambda v: K.batch_flatten(v))(x)

Result of first Lambda is <tf.Tensor 'lambda_144/concat:0' shape=(?, 4, 8, 62) dtype=float32>

Result of second Lambda is <tf.Tensor 'lambda_157/Reshape:0' shape=(?, ?) dtype=float32>

EDIT-2

Tried batch_flatten but get an error downstream when I build the model output (using reshape instead of batch_flatten seems to work). branch3x3 is <tf.Tensor 'conv2d_202/Elu:0' shape=(?, 4, 8, 30) dtype=float32>, and branch5x5 is <tf.Tensor 'conv2d_203/Elu:0' shape=(?, 4, 8, 32) dtype=float32>:

from keras import backend as K
x = Lambda(lambda v: K.concatenate([v[0], v[1]], axis=3))([branch3x3, branch5x5])
x = Lambda(lambda v: K.batch_flatten(v))(x)
y = Conv1D(filters=2, kernel_size=4)(Input(shape=(4, 1)))
y = Lambda(lambda v: K.batch_flatten(v))(y)
z = Lambda(lambda v: K.concatenate([v[0], v[1]], axis=1))([x, y])
output = Dense(32, kernel_initializer=TruncatedNormal(), activation='linear')(z)
cnn = Model(inputs=[m1, m2], outputs=output)

The output statement results in the following error for the kernel_initializer: TypeError: Failed to convert object of type to Tensor. Contents: (None, 32). Consider casting elements to a supported type.

like image 723
csankar69 Avatar asked Nov 24 '17 21:11

csankar69


1 Answers

From the docstring of flatten:

def flatten(x):
    """Flatten a tensor.
    # Arguments
        x: A tensor or variable.
    # Returns
        A tensor, reshaped into 1-D
    """

So it turns a tensor with shape (batch_size, 4, 8, 62) into a 1-D tensor with shape (batch_size * 4 * 8 * 62,). That's why your new tensor has a 1-D shape (?,).

If you want to keep the first dimension, use batch_flatten:

def batch_flatten(x):
    """Turn a nD tensor into a 2D tensor with same 0th dimension.
    In other words, it flattens each data samples of a batch.
    # Arguments
        x: A tensor or variable.
    # Returns
        A tensor.
    """

EDIT: You see the shape being (?, ?) because the shape is determined dynamically at runtime. If you feed in a numpy array, you can easily verify that the shape is correct.

input_tensor = Input(shape=(4, 8, 62))
x = Lambda(lambda v: K.batch_flatten(v))(input_tensor)
print(x)

Tensor("lambda_1/Reshape:0", shape=(?, ?), dtype=float32)

model = Model(input_tensor, x)
out = model.predict(np.random.rand(32, 4, 8, 62))
print(out.shape)

(32, 1984)

EDIT-2:

From the error message, it seems that TruncatedNormal requires a fixed output shape from the previous layer. So the dynamic shape (None, None) from batch_flatten won't work.

I can think of two options:

  1. Provide manually computed output_shape to the Lambda layers:
x = Lambda(lambda v: K.concatenate([v[0], v[1]], axis=3))([branch3x3, branch5x5])
x_shape = (np.prod(K.int_shape(x)[1:]),)
x = Lambda(lambda v: K.batch_flatten(v), output_shape=x_shape)(x)

input_y = Input(shape=(4, 1))
y = Conv1D(filters=2, kernel_size=4)(input_y)
y_shape = (np.prod(K.int_shape(y)[1:]),)
y = Lambda(lambda v: K.batch_flatten(v), output_shape=y_shape)(y)

z = Lambda(lambda v: K.concatenate([v[0], v[1]], axis=1))([x, y])
output = Dense(32, kernel_initializer=TruncatedNormal(), activation='linear')(z)
cnn = Model(inputs=[m1, m2, input_y], outputs=output)
  1. Use the Flatten layer (which calls batch_flatten and computes the output shape inside of it):
x = Concatenate(axis=3)([branch3x3, branch5x5])
x = Flatten()(x)

input_y = Input(shape=(4, 1))
y = Conv1D(filters=2, kernel_size=4)(input_y)
y = Flatten()(y)

z = Concatenate(axis=1)([x, y])
output = Dense(32, kernel_initializer=TruncatedNormal(), activation='linear')(z)
cnn = Model(inputs=[m1, m2, input_y], outputs=output)

I'd prefer the latter as it makes the code less cluttered. Also,

  • You can replace the Lambda layer wrapping K.concatenate() with a Concatenate layer.
  • Remember to move the Input(shape=(4, 1)) out and provide it in your Model(inputs=...) call.
like image 132
Yu-Yang Avatar answered Nov 15 '22 15:11

Yu-Yang