Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use the pre-trained ResNet50 in tensorflow?

Tags:

tensorflow

I want to design a network built on the pre-trained network with tensorflow, taking Reset50 for example. ry released a model, however, I don't know how to use it to build my model with their checkpoint? The definition of resnet can be found in resnet.py. who can help me? Thank you very much!

def inference(x, is_training,
          num_classes=1000,
          num_blocks=[3, 4, 6, 3],  # defaults to 50-layer network
          use_bias=False, # defaults to using batch norm
          bottleneck=True):
c = Config()
c['bottleneck'] = bottleneck
c['is_training'] = tf.convert_to_tensor(is_training,
                                        dtype='bool',
                                        name='is_training')
c['ksize'] = 3
c['stride'] = 1
c['use_bias'] = use_bias
c['fc_units_out'] = num_classes
c['num_blocks'] = num_blocks
c['stack_stride'] = 2

with tf.variable_scope('scale1'):
    c['conv_filters_out'] = 64
    c['ksize'] = 7
    c['stride'] = 2
    x = conv(x, c)
    x = bn(x, c)
    x = activation(x)

with tf.variable_scope('scale2'):
    x = _max_pool(x, ksize=3, stride=2)
    c['num_blocks'] = num_blocks[0]
    c['stack_stride'] = 1
    c['block_filters_internal'] = 64
    x = stack(x, c)

with tf.variable_scope('scale3'):
    c['num_blocks'] = num_blocks[1]
    c['block_filters_internal'] = 128
    assert c['stack_stride'] == 2
    x = stack(x, c)

with tf.variable_scope('scale4'):
    c['num_blocks'] = num_blocks[2]
    c['block_filters_internal'] = 256
    x = stack(x, c)

with tf.variable_scope('scale5'):
    c['num_blocks'] = num_blocks[3]
    c['block_filters_internal'] = 512
    x = stack(x, c)

# post-net
x = tf.reduce_mean(x, reduction_indices=[1, 2], name="avg_pool")

if num_classes != None:
    with tf.variable_scope('fc'):
        x = fc(x, c)

return x
def stack(x, c):
for n in range(c['num_blocks']):
    s = c['stack_stride'] if n == 0 else 1
    c['block_stride'] = s
    with tf.variable_scope('block%d' % (n + 1)):
        x = block(x, c)
return x


def block(x, c):
filters_in = x.get_shape()[-1]

m = 4 if c['bottleneck'] else 1
filters_out = m * c['block_filters_internal']

shortcut = x  # branch 1

c['conv_filters_out'] = c['block_filters_internal']

if c['bottleneck']:
    with tf.variable_scope('a'):
        c['ksize'] = 1
        c['stride'] = c['block_stride']
        x = conv(x, c)
        x = bn(x, c)
        x = activation(x)

    with tf.variable_scope('b'):
        x = conv(x, c)
        x = bn(x, c)
        x = activation(x)

    with tf.variable_scope('c'):
        c['conv_filters_out'] = filters_out
        c['ksize'] = 1
        assert c['stride'] == 1
        x = conv(x, c)
        x = bn(x, c)
else:
    with tf.variable_scope('A'):
        c['stride'] = c['block_stride']
        assert c['ksize'] == 3
        x = conv(x, c)
        x = bn(x, c)
        x = activation(x)

    with tf.variable_scope('B'):
        c['conv_filters_out'] = filters_out
        assert c['ksize'] == 3
        assert c['stride'] == 1
        x = conv(x, c)
        x = bn(x, c)

with tf.variable_scope('shortcut'):
    if filters_out != filters_in or c['block_stride'] != 1:
        c['ksize'] = 1
        c['stride'] = c['block_stride']
        c['conv_filters_out'] = filters_out
        shortcut = conv(shortcut, c)
        shortcut = bn(shortcut, c)

return activation(x + shortcut)


def bn(x, c):
x_shape = x.get_shape()
params_shape = x_shape[-1:]

if c['use_bias']:
    bias = _get_variable('bias', params_shape,
                         initializer=tf.zeros_initializer)
    return x + bias


axis = list(range(len(x_shape) - 1))

