Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: What does tf.nn.separable_conv2d do?

Tags:

tensorflow

I'm not quite sure what tf.nn.separable_conv2d does exactly. It seems to be that the pointwise_filter is the scaling factor for different features when generating one pixel of the next layer. But I'm not sure whether my interpretation is correct. Is there any reference for this method and what's the benefit?

tf.nn.separable_conv2d generates the same shape as tf.nn.conv2d. I would assume I can replace tf.nn.conv2d with tf.nn.separable_conv2d. But the result when using tf.nn.separable_conv2d seems to be very bad. The network stopped learning very early. For MNIST dataset, the accuracy is just random guess ~ 10%.

I thought when I set the pointwise_filter values to be all 1.0 and make it not trainable, I would get the same thing as the tf.nn.conv2d. But not really... still ~10% accuracy.

But when tf.nn.conv2d is used with the same hyper-parameters, the accuracy can be 99%. Why?

Also, it requires channel_multiplier * in_channels < out_channels. Why? What is the role of channel_multiplier here?

Thanks.

Edit:

I used channel_multiplier previously as 1.0. Maybe that is a bad choice. After I change it to 2.0, the accuracy becomes much better. But what is the role of channel_multiplier? Why 1.0 is not a good value?

like image 242
Fei Avatar asked May 07 '16 18:05

Fei


2 Answers

In the regular 2D convolution performed over multiple input channels, the filter is as deep as the input and lets us freely mix channels to generate each element in the output. Depthwise convolutions don't do that - each channel is kept separate - hence the name depthwise. Here's a diagram to help explain how that works[1]:

Seperable Conv2d

If you look at the official documentation you will find:

output[b, i, j, k] = sum_{di, dj, q, r} input[b, strides[1] * i + di, strides[2] * j + dj, q] * depthwise_filter[di, dj, q, r] * pointwise_filter[0, 0, q * channel_multiplier + r, k]

And a sample code in tensorflow to test:

import tensorflow as tf
import numpy as np

width = 8
height = 8
batch_size = 100
filter_height = 3
filter_width = 3
in_channels = 3
channel_multiplier = 1
out_channels = 3


input_tensor = tf.get_variable(shape=(batch_size, height, width, in_channels), name="input")
depthwise_filter = tf.get_variable(shape=(filter_height, filter_width, in_channels, channel_multiplier), name="deptwise_filter")
pointwise_filter = tf.get_variable(shape=[1, 1, channel_multiplier * in_channels, out_channels], name="pointwise_filter")

output = tf.nn.separable_conv2d(
    input_tensor,
    depthwise_filter,
    pointwise_filter,
    strides=[1,1,1,1],
    padding='SAME',
)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    output_value = sess.run(output, feed_dict={input_tensor: np.random.rand(batch_size, width, height, in_channels),
                                               depthwise_filter: np.random.rand(filter_height, filter_width, in_channels, channel_multiplier),
                                               pointwise_filter: np.random.rand(1, 1, channel_multiplier * in_channels, out_channels)})
    print(np.shape(output_value))

credit:

[1] https://eli.thegreenplace.net/2018/depthwise-separable-convolutions-for-machine-learning/

[2] https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d

like image 77
Rohola Zandie Avatar answered Oct 20 '22 19:10

Rohola Zandie


tf.nn.separable_conv2d() implements the so-called 'separable convolution' described on slide 26 and onwards of this talk.

The idea is that instead of convolving jointly across all channels of an image, you run a separate 2D convolution on each channel with a depth of channel_multiplier. The in_channels * channel_multiplier intermediate channels get concatenated together, and mapped to out_channels using a 1x1 convolution.

It's often an effective way to reduce the parametric complexity of early convolutions in a convnet, and can materially speed up training. channel_multiplier controls that complexity, and would typically be 4 to 8 for a RGB input. For a grayscale input, using it makes little sense.

like image 29
Vincent Vanhoucke Avatar answered Oct 20 '22 19:10

Vincent Vanhoucke