Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: Convolutions with different filter for each sample in the mini-batch

I would like to have a 2d convolution with a filter which depends on the sample in the mini-batch in tensorflow. Any ideas how one could do that, especially if the number of sample per mini-batch is not known?

Concretely, I have input data inp of the form MB x H x W x Channels, and I have filters F of the form MB x fh x fw x Channels x OutChannels.

It is assumed that

inp = tf.placeholder('float', [None, H, W, channels_img], name='img_input').

I would like to do tf.nn.conv2d(inp, F, strides = [1,1,1,1]), but this is not allowed because F cannot have a mini-batch dimension. Any idea how to solve this problem?

like image 293
patapouf_ai Avatar asked Feb 06 '17 13:02

patapouf_ai


3 Answers

You could use tf.map_fn as follows:

inp = tf.placeholder(tf.float32, [None, h, w, c_in]) 
def single_conv(tupl):
    x, kernel = tupl
    return tf.nn.conv2d(x, kernel, strides=(1, 1, 1, 1), padding='VALID')
# Assume kernels shape is [tf.shape(inp)[0], fh, fw, c_in, c_out]
batch_wise_conv = tf.squeeze(tf.map_fn(
    single_conv, (tf.expand_dims(inp, 1), kernels), dtype=tf.float32),
    axis=1
)

It is important to specify dtype for map_fn. Basically, this solution defines batch_dim_size 2D convolution operations.

like image 54
Jos van de Wolfshaar Avatar answered Oct 19 '22 23:10

Jos van de Wolfshaar


I think the proposed trick is actually not right. What happens with a tf.conv3d() layer is that the input gets convolved on depth (=actual batch) dimension AND then summed along resulting feature maps. With padding='SAME' the resulting number of outputs then happens to be the same as batch size so one gets fooled!

EDIT: I think a possible way to do a convolution with different filters for the different mini-batch elements involves 'hacking' a depthwise convolution. Assuming batch size MB is known:

inp = tf.placeholder(tf.float32, [MB, H, W, channels_img])

# F has shape (MB, fh, fw, channels, out_channels)
# REM: with the notation in the question, we need: channels_img==channels

F = tf.transpose(F, [1, 2, 0, 3, 4])
F = tf.reshape(F, [fh, fw, channels*MB, out_channels)

inp_r = tf.transpose(inp, [1, 2, 0, 3]) # shape (H, W, MB, channels_img)
inp_r = tf.reshape(inp, [1, H, W, MB*channels_img])

out = tf.nn.depthwise_conv2d(
          inp_r,
          filter=F,
          strides=[1, 1, 1, 1],
          padding='VALID') # here no requirement about padding being 'VALID', use whatever you want. 
# Now out shape is (1, H, W, MB*channels*out_channels)

out = tf.reshape(out, [H, W, MB, channels, out_channels) # careful about the order of depthwise conv out_channels!
out = tf.transpose(out, [2, 0, 1, 3, 4])
out = tf.reduce_sum(out, axis=3)

# out shape is now (MB, H, W, out_channels)

In case MB is unknown, it should be possible to determine it dynamically using tf.shape() (I think)

like image 35
drasros Avatar answered Oct 19 '22 23:10

drasros


The accepted answer is slightly wrong in how it treats the dimensions, as they are changed by padding = "VALID" (he treats them as if padding = "SAME"). Hence in the general case, the code will crash, due to this mismatch. I attach his corrected code, with both scenarios correctly treated.

inp = tf.placeholder(tf.float32, [MB, H, W, channels_img])

# F has shape (MB, fh, fw, channels, out_channels)
# REM: with the notation in the question, we need: channels_img==channels

F = tf.transpose(F, [1, 2, 0, 3, 4])
F = tf.reshape(F, [fh, fw, channels*MB, out_channels)

inp_r = tf.transpose(inp, [1, 2, 0, 3]) # shape (H, W, MB, channels_img)
inp_r = tf.reshape(inp_r, [1, H, W, MB*channels_img])

padding = "VALID" #or "SAME"
out = tf.nn.depthwise_conv2d(
          inp_r,
          filter=F,
          strides=[1, 1, 1, 1],
          padding=padding) # here no requirement about padding being 'VALID', use whatever you want. 
# Now out shape is (1, H-fh+1, W-fw+1, MB*channels*out_channels), because we used "VALID"

if padding == "SAME":
    out = tf.reshape(out, [H, W, MB, channels, out_channels)
if padding == "VALID":
    out = tf.reshape(out, [H-fh+1, W-fw+1, MB, channels, out_channels)
out = tf.transpose(out, [2, 0, 1, 3, 4])
out = tf.reduce_sum(out, axis=3)

# out shape is now (MB, H-fh+1, W-fw+1, out_channels)
like image 35
Žiga Sajovic Avatar answered Oct 19 '22 23:10

Žiga Sajovic