Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Variable tf.Variable has 'None' for gradient in TensorFlow Probability

I'm having trouble constructing a basic BNN in TFP. I'm new to TFP and BNNs in general, so I apologize if I've missed something simple.

I can train a basic NN in Tensorflow by doing the following:

model = keras.Sequential([
    keras.layers.Dense(units=100, activation='relu'),
    keras.layers.Dense(units=50, activation='relu'),
    keras.layers.Dense(units=5, activation='softmax')
])

model.compile(optimizer=optimizer, 
              loss=tf.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(
    training_data.repeat(), 
    epochs=100, 
    steps_per_epoch=(X_train.shape[0]//1024),
    validation_data=test_data.repeat(), 
    validation_steps=2
)

However, I have trouble when trying to implement a similar architecture with tfp DenseFlipout layers:

model = keras.Sequential([
    tfp.layers.DenseFlipout(units=100, activation='relu'),
    tfp.layers.DenseFlipout(units=10, activation='relu'),
    tfp.layers.DenseFlipout(units=5, activation='softmax')
])

model.compile(optimizer=optimizer, 
              loss=tf.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(
    training_data.repeat(), 
    epochs=100, 
    steps_per_epoch=(X_train.shape[0]//1024),
    validation_data=test_data.repeat(), 
    validation_steps=2
)

I get the following Value error:

ValueError: 
Variable <tf.Variable 'sequential_11/dense_flipout_15/kernel_posterior_loc:0' 
shape=(175, 100) dtype=float32> has `None` for gradient. 
Please make sure that all of your ops have a gradient defined (i.e. are differentiable). 
Common ops without gradient: K.argmax, K.round, K.eval.

I've done some googling, and have looked around the TFP docs, but am at a loss so thought I would share the issue. Have I missed something obvious?

Thanks in advance.

like image 995
mackdelany Avatar asked Aug 14 '19 12:08

mackdelany


1 Answers

I expect it's because you're using TensorFlow 2, are you? It isn't fully supported yet. If so, downgrading to 1.14 should get it working.

like image 112
AKW Avatar answered Oct 13 '22 10:10

AKW