Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: stack all row pairs from a tensor

Given a tensor t=[[1,2], [3,4]], I need to produce ts=[[1,2,1,2], [1,2,3,4], [3,4,1,2], [3,4,3,4]]. That is, I need to stack together all row pairs. Important: the tensor has dimension [None, 2], ie. the first dimension is variable.

I have tried:

  • Using a tf.while_loop to generate a list of indices idx=[[0, 0], [0, 1], [1, 0], [1, 1]], then tf.gather(ts, idx). This works but is messy and I don't know what to do about gradients.
  • 2 for loops iterating over tf.unstack(t), adding stacked rows to a buffer, then tf.stack(buffer). This does not work if the first dimension is variable.
  • To look for inspiration in broadcasting. For instance, given x=t.expand_dims(t, 0), y=t.expand_dims(t, 1), s=tf.reshape(tf.add(x, y), [-1, 2]) s will be [[2, 4], [4, 6], [4, 6], [6, 8]], ie. the sum of every row combination. But how can I do stacking instead of sum? I've been failing for 2 days :)
like image 663
hyperio Avatar asked May 25 '18 20:05

hyperio


People also ask

What is TF concat?

Used in the notebooks stack , tf. repeat . Concatenates the list of tensors values along dimension axis . If values[i].shape = [D0, D1, ... Daxis(i), ...Dn] , the concatenated result has shape.

How do you split a tensor in keras?

Use Lambda to split a tensor of shape (64,16,16) into (64,1,1,256) and then subset any indexes you need.

What is TF argMax?

The tf. argMax() function is used to return the indices for the maximum values of the specified Tensor along an axis. The output result has the same shape as input with the dimension along the axis removed. Syntax: tf.argMax (x, axis)


2 Answers

Solution with tf.meshgrid() and some reshaping:

import tensorflow as tf
import numpy as np

t = tf.placeholder(tf.int32, [None, 2])
num_rows, size_row = tf.shape(t)[0], tf.shape(t)[1] # actual dynamic dimensions

# Getting pair indices using tf.meshgrid:
idx_range = tf.range(num_rows)
pair_indices = tf.stack(tf.meshgrid(*[idx_range, idx_range]))
pair_indices = tf.transpose(pair_indices, perm=[1, 2, 0])

# Finally gathering the rows accordingly:
res = tf.reshape(tf.gather(t, pair_indices), (-1, size_row * 2))

with tf.Session() as sess:
    print(sess.run(res, feed_dict={t: np.array([[1,2], [3,4], [5,6]])}))
    # [[1 2 1 2]
    #  [3 4 1 2]
    #  [5 6 1 2]
    #  [1 2 3 4]
    #  [3 4 3 4]
    #  [5 6 3 4]
    #  [1 2 5 6]
    #  [3 4 5 6]
    #  [5 6 5 6]]

Solution using cartesian product:

import tensorflow as tf
import numpy as np

t = tf.placeholder(tf.int32, [None, 2])
num_rows, size_row = tf.shape(t)[0], tf.shape(t)[1] # actual dynamic dimensions

# Getting pair indices by computing the indices cartesian product:
row_idx = tf.range(num_rows)
row_idx_a = tf.expand_dims(tf.tile(tf.expand_dims(row_idx, 1), [1, num_rows]), 2)
row_idx_b = tf.expand_dims(tf.tile(tf.expand_dims(row_idx, 0), [num_rows, 1]), 2)
pair_indices = tf.concat([row_idx_a, row_idx_b], axis=2)

# Finally gathering the rows accordingly:
res = tf.reshape(tf.gather(t, pair_indices), (-1, size_row * 2))

with tf.Session() as sess:
    print(sess.run(res, feed_dict={t: np.array([[1,2], [3,4], [5,6]])}))
    # [[1 2 1 2]
    #  [1 2 3 4]
    #  [1 2 5 6]
    #  [3 4 1 2]
    #  [3 4 3 4]
    #  [3 4 5 6]
    #  [5 6 1 2]
    #  [5 6 3 4]
    #  [5 6 5 6]]
like image 87
benjaminplanche Avatar answered Oct 23 '22 18:10

benjaminplanche


Can be achieved by:

tf.concat([tf.tile(tf.expand_dims(t,1), [1, tf.shape(t)[0], 1]), tf.tile(tf.expand_dims(t,0), [tf.shape(t)[0], 1, 1])], axis=2)

Detailed steps:

t = tf.placeholder(tf.int32, shape=[None, 2])
#repeat each row of t
d = tf.tile(tf.expand_dims(t,1), [1, tf.shape(t)[0], 1])
#Output:
#[[[1 2] [1 2]]
# [[3 4] [3 4]]]

#repeat the entire input t
e = tf.tile(tf.expand_dims(t,0), [tf.shape(t)[0], 1, 1])
#Output:
#[[[1 2] [3 4]]
# [[1 2] [3 4]]]

#concat
f = tf.concat([d, e], axis=2)

with tf.Session() as sess:
   print(sess.run(f, {t:np.asarray([[1,2],[3,4]])}))  
#Output
#[[[1 2 1 2]
#[1 2 3 4]]
#[[3 4 1 2]
#[3 4 3 4]]]
like image 1
vijay m Avatar answered Oct 23 '22 19:10

vijay m