beta = _get_variable('beta',
                     params_shape,
                     initializer=tf.zeros_initializer)
gamma = _get_variable('gamma',
                      params_shape,
                      initializer=tf.ones_initializer)

moving_mean = _get_variable('moving_mean',
                            params_shape,
                            initializer=tf.zeros_initializer,
                            trainable=False)
moving_variance = _get_variable('moving_variance',
                                params_shape,
                                initializer=tf.ones_initializer,
                                trainable=False)

# These ops will only be preformed when training.
mean, variance = tf.nn.moments(x, axis)
update_moving_mean = moving_averages.assign_moving_average(moving_mean,
                                                           mean, BN_DECAY)
update_moving_variance = moving_averages.assign_moving_average(
    moving_variance, variance, BN_DECAY)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)

mean, variance = control_flow_ops.cond(
    c['is_training'], lambda: (mean, variance),
    lambda: (moving_mean, moving_variance))

x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
#x.set_shape(inputs.get_shape()) ??

return x


def fc(x, c):
num_units_in = x.get_shape()[1]
num_units_out = c['fc_units_out']
weights_initializer = tf.truncated_normal_initializer(
    stddev=FC_WEIGHT_STDDEV)

weights = _get_variable('weights',
                        shape=[num_units_in, num_units_out],
                        initializer=weights_initializer,
                        weight_decay=FC_WEIGHT_STDDEV)
biases = _get_variable('biases',
                       shape=[num_units_out],
                       initializer=tf.zeros_initializer)
x = tf.nn.xw_plus_b(x, weights, biases)
return x


def _get_variable(name,
              shape,
              initializer,
              weight_decay=0.0,
              dtype='float',
              trainable=True):
"A little wrapper around tf.get_variable to do weight decay and add to"
"resnet collection"
if weight_decay > 0:
    regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
else:
    regularizer = None
collections = [tf.GraphKeys.VARIABLES, RESNET_VARIABLES]
return tf.get_variable(name,
                       shape=shape,
                       initializer=initializer,
                       dtype=dtype,
                       regularizer=regularizer,
                       collections=collections,
                       trainable=trainable)


def conv(x, c):
ksize = c['ksize']
stride = c['stride']
filters_out = c['conv_filters_out']

filters_in = x.get_shape()[-1]
shape = [ksize, ksize, filters_in, filters_out]
initializer = tf.truncated_normal_initializer(stddev=CONV_WEIGHT_STDDEV)
weights = _get_variable('weights',
                        shape=shape,
                        dtype='float',
                        initializer=initializer,
                        weight_decay=CONV_WEIGHT_DECAY)
return tf.nn.conv2d(x, weights, [1, stride, stride, 1], padding='SAME')


def _max_pool(x, ksize=3, stride=2):
return tf.nn.max_pool(x,
                      ksize=[1, ksize, ksize, 1],
                      strides=[1, stride, stride, 1],
                      padding='SAME')
like image 852
Lei Avatar asked Mar 03 '17 07:03

Lei


1 Answers

Basically you should use the code supplied for the model. You can create graph using them and then supply the checkpoint file, see how to do it in case of ResNet50 below:

from tensorflow.contrib.slim.nets import resnet_v1
import tensorflow as tf
import tensorflow.contrib.slim as slim

# Create graph
inputs = tf.placeholder(tf.float32, shape=[batch_size, height, width, channels])
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
    net, end_points = resnet_v1.resnet_v1_50(inputs, is_training=False)

saver = tf.train.Saver()    

with tf.Session() as sess:
    saver.restore(sess, '.resnet_v1_50.ckpt')
    representation_tensor = sess.graph.get_tensor_by_name('resnet_v1_50/pool5:0') # if you don't know names like these, consider referring to corresponding model file or generate .pbtxt file as mentioned in  @civilman628 's answer in link below
    img = ...  #load image here with size [1, 224,224, 3]
    features = sess.run(representation_tensor, {'Placeholder:0': x})

For more details, please see my (@parthg) answer on Tensorflow Github issue related to this: #7172

like image 140
Parth Avatar answered Sep 19 '22 01:09

Parth