Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TF2 / Keras slice tensor using [:, :, 0]

In TF 2.0 Beta I'm trying:

x = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
print(x.shape) # (None, 240, 2)
a = x[:, :, 0]
print(a.shape) # <unknown>

In TF 1.x I could do:

x = tf1.placeholder(tf1.float32, (None, 240, 2)
a = x[:, :, 0]

and it would work fine. How do I achieve this in TF 2.0? I think

tf.split(x, 2, axis=2)

may work, however I'd like to use slicing rather than hard coding the 2 (dim of axis 2).

like image 263
Nic Avatar asked Aug 10 '19 02:08

Nic


People also ask

Can you slice tensors?

You can use tf. slice on higher dimensional tensors as well. You can also use tf. strided_slice to extract slices of tensors by 'striding' over the tensor dimensions.

How do you subset a tensor?

Basically to subset a tensor for some indexes [a,b,c] It needs to get in the format [[0,a],[1,b],[2,c]] and then use gather_nd() to get the subset.

How do you access the element in a tensor?

We use Indexing and Slicing to access the values of a tensor. Indexing is used to access the value of a single element of the tensor, whereasSlicing is used to access the values of a sequence of elements. We use the assignment operator to modify the values of a tensor.

What is tf slice?

This operation extracts a slice of size size from a tensor input_ starting at the location specified by begin . The slice size is represented as a tensor shape, where size[i] is the number of elements of the 'i'th dimension of input_ that you want to slice.


1 Answers

The difference is that the object returned from Input represents a layer rather than anything analogous with a placeholder or tensor. So the x in your tf 2.0 code above is a layer object whereas the x in your tf 1.x code is a placeholder for a tensor.

You can define a slicing layer to perform the operation. There are out of the box layers available but, for a simple slice like this, a Lambda layer is super easy to read and perhaps comes closest to the way you have been used to slicing in tf 1.x.

Something like this:

input_lyr = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
sliced_lyr = tf.keras.layers.Lambda(lambda x: x[:,:,0])

which you can use in your keras model like this:

model = tf.keras.models.Sequential([
    input_lyr,
    sliced_lyr,
    # ...
    # <other layers>
    # ...
])

Of course the above is specific to a keras model. If, instead, you have a tensor rather than a keras layer object then the slicing works exactly as before. Something like this:

my_tensor = tf.random.uniform((8,240,2))
sliced = my_tensor[:,:,0]

print(my_tensor.shape)
print(sliced.shape)

outputs:

(8, 240, 2)
(8, 240)

as expected

like image 54
Stewart_R Avatar answered Oct 12 '22 14:10

Stewart_R