Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Batched Cross Correlation in Tensorflow

I need to perform 2D cross convolution (e.g. like MATLAB's https://www.mathworks.com/help/signal/ref/xcorr2.html) across a batch of pairs of feature maps.

For clarification:

Let X be B x W1 x H1 x C and Y be B x W2 x H2 x C.

The output I want is of shape B x W2 x H2 x 1 (assuming we treat X as the "filter" we pass over Y with SAME padding) where the i-th 1 x W2 x H2 x 1 slice of the output is the cross correlation between the X[i,:,:,:] and Y[i,:,:,:] e.g. something like

tf.nn.conv2d(Y[i,:,:,:], X[i,:,:,:], [1,1,1,1], padding='SAME')

Is there an efficient way to implement this operation?

Note: if X is 1 x W1 x H1 x C and we want to cross-correlate it with Y at each slice B, this is easy:

cross_corr = tf.nn.conv2d(
    Y, tf.transpose(X, perm[1,2,3,0], [1,1,1,1], padding='SAME')

which takes advantage of the fact that Tensorflow implements conv2d as cross correlation and the fact that the we can treat the smaller tensor as essentially a filter after transposing. This does NOT solve my issue because I need to take the cross convolution of B different filters.

Maybe conv3d is a possibility?

Note2: matconvnet's vl_nnconv does this if the filter channels divides the input channels. Does Tensorflow have an equivalent?

like image 673
C. Y. Avatar asked Mar 23 '17 01:03

C. Y.


2 Answers

You can use the tf.map function like I showed here.

I tried back then to use conv3d but didn't find a way to make it work.

like image 129
Benoit Seguin Avatar answered Oct 17 '22 19:10

Benoit Seguin


[TensorFlow 2.x version]

I am not sure if this method was around at the time this question was asked, but you can try tf.nn.convolution:

https://www.tensorflow.org/api_docs/python/tf/nn/convolution

It looks like the method supports multiple filters, and deals with N-D convolutions, where N must be 1, 2, or 3.

like image 32
ComputerScientist Avatar answered Oct 17 '22 20:10

ComputerScientist