Is it possible to train a generative model (i.e. variational autoencoder with custom loss calculation) with tensorflow's tpu_estimator()?
Simplified version of my VAE:
def model_fn(features, labels, mode, params):
#Encoder layers
x = layers.Input()
h = conv1D()(x)
#BOTTLENECK LAYER
z_mean = Dense()(h)
z_log_var = Dense()(h)
def sampling(args):
z_mean_, z_log_var_ = args
epsilon = tf.random_normal()
return z_mean_ + tf.exp(z_log_var_/2)*epsilon
z = Lambda(sampling, name='lambda')([z_mean, z_log_var])
#Decoder Layers
h = Dense(z)
x_decoded = TimeDistributed(Dense(activation='softmax'))(h)
#VAE
vae = tf.keras.models.Model(x, x_decoded)
#VAE LOSS
def vae_loss(x,x_decoded_mean):
x = flatten(x)
x_decoded_mean = flatten(x_decoded_mean)
xent_loss = binary_crossentropy(x, x_decoded_mean)
kl_loss = mean(1 + z_log_var - square(z_mean) - exp(z_log_var))
return xent_loss + kl_loss
optimizer = tf.train.AdamOptimizer()
optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
train_op = optimizer.minimize(vae_loss, global_step=tf.train.get_global_step())
return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=vae_loss, train_op=train_op)
The TPU configuration initializes and dataset is loaded properly with my input_fn, but get the following error which is triggered by the custom loss function:
VAE_LOSS() error:
File "TPUest.py", line 107, in model_fn
train_op = optimizer.minimize(vae_loss, global_step=tf.train.get_global_step())
File "/usr/local/lib/python2.7/dist- packages/tensorflow/python/training/optimizer.py", line 414, in minimize grad_loss=grad_loss)
File "/usr/local/lib/python2.7/distpackages/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py", line 84, in compute_gradients
loss *= scale
TypeError: unsupported operand type(s) for *=: 'function' and 'float'
A call to optimizer.minimize needs to have a loss Tensor, but what you have passed is a Python function (that with appropriate inputs would evaluate to what you want). See https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer#minimize
What you need to do is explicitly construct the vae_loss Tensor in the above code. During execution, the data will be propagated from your input layer down to this loss calculation.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With