Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can I train a Tensorflow keras model with complex input/output?

I am trying to train a very simple model which only have one convolution layer.

 def kernel_model(filters=1, kernel_size=3):
    input_layer = Input(shape=(250,1))
    conv_layer = Conv1D(filters=filters,kernel_size=kernel_size,padding='same',use_bias = False)(input_layer)
    model = Model(inputs=input_layer,output=conv_layer)
    return model 

But the input(X), prediction output(y_pred) and true_output(y_true) are all complex number. When I call the function model.fit(X,y_true)

There is the error TypeError: Gradients of complex tensors must set grad_ys (y.dtype = tf.complex64)

Does that means I have to write the back-propagation by hand?
What should I do to solve this problem? thanks

like image 814
Raindrop Avatar asked Nov 15 '19 00:11

Raindrop


People also ask

What function should you use to train a Keras sequential model?

Fit Keras Model You can train or fit your model on your loaded data by calling the fit() function on the model. Training occurs over epochs, and each epoch is split into batches.

What is sequential API in TensorFlow?

Figure 1: The “Sequential API” is one of the 3 ways to create a Keras model with TensorFlow 2.0. A sequential model, as the name suggests, allows you to create models layer-by-layer in a step-by-step fashion.

What is verbose in TensorFlow?

verbose = 1, which includes both progress bar and one line per epoch. verbose = 0, means silent. verbose = 2, one line per epoch i.e. epoch no./total no. of epochs.


1 Answers

Your DNN needs to mininimize the Loss-function through back-propagation. To minimize something, it naturally needs to have an ordering. Complex numbers are not ordered, while Reals are. So you generally need a loss function L: Complex -> Reals

Change your complex-valued loss function from simple square:

error = K.cast(K.mean(K.square(y_pred_propgation - y_true)),tf.complex64)

to a real-valued magnitude ||.||^2 of the complex number:

error = K.mean(K.square(K.abs(y_true-y_pred)))
like image 167
Martin Thøgersen Avatar answered Oct 13 '22 21:10

Martin Thøgersen