Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to split a tensor column-wise in Keras to implement STFCN

I'd like to implement the Spatiotemporal Fully Convolutional Network (STFCN) in Keras. I need to feed each depth column of a 3D convolutional output, e.g. a tensor with shape (64, 16, 16), as input to a separate LSTM.

To make this clear, I have a (64 x 16 x 16) tensor of dimensions (channels, height, width). I need to split (explicitly or implicitly) the tensor into 16 * 16 = 256 tensors of shape (64 x 1 x 1).

Here is a diagram from the STFCN paper to illustrate the Spatio-Temporal module. What I described above is the arrow between 'Spatial Features' and 'Spatio-Temporal Module'.

The connection between FCn and Spatio-Temporal Module is the relevant part of the diagram.

How would this idea be best implemented in Keras?

like image 917
jhuang Avatar asked Jan 05 '17 11:01

jhuang


1 Answers

You can use tf.split from Tensorflow using Keras Lambda layer

Use Lambda to split a tensor of shape (64,16,16) into (64,1,1,256) and then subset any indexes you need.

import numpy as np
import tensorflow as tf
import keras.backend as K
from keras.models import  Model
from keras.layers import Input, Lambda

# input data
data = np.ones((3,64,16,16))

# define lambda function to split
def lambda_fun(x) : 
    x = K.expand_dims(x, 4)
    split1 = tf.split(x, 16, 2)
    x = K.concatenate(split1, 4)
    split2 = tf.split(x, 16, 3)
    x = K.concatenate(split2, 4)
    return x

## check thet splitting works fine
input = Input(shape= (64,16,16))
ll = Lambda(lambda_fun)(input)
model = Model(inputs=input, outputs=ll)
res = model.predict(data)
print(np.shape(res))    #(3, 64, 1, 1, 256)
like image 140
Vadym B. Avatar answered Sep 27 '22 01:09

Vadym B.