Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to combine tensor elements with similar values of specific column?

We got this 3D input_tensor which is a tensor representing (batch_size, N, 2).

  • Where,
    • batch_size = total batches
    • N = total predictions,
    • 2 = (label, score)

I want to add the score values (2nd column elements) where labels (1st column elements) are same per batch. For example, given this tensor with 3 batches, 4 predictions per batch and 2 elements; I want required_output_tensor as result.

Condition: No for loops or tf.map_fn() for this answer. Reason, tf.map_fn() is SLOW on GPU with TF2.X. You can take a look at my sample code here that is working on 2d tensor and I can use the same with tf.map_fn().

input_tensor = tf.constant([
    [
        [2., 0.7],
        [1., 0.1],
        [3., 0.4],
        [2., 0.8],
    ],
    [
        [2., 0.7],
        [1., 0.1],
        [1., 0.4],
        [4., 0.8],
    ],
    [
        [3., 0.7],
        [1., 0.1],
        [3., 0.4],
        [4., 0.8],
    ]
])

required_output_tensor = [
    [
        [2., 1.5],
        [1., 0.1],
        [3., 0.4],
    ],
    [
        [2., 0.7],
        [1., 0.5],
        [4., 0.8],
    ],
    [
        [3., 1.1],
        [1., 0.1],
        [4., 0.8],
    ]
]

EDIT: I can see how we will end up with ragged tensor. In that case, I'm fine with choosing top-k elements per batch where k=min(size(smallest_batch)), or it can be hard coded to topk=2.

EDIT 2: Adding additional input to try out the proposed solution:

additional_input_tensor = tf.constant([
    [
        [2., 0.5],
        [1., 0.1],
        [3., 0.4],
        [2., 0.5],
    ],
    [
        [22., 0.7],
        [11., 0.2],
        [11., 0.3],
        [44., 0.8],
    ],
    [
        [3333., 0.7],
        [1111., 0.1],
        [4444., 0.4],
        [5555., 0.8],
    ],
    [
        [2., 0.9],
        [1., 0.2],
        [5., 0.3],
        [2., 0.9],
    ]
])
like image 415
Snehal Avatar asked Oct 16 '22 01:10

Snehal


2 Answers

It's not entirely as you asked, but if you know the number of classes, and you don't want to have a ragged tensor, you can use one-hot encodings to add the different scores for the same classes:

input_tensor = tf.constant([
    [
        [2., 0.7],
        [1., 0.1],
        [3., 0.4],
        [2., 0.8],
    ],
    [
        [2., 0.7],
        [1., 0.1],
        [1., 0.4],
        [4., 0.8],
    ],
    [
        [3., 0.7],
        [1., 0.1],
        [3., 0.4],
        [4., 0.8],
    ]
])


number_of_classes = 5

#first split the labels from scores
labels = tf.expand_dims(input_tensor[:,:,0], axis=-1)
scores = tf.expand_dims(input_tensor[:,:,1], axis=-1)

#get a one-hot encoding for the labels
#the way you do this would likely depend on your specific labels
#the way I do it here is not very robust (maybe use half open intervals instead)
class_indices = tf.reshape(tf.range(number_of_classes, dtype=tf.float32), shape=(1,1,number_of_classes))
one_hots = tf.cast(tf.equal(class_indices, labels), tf.float32)
print(one_hots.shape)  # (batch, N, number_of_classes)

#now multiply the one hots by the scores, and add all together
scored_one_hots = scores * one_hots
scores_per_index = tf.reduce_sum(scored_one_hots, axis=1) # (batch, number_of_classes) 
# where the second index denotes the class and contains the score for that class

# now finish up by combining these scores with the labels
# edit: of course this part too depends on how you actually did the encoding
batch_size = input_tensor.shape[0]
ordered_labels = tf.repeat(tf.expand_dims(tf.range(number_of_classes, dtype=tf.float32), axis=0), batch_size, axis=0)

result = tf.stack([ordered_labels, scores_per_index], axis=2)
print(result)

print result:

(3, 4, 5)
tf.Tensor(
[[[0.  0. ]
  [1.  0.1]
  [2.  1.5]
  [3.  0.4]
  [4.  0. ]]

 [[0.  0. ]
  [1.  0.5]
  [2.  0.7]
  [3.  0. ]
  [4.  0.8]]

 [[0.  0. ]
  [1.  0.1]
  [2.  0. ]
  [3.  1.1]
  [4.  0.8]]], shape=(3, 5, 2), dtype=float32)

