Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras: How to slice tensor using information from another tensor?

I am trying to implement a custom loss function and have come across this problem. The custom loss function will look something like this:

def customLoss(z):
    y_pred = z[0] 
    y_true = z[1]
    features = z[2] 
    ...
    return loss

In my situation, y_pred and y_true are actually greyscale images. The features contained in z[2] consists of a pair of locations (x,y) where I would like to compare y_pred and y_true. These locations depend on the input training sample, so when defining the model they are passed as inputs. So my question is: how do I use the tensor features to index into the tensors y_pred and y_true?

like image 910
theQman Avatar asked Jun 12 '18 15:06

theQman


People also ask

How do you split a tensor in keras?

Use Lambda to split a tensor of shape (64,16,16) into (64,1,1,256) and then subset any indexes you need.

How do you subset a tensor?

Basically to subset a tensor for some indexes [a,b,c] It needs to get in the format [[0,a],[1,b],[2,c]] and then use gather_nd() to get the subset.

Can you slice tensors?

You can use tf. slice on higher dimensional tensors as well. You can also use tf. strided_slice to extract slices of tensors by 'striding' over the tensor dimensions.

Can you index a tensor?

Single element indexing for a 1-D tensors works mostly as expected. Like R, it is 1-based. Unlike R though, it accepts negative indices for indexing from the end of the array. (In R, negative indices are used to remove elements.)


1 Answers

If you are using Tensorflow as backend, tf.gather_nd() could do the trick (Keras doesn't have an exact equivalent yet as far as I can tell):

from keras import backend as K
import tensorflow as tf

def customLoss(z):
    y_pred = z[0]
    y_true = z[1]
    features = z[2]

    # Gathering values according to 2D indices:
    y_true_feat = tf.gather_nd(y_true, features)
    y_pred_feat = tf.gather_nd(y_pred, features)

    # Computing loss (to be replaced):
    loss = K.abs(y_true_feat - y_pred_feat)
    return loss

# Demonstration:
y_true = K.constant([[[0, 0, 0], [1, 1, 1]], [[2, 2, 2], [3, 3, 3]]])
y_pred = K.constant([[[0, 0, -1], [1, 1, 1]], [[0, 2, 0], [3, 3, 0]]])
coords = K.constant([[0, 1], [1, 0]], dtype="int64")

loss = customLoss([y_pred, y_true, coords])

tf_session = K.get_session()
print(loss.eval(session=tf_session))
# [[ 0.  0.  0.]
#  [ 2.  0.  2.]]

Note 1: Keras however has K.gather() which only works for 1D indices. If you want to use native Keras only, you could still flatten your matrices and indices, to apply this method:

def customLoss(z):
    y_pred = z[0]
    y_true = z[1]
    features = z[2]

    y_shape = K.shape(y_true)
    y_dims = K.int_shape(y_shape)[0]

    # Reshaping y_pred & y_true from (N, M, ...) to (N*M, ...):
    y_shape_flat = [y_shape[0] * y_shape[1]] + [-1] * (y_dims - 2)
    y_true_flat = K.reshape(y_true, y_shape_flat)
    y_pred_flat = K.reshape(y_pred, y_shape_flat)

    # Transforming accordingly the 2D coordinates in 1D ones:
    features_flat = features[0] * y_shape[1] + features[1]

    # Gathering the values:
    y_true_feat = K.gather(y_true_flat, features_flat)
    y_pred_feat = K.gather(y_pred_flat, features_flat)

    # Computing loss (to be replaced):
    loss = K.abs(y_true_feat - y_pred_feat)
    return loss

Note 2: To answer your question in comment, slicing can be done in a numpy-way with Tensorflow as backend:

x = K.constant([[[0, 1, 2], [3, 4, 5]], [[0, 0, 0], [0, 0, 0]]])
sess = K.get_session()

# When it comes to slicing, TF tensors work as numpy arrays:
slice = x[0, 0:2, 0:3]
print(slice.eval(session=sess))
# [[ 0.  1.  2.]
#  [ 3.  4.  5.]]

# This also works if your indices are tensors (TF will call tf.slice() below):
coords_range_per_dim = K.constant([[0, 2], [0, 3]], dtype="int32")
slice = x[0,
          coords_range_per_dim[0][0]:coords_range_per_dim[0][1],
          coords_range_per_dim[1][0]:coords_range_per_dim[1][1]
         ]
print(slice.eval(session=sess))
# [[ 0.  1.  2.]
#  [ 3.  4.  5.]]
like image 155
benjaminplanche Avatar answered Sep 30 '22 13:09

benjaminplanche