Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Selectively zero weights in TensorFlow?

Lets say I have an NxM weight variable weights and a constant NxM matrix of 1s and 0s mask.

If a layer of my network is defined like this (with other layers similarly defined):

masked_weights = mask*weights
layer1 = tf.relu(tf.matmul(layer0, masked_weights) + biases1)

Will this network behave as if the corresponding 0s in mask are zeros in weights during training? (i.e. as if the connections represented by those weights had been removed from the network entirely)?

If not, how can I achieve this goal in TensorFlow?

like image 741
Will Bolden Avatar asked Jul 09 '16 06:07

Will Bolden


1 Answers

The answer is yes. The experiment depicts the following graph.enter image description here

The implementation is:

import numpy as np, scipy as sp, tensorflow as tf

x = tf.placeholder(tf.float32, shape=(None, 3))
weights = tf.get_variable("weights", [3, 2])
bias = tf.get_variable("bias", [2])
mask = tf.constant(np.asarray([[0, 1], [1, 0], [0, 1]], dtype=np.float32)) # constant mask

masked_weights = tf.multiply(weights, mask)
y = tf.nn.relu(tf.nn.bias_add(tf.matmul(x, masked_weights), bias))
loss = tf.losses.mean_squared_error(tf.constant(np.asarray([[1, 1]], dtype=np.float32)),y)

weights_grad = tf.gradients(loss, weights)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print("Masked weights=\n", sess.run(masked_weights))
data = np.random.rand(1, 3)

print("Graident of weights\n=", sess.run(weights_grad, feed_dict={x: data}))
sess.close()

After running the code above, you will see the gradients are masked as well. In my example, they are:

Graident of weights
= [array([[ 0.        , -0.40866762],
       [ 0.34265977, -0.        ],
       [ 0.        , -0.35294518]], dtype=float32)]
like image 56
Tengerye Avatar answered Oct 06 '22 03:10

Tengerye