Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In which cases we use the attribute trainable_variables over trainable_weights and vice-versa of a tf.keras.Model in TF2?

I was studying how to do transfer learning in TF 2 and I saw that at this tutorial from Tensorflow they use the attribute trainable_variables to reference the trainable variables of a model but in this other tutorial from the keras documentation they use the attribute trainable_weights of a tf.keras.Model.

I checked both attributes with a simple model, and they give me the same result.

import tensorflow as tf
print(tf.__version__)

inputs = tf.keras.layers.Input(shape=[64, 64, 3])

x = tf.keras.layers.Conv2D(128, kernel_size=3, strides=2)(inputs)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

model = tf.keras.Model(inputs=inputs, outputs=x)

print("\nTrainable weights")
vars_model = [var.name for var in model.trainable_weights]
print(*vars_model, sep="\n")

print("\nTrainable variables")
vars_model = [var.name for var in model.trainable_variables]
print(*vars_model, sep="\n")

Output:

2.2.0

Trainable weights
conv2d/kernel:0
conv2d/bias:0
batch_normalization/gamma:0
batch_normalization/beta:0

Trainable variables
conv2d/kernel:0
conv2d/bias:0
batch_normalization/gamma:0
batch_normalization/beta:0

I checked this other issue and tried to follow the definition of both attributes: trainable_variables seems to be here and trainable_weights seems to be here and here, since td.keras.Model also inherits from network.Network. The former seems to be returning the trainable_weights variable. But, I am not sure that this happens in "all" cases.

So, I am wondering in which cases we use trainable_variables over trainable_weights and vice-versa? and why?

like image 623
K. Bogdan Avatar asked Oct 19 '25 09:10

K. Bogdan


1 Answers

They both are same in Tensorflow version 2.2.0. If you go into the source code of base layer - tf.keras.layers.Layer (click on "View source on GitHub"), you can find the below assignment. This is the class from which all layers inherit.

  @property
  @doc_controls.do_not_generate_docs
  def trainable_variables(self):
    return self.trainable_weights

  @property
  @doc_controls.do_not_generate_docs
  def non_trainable_variables(self):
    return self.non_trainable_weights

Hope this answers your question. Happy Learning.


Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!