The way you make your one-hots would depend on the specifics of the labels (tf.equals might not be the best choice, but you can use comparisons etc.).

like image 72
simon Avatar answered Oct 22 '22 10:10

simon


This question is not well defined in general, as it can happen that you have a different number of non-repeated id values in the input groups, so the result would not be a dense tensor. You could try using a ragged tensor, although that can be limiting. One option is to make a result where every group in the output has every id, and the scores for those ids that are not in the corresponding input group are simply set to zero. This is how you could do that:

import tensorflow as tf

input_tensor = tf.constant([
    [
        [2., 0.7],
        [1., 0.1],
        [3., 0.4],
        [2., 0.8],
    ],
    [
        [2., 0.7],
        [1., 0.1],
        [1., 0.4],
        [4., 0.8],
    ],
    [
        [3., 0.7],
        [1., 0.1],
        [3., 0.4],
        [4., 0.8],
    ]
])
# Take input tensor shape
s = tf.shape(input_tensor)
# Flatten first dimensions
flat = tf.reshape(input_tensor, (-1, 2))
# Find unique id values
group_ids, group_idx = tf.unique(flat[:, 0], out_idx=s.dtype)
# Shift id indices per group in the input
num_groups = tf.reduce_max(group_idx) + 1
group_shift = tf.tile(tf.expand_dims(num_groups * tf.range(s[0]), 1), (1, s[1]))
group_idx_shift = group_idx + tf.reshape(group_shift, (-1,))
# Aggregate per group in the input
num_groups_shift = num_groups * s[0]
# Either use unsorted_segment_sum
group_sum = tf.math.unsorted_segment_sum(flat[:, 1], group_idx_shift, num_groups_shift)
# Or use bincount
group_sum = tf.math.bincount(group_idx_shift, weights=flat[:, 1],
                             minlength=num_groups_shift)
# Reshape and concatenate
group_sum_res = tf.reshape(group_sum, (s[0], num_groups))
group_ids_res = tf.tile(tf.expand_dims(group_ids, 0), (s[0], 1))
result = tf.stack([group_ids_res, group_sum_res], axis=-1)
# Sort results
result_s = tf.argsort(group_sum_res, axis=-1, direction='DESCENDING')
result_sorted = tf.gather_nd(result, tf.expand_dims(result_s, axis=-1), batch_dims=1)
print(result_sorted.numpy())
# [[[2.  1.5]
#   [3.  0.4]
#   [1.  0.1]
#   [4.  0. ]]
# 
#  [[4.  0.8]
#   [2.  0.7]
#   [1.  0.5]
#   [3.  0. ]]
# 
#  [[3.  1.1]
#   [4.  0.8]
#   [1.  0.1]
#   [2.  0. ]]]

EDIT:

Here is an alternative using a ragged tensor output instead:

import tensorflow as tf

input_tensor = tf.constant([...])
# Same as before
s = tf.shape(input_tensor)
flat = tf.reshape(input_tensor, (-1, 2))
group_ids, group_idx = tf.unique(flat[:, 0], out_idx=s.dtype)
num_groups = tf.reduce_max(group_idx) + 1
group_shift = tf.tile(tf.expand_dims(num_groups * tf.range(s[0]), 1), (1, s[1]))
group_idx_shift = group_idx + tf.reshape(group_shift, (-1,))
# Apply unique again to find ids per batch
group_ids2_ref, group_idx2 = tf.unique(group_idx_shift)
group_ids2 = tf.gather(group_ids, group_ids2_ref % num_groups)
# Also can use unsorted_segment_sum here if preferred
group_sum = tf.math.bincount(group_idx2, weights=flat[:, 1])
# Count number of elements in each output group
out_sizes = tf.math.bincount(group_ids2_ref // num_groups, minlength=s[0])
# Make ragged result
group_sum_r = tf.RaggedTensor.from_row_lengths(group_sum, out_sizes)
group_ids_r = tf.RaggedTensor.from_row_lengths(group_ids2, out_sizes)
result = tf.stack([group_ids_r, group_sum_r], axis=-1)
print(*result.to_list(), sep='\n')
# [[2.0, 1.5], [1.0, 0.10000000149011612], [3.0, 0.4000000059604645]]
# [[2.0, 0.699999988079071], [1.0, 0.5], [4.0, 0.800000011920929]]
# [[3.0, 1.100000023841858], [1.0, 0.10000000149011612], [4.0, 0.800000011920929]]
like image 32
jdehesa Avatar answered Oct 22 '22 10:10

jdehesa