Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Split output of a layer in keras

Say, I have a layer with output dims (4, x, y). I want to split this into 4 separate (1, x, y) tensors, which I can use as input for 4 other layers.

What I'm essentially looking for is the opposite of the Merge layer. I know that there's no split layer in keras, but is there a simple way to do this in keras?

like image 283
Shayan RC Avatar asked Dec 04 '16 10:12

Shayan RC


People also ask

How do you split a tensor in Keras?

Use Lambda to split a tensor of shape (64,16,16) into (64,1,1,256) and then subset any indexes you need.

What is the output of dense layer?

The output shape of the Dense layer will be affected by the number of neuron / units specified in the Dense layer. For example, if the input shape is (8,) and number of unit is 16, then the output shape is (16,).

What does Sequential do in Keras?

Keras offers a number of APIs you can use to define your neural network, including: Sequential API, which lets you create a model layer by layer for most problems. It's straightforward (just a simple list of layers), but it's limited to single-input, single-output stacks of layers.

What is Lambda layer in Keras?

The Lambda layer exists so that arbitrary expressions can be used as a Layer when constructing Sequential and Functional API models. Lambda layers are best suited for simple operations or quick experimentation. For more advanced use cases, follow this guide for subclassing tf. keras.


2 Answers

Are you looking for something like this?

import keras.backend as K
import numpy as np

val = np.random.random((4, 2, 3))
t = K.variable(value=val)
t1 = t[0, :, :]
t2 = t[1, :, :]
t3 = t[2, :, :]
t4 = t[3, :, :]

print('t1:\n', K.eval(t1))
print('t2:\n', K.eval(t2))
print('t3:\n', K.eval(t3))
print('t4:\n', K.eval(t4))
print('t:\n', K.eval(t))

It gives the following output:

t1:
 [[ 0.18787734  0.1085723   0.01127671]
 [ 0.06032621  0.14528386  0.21176969]]
t2:
 [[ 0.34292713  0.56848335  0.83797884]
 [ 0.11579451  0.21607392  0.80680907]]
t3:
 [[ 0.1908586   0.48186591  0.23439431]
 [ 0.93413448  0.535191    0.16410089]]
t4:
 [[ 0.54303145  0.78971165  0.9961108 ]
 [ 0.87826216  0.49061012  0.42450914]]
t:
 [[[ 0.18787734  0.1085723   0.01127671]
  [ 0.06032621  0.14528386  0.21176969]]

 [[ 0.34292713  0.56848335  0.83797884]
  [ 0.11579451  0.21607392  0.80680907]]

 [[ 0.1908586   0.48186591  0.23439431]
  [ 0.93413448  0.535191    0.16410089]]

 [[ 0.54303145  0.78971165  0.9961108 ]
  [ 0.87826216  0.49061012  0.42450914]]]

Note that, now t1, t2, t3, t4 is of shape(2,3).

print(t1.shape.eval()) # prints [2 3]

So, if you want to keep the 3d shape, you need to do the following:

t1 = t[0, :, :].reshape((1, 2, 3))
t2 = t[1, :, :].reshape((1, 2, 3))
t3 = t[2, :, :].reshape((1, 2, 3))
t4 = t[3, :, :].reshape((1, 2, 3))

Now, you get the spitted tensors in correct dimension.

print(t1.shape.eval()) # prints [1 2 3]

Hope that it will help you to solve your problem.

like image 83
Wasi Ahmad Avatar answered Oct 06 '22 08:10

Wasi Ahmad


You can define Lambda layers to do the slicing for you:

from keras.layers import Lambda
from keras.backend import slice
.
.
x = Lambda( lambda x: slice(x, START, SIZE))(x)

For your specific example, try:

x1 = Lambda( lambda x: slice(x, (0, 0, 0), (1, -1, -1)))(x)
x2 = Lambda( lambda x: slice(x, (1, 0, 0), (1, -1, -1)))(x)
x3 = Lambda( lambda x: slice(x, (2, 0, 0), (1, -1, -1)))(x)
x4 = Lambda( lambda x: slice(x, (3, 0, 0), (1, -1, -1)))(x)
like image 31
mohaghighat Avatar answered Oct 06 '22 08:10

mohaghighat