Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Filter out non-zero values in a tensor

Suppose I have an array: input = np.array([[1,0,3,5,0,8,6]]), and I want to filter out [1,3,5,8,6].

I know that you can use tf.where with a condition but the returned value still has 0's in it. Output of the following snippet is [[[1 0 3 5 0 8 6]]]. I also don't understand why tf.where needs both x and y.

Is there anyway I can get rid of the 0's in the resulting tensor?

import numpy as np
import tensorflow as tf

input = np.array([[1,0,3,5,0,8,6]])

X = tf.placeholder(tf.int32,[None,7])

zeros = tf.zeros_like(X)
index = tf.not_equal(X,zeros)
loc = tf.where(index,x=X,y=X)

with tf.Session() as sess:
    out = sess.run([loc],feed_dict={X:input})
    print np.array(out)
like image 506
user3813674 Avatar asked Feb 12 '17 22:02

user3813674


2 Answers

Casting numbers to bool identifies zeros as False. Then you can mask as usual. Example:

x = [1,0,2]
mask = tf.cast(x, dtype=tf.bool)  # [True, False, True]
nonzero_x = tf.boolean_mask(x, mask)  # [1, 2]
like image 94
Jorge Barrios Avatar answered Sep 29 '22 13:09

Jorge Barrios


First create a boolean mask to identify where your condition is true; then apply the mask to your tensor, as shown below. You can if you want use tf.where to index - however it returns a tensor using x&y with the same rank as the input so without further work the best you could achieve would be something like [[[1 -1 3 5 -1 8 6]]] changing -1 with something that you would identify to remove later. Just using where (without x&y) will give you the index of all values where your condition is true so a solution can be created using indexes if that is what you prefer. My recommendation is below for the most clarity.

import numpy as np
import tensorflow as tf
input = np.array([[1,0,3,5,0,8,6]])
X = tf.placeholder(tf.int32,[None,7])
zeros = tf.cast(tf.zeros_like(X),dtype=tf.bool)
ones = tf.cast(tf.ones_like(X),dtype=tf.bool)
loc = tf.where(input!=0,ones,zeros)
result=tf.boolean_mask(input,loc)
with tf.Session() as sess:
 out = sess.run([result],feed_dict={X:input})
 print (np.array(out))
like image 40
The Puternerd Avatar answered Sep 29 '22 13:09

The Puternerd