Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Design patterns for tensorflow models

Tags:

tensorflow

I created several simple models, mostly based on some tutorials. From what I've done, I feel that models are quite hard to reuse, and I feel that I need to create some structure with classes to encapsulate the models.

What are the 'standard' ways of structuring tensorflow models? Are there any coding conventions/best practices for this?

like image 596
Konstantin Solomatov Avatar asked Apr 24 '17 04:04

Konstantin Solomatov


People also ask

What model does TensorFlow use?

High-level work in TensorFlow—creating nodes and layers and linking them together—uses the Keras library. The Keras API is outwardly simple; a basic model with three layers can be defined in less than 10 lines of code, and the training code for the same takes just a few more lines of code.

Which design patterns is used in Django?

Django closely follows the MVC (Model View Controller) design pattern, however, it does use its own logic in the implementation.


1 Answers

Throughout the Tensorflow examples and tutorials the prominent pattern for structuring model code is, to split the model up into three functions:

  • inference(inputs, ...) which builds the model
  • loss(logits, ...) which adds the loss on top of the logits
  • train(loss, ...) which adds training ops

When creating a model for training, your code would look something like this:

inputs = tf.placeholder(...)
logits = mymodel.inference(inputs, ...)
loss = mymodel.loss(logits, ...)
train = mymodel.train(loss, ...)

This pattern is used for the CIFAR-10 Tutorial for example (code, tutorial).

One thing one might stumble over is the fact that you cannot share (Python) variables between the inference and the loss function. That is not a big issue though, since Tensorflow provides Graph collections for exactly this use case, making for a much cleaner design (since it makes you group your things logically). One major use case for this is regularization:

If you are using the layers module (e.g. tf.layers.conv2d) you already have what you need, since all regularization penalties will be added (source) to the collection tf.GraphKeys.REGULARIZATION_LOSSES by default. For example when you do this:

conv1 = tf.layers.conv2d(
    inputs,
    filters=96,
    kernel_size=11,
    strides=4,
    activation=tf.nn.relu,
    kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
    kernel_regularizer=tf.contrib.layers.l2_regularizer(),
    name='conv1')

Your loss could look like this then:

def loss(logits, labels):
    softmax_loss = tf.losses.softmax_cross_entropy(
        onehot_labels=labels,
        logits=logits)

    regularization_loss = tf.add_n(tf.get_collection(
        tf.GraphKeys.REGULARIZATION_LOSSES)))

    return tf.add(softmax_loss, regularization_loss)

If you are not using the layers module, you would have to populate the collection manually (just like in the linked source snippet). Basically you want to add the penalties to the collection using tf.add_to_collection:

tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, reg_penalty)

With this you can calculate the loss including the regularization penalties just like above.

like image 147
thertweck Avatar answered Oct 22 '22 15:10

thertweck