Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Slicing tensors in tensorflow using argmax

I want to make a dynamic loss function in tensorflow. I want to calculate the energy of a signal's FFT, more specifically only a window of size 3 around the most dominant peak. I am unable to implement in TF, as it throws a lot of errors like Stride and InvalidArgumentError (see above for traceback): Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,64], [1,64], and [1] instead.

My code is this:

self.spec = tf.fft(self.signal)
self.spec_mag = tf.complex_abs(self.spec[:,1:33])
self.argm = tf.cast(tf.argmax(self.spec_mag, 1), dtype=tf.int32)
self.frac = tf.reduce_sum(self.spec_mag[self.argm-1:self.argm+2], 1)

Since I am computing batchwise of 64 and dimension of data as 64 too, the shape of self.signal is (64,64). I wish to calculate only the AC components of the FFT. As the signal is real valued, only half the spectrum would do the job. Hence, the shape of self.spec_mag is (64,32).

The max in this fft is located at self.argm which has a shape (64,1).

Now I want to calculate the energy of 3 elements around the max peak via: self.spec_mag[self.argm-1:self.argm+2].

However when I run the code and try to obtain the value of self.frac, I get thrown with multiple errors.

like image 635
canonball Avatar asked Feb 08 '18 08:02

canonball


People also ask

What is Argmax in TensorFlow?

argmax() is a method present in tensorflow math module. This method is used to find the maximum value across the axes. Syntax: tensorflow. math. argmax( input,axes,output_type,name ) Arguments: 1.

Can you slice tensors?

You can use tf. slice on higher dimensional tensors as well. You can also use tf. strided_slice to extract slices of tensors by 'striding' over the tensor dimensions.

How do you subset a tensor?

Basically to subset a tensor for some indexes [a,b,c] It needs to get in the format [[0,a],[1,b],[2,c]] and then use gather_nd() to get the subset.

Can you index a tensor?

We can index a Tensor with another Tensor and sometimes we can successfully index a Tensor with a NumPy array. The following code works for some dims : (2, 3) (2, 3, 2)


1 Answers

It seems like you were missing and index when accessing argm. Here is the fixed version of the 1, 64 version.

import tensorflow as tf
import numpy as np

x = np.random.rand(1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)

signal = xt
print('signal', signal.shape)
print('signal', signal.eval())

spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())

spec_mag = tf.abs(spec[:,1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())

argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())

frac = tf.reduce_sum(spec_mag[0][(argm[0]-1):(argm[0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())

and here is the expanded version (batch, m, n)

import tensorflow as tf
import numpy as np

x = np.random.rand(1, 1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)

signal = xt
print('signal', signal.shape)
print('signal', signal.eval())

spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())

spec_mag = tf.abs(spec[:, :, 1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())

argm = tf.cast(tf.argmax(spec_mag, 2), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())

frac = tf.reduce_sum(spec_mag[0][0][(argm[0][0]-1):(argm[0][0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())

you may want to fix function names since I edit this code at a newer version of tensorflow.

like image 149
Ali Mert Ceylan Avatar answered Sep 22 '22 17:09

Ali Mert Ceylan