I want to take the shape of Input data which is passed to Input layer with (None,) shape, and use it in a for loop for some purpose.
Here's part of my code implementation:
lst_rfrm = []
Inpt_lyr = keras.Input(shape = (None,))
for k in range(tm_stp):
F = keras.layers.Lambda(lambda x, i, j: x[:, None, j : j + i])
F.arguments = {'i' : sub_len, 'j' : k}
tmp_rfrm = F(Inpt_lyr)
lst_rfrm.append(tmp_rfrm)
cnctnt_lyr = keras.layers.merge.Concatenate(axis = 1)(lst_rfrm)
#defining other layers ...
because the Input shape is (None,), I don't know what to give to for loop as range( at the code i describe it with 'tm_stp'). how can i get the shape of the input layer (the data that is passed to input layer) in this situation? any help is deeply appreciated
You can try a different type of loop. It seems you are trying sliding windows, right? You don't know the "length" to run, but you know the window size and how much of the borders to remove... so....
This function gets the slices following that principle:
windowSize = sub_len
def getWindows(x):
borderCut = windowSize - 1 #lost length in the length dimension
leftCut = range(windowSize) #start of sequence
rightCut = [i - borderCut for i in leftCut] #end of sequence - negative
rightCut[-1] = None #because it can't be zero for slicing
croppedSequences = K.stack([x[:, l: r] for l,r in zip(leftCut, rightCut)], axis=-1)
return croppedSequences
Running test:
from keras.layers import *
from keras.models import Model
import keras.backend as K
import numpy as np
windowSize = 3
batchSize = 5
randomLength = np.random.randint(5,10)
inputData = np.arange(randomLength * batchSize).reshape((batchSize, randomLength))
def getWindows(x):
borderCut = windowSize - 1
leftCut = range(windowSize)
rightCut = [i - borderCut for i in leftCut]
rightCut[-1] = None
croppedSequences = K.stack([x[:, l: r] for l,r in zip(leftCut, rightCut)], axis=-1)
return croppedSequences
inputs = Input((None,))
outputs = Lambda(getWindows)(inputs)
model = Model(inputs, outputs)
preds = model.predict(inputData)
for i, (inData, pred) in enumerate(zip(inputData, preds)):
print('sample: ', i)
print('input sequence: ', inData)
print('output sequence: \n', pred, '\n\n')
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With