Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Possible to use Rank Correlation as cost function in TensorFlow?

I'm working with extremely noisy data occasionally peppered with outliers, so I'm relying mostly on correlation as a measure of accuracy in my NN.

Is it possible to explictly use something like rank correlation (the Spearman correlation coefficient) as my cost function? Up to now, I've relied mostly on MSE as a proxy for correlation.

I have three major stumbling blocks right now:

1) The notion of ranking becomes much fuzzier with mini-batches.

2) How do you dynamically perform rankings? Will TensorFlow not have a gradient error/be unable to track how a change in a weight/bias affects the cost?

3) How do you determine the size of the tensors you're looking at during runtime?

For example, the code below is what I'd like to roughly do if I were to just use correlation. In practice, length needs to be passed in rather than determined at runtime.

length = tf.shape(x)[1] ## Example code. This line not meant to work.

original_loss =  -1 * length * tf.reduce_sum(tf.mul(x, y)) - (tf.reduce_sum(x) * tf.reduce_sum(y))
divisor = tf.sqrt(
  (length * tf.reduce_sum(tf.square(x)) - tf.square(tf.reduce_sum(x))) *
  (length * tf.reduce_sum(tf.square(y)) - tf.square(tf.reduce_sum(y)))
)
original_loss = tf.truediv(original_loss, divisor)
like image 478
Ryan K. Avatar asked Nov 08 '22 12:11

Ryan K.


1 Answers

Here is the code for the Spearman correlation:

predictions_rank = tf.nn.top_k(predictions_batch, k=samples, sorted=True, name='prediction_rank').indices
real_rank = tf.nn.top_k(real_outputs_batch, k=samples, sorted=True, name='real_rank').indices
rank_diffs = predictions_rank - real_rank
rank_diffs_squared_sum = tf.reduce_sum(rank_diffs * rank_diffs)
six = tf.constant(6)
one = tf.constant(1.0)
numerator = tf.cast(six * rank_diffs_squared_sum, dtype=tf.float32)
divider = tf.cast(samples * samples * samples - samples, dtype=tf.float32)
spearman_batch = one - numerator / divider

The problem with the Spearman correlation is that you need to use a sorting algorithm (top_k in my code). And there is no way to translate it to a loss value. There is no derivade of a sorting algorithm. You can use a normal correlation but I think there is no mathematically difference to use the mean squared error.

I am working on this right now for images. What I have read in papers that they use to add the ranking into the loss function is to compare 2 or 3 images (where I say images you can say anything you want to rank).

Comparing two elements:

enter image description here

enter image description here

Where N is the total number of elements and α a margin value. I got this equation from Photo Aesthetics Ranking Network with Attributes and Content Adaptation

You can also use losses with 3 elemens where you compare two of them with similar ranking with another one with a different one:

enter image description here

But in this equation you also need to add the direction of the ranking, more details in Will People Like Your Image?. In the case of this paper they use a vector encodig instead of a real value but you can do it for just a number too.

In the case of images, the comparison between images makes more sense when those images are related. So it is a good idea to run a clustering algorithm to create (maybe?) 10 clusters, so you can use elements of the same cluster to make comparisons instead of very different things. This will help the network as the inputs are related somehow and not completely different.

As a side note you should know what is more important for you, if it is the final rank order or the rank value. If it is the value you should go with mean square error, if it is the rank order you can use the losses I wrote before. Or you can even combine them.

How do you determine the size of the tensors you're looking at during runtime?

tf.shape(tensor) returns a tensor with the shape. Then you can use tf.gather(tensor,index) to get the value you want.

like image 156
jorgemf Avatar answered Nov 14 '22 21:11

jorgemf