Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Manually set the trainable_variables weights in Tensorflow 2

I have to set the trainable_variables value in a model in tensorflow, instead of using optimizer. Is there a function or a way to do it? I show an example code: I want set mnist_model.trainable_variables value.

for epoch in range(0,1):
  with tf.GradientTape() as tape:
  prediction = mnist_model(mnist_images, training=True)


  loss_value = loss(mnist_labels, prediction)

variables = mnist_model.trainable_variables
loss_history.append(loss_value.numpy())
grads = tape.gradient(loss_value, variables)
like image 292
fpi Avatar asked Oct 26 '25 07:10

fpi


1 Answers

model.trainable_variables returns a list of the trainable variables. When you print them out you'll see their shape.

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 16) dtype=float32>

Using this shape you can assign the weights with the .assign() method. You will need to build() your model before you do so, otherwise Tensorflow won't have trainable variables.

model.trainable_variables[0].assign(tf.fill((3, 3, 1, 16), .12345))
Out[3]: 
<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 16) dtype=float32, numpy=
array([[[[0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
          0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
          0.12345, 0.12345, 0.12345, 0.12345]],
        [[0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
          0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
          0.12345, 0.12345, 0.12345, 0.12345]]

Full working example:

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, Dropout, Flatten

class CNN(Model):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = Conv2D(filters=16, kernel_size=(3, 3), strides=(1, 1))
        self.maxp1 = MaxPool2D(pool_size=(2, 2))
        self.flat1 = Flatten()
        self.dens1 = Dense(64, activation='relu')
        self.drop1 = Dropout(5e-1)
        self.dens3 = Dense(10)

    def call(self, x, training=None, **kwargs):
        x = self.conv1(x)
        x = self.maxp1(x)
        x = self.flat1(x)
        x = self.dens1(x)
        x = self.drop1(x)
        x = self.dens3(x)
        return x

model = CNN()

model.build(input_shape=(1, 28, 28, 1))

print(model.trainable_variables[0])

model.trainable_variables[0].assign(tf.fill((3, 3, 1, 16), .12345))

print(model.trainable_variables[0])

Original weights:

<tf.Variable 'conv2d_2/kernel:0' shape=(3, 3, 1, 16) dtype=float32, numpy=
array([[[[-0.18103004, -0.18038717, -0.04171562, -0.14022854,
          -0.00918788,  0.07348467,  0.07931305, -0.03991133,
           0.12809007, -0.11934308,  0.11453925,  0.02502337,
          -0.165835  , -0.14841306,  0.1911544 , -0.09917622]],
        [[-0.0496967 ,  0.13865136, -0.17599788, -0.18716624,
          -0.03473145, -0.02006209, -0.00364855, -0.03497578,
           0.05207129,  0.07728194, -0.11234754,  0.09303482,
           0.17245303, -0.07428543, -0.19278058,  0.15201278]]]],
      dtype=float32)>

Edited weights:

<tf.Variable 'conv2d_6/kernel:0' shape=(3, 3, 1, 16) dtype=float32, numpy=
array([[[[0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
          0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
          0.12345, 0.12345, 0.12345, 0.12345]],
        [[0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
          0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
          0.12345, 0.12345, 0.12345, 0.12345]]]], dtype=float32)>
like image 163
Nicolas Gervais Avatar answered Oct 28 '25 20:10

Nicolas Gervais



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!