Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using SparseTensor as a trainable variable?

I'm trying to use SparseTensor to represent weight variables in a fully-connected layer.
However, it seems that TensorFlow 0.8 doesn't allow to use SparseTensor as tf.Variable.
Is there any way to go around this?

I've tried

import tensorflow as tf

a = tf.constant(1)
b = tf.SparseTensor([[0,0]],[1],[1,1])

print a.__class__  # shows <class 'tensorflow.python.framework.ops.Tensor'>
print b.__class__  # shows <class 'tensorflow.python.framework.ops.SparseTensor'>

tf.Variable(a)     # Variable is declared correctly
tf.Variable(b)     # Fail

By the way, my ultimate goal of using SparseTensor is to permanently mask some of connections in dense form. Thus, these pruned connections are ignored while calculating and applying gradients.

In my current implementation of MLP, SparseTensor and its sparse form of matmul ops successfully reports inference outputs. However, the weights declared using SparseTensor aren't trained as training steps go.

like image 341
Younghwan Oh Avatar asked May 03 '16 10:05

Younghwan Oh


1 Answers

As a workaround to your problem, you can provide a tf.Variable (until Tensorflow v0.8) for the values of a sparse tensor. The sparsity structure has to be pre-defined in that case, the weights however remain trainable.

weights = tf.Variable(<initial-value>)
sparse_var = tf.SparseTensor(<indices>, weights, <shape>)  # v0.8
sparse_var = tf.SparseTensor(<indices>, tf.identity(weights), <shape>)  # v0.9
like image 95
T. Kipf Avatar answered Sep 22 '22 15:09

T. Kipf