Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the right way to manipulate the shape of a tensor when there are unknown elements in it?

Let's say that I have a tensor of shape (None, None, None, 32) and I want to reshape this to (None, None, 32) where the middle dimension is the product of two middle dimensions of the original one. What is the right way to do so?

like image 484
Mehran Avatar asked Oct 09 '19 12:10

Mehran


1 Answers

import keras.backend as K

def flatten_pixels(x):
    shape = K.shape(x)
    newShape = K.concatenate([
                                 shape[0:1], 
                                 shape[1:2] * shape[2:3],
                                 shape[3:4]
                             ])

    return K.reshape(x, newShape)

Use it in a Lambda layer:

from keras.layers import Lambda

model.add(Lambda(flatten_pixels))

A little knowledge:

  • K.shape returns the "current" shape of the tensor, containing data - It's a Tensor containing int values for all dimensions. It only exists properly when running the model and can't be used in model definition, only in runtime calculations.
  • K.int_shape returns the "definition" shape of the tensor as a tuple. This means the variable dimensions will come containing None values.
like image 140
Daniel Möller Avatar answered Nov 18 '22 03:11

Daniel Möller