I'm trying to learn tensorflow and I'm getting the following error: logits and labels must be broadcastable: logits_size=[32,1] labels_size=[16,1]
The code runs fine when I got this as input:
self.input = np.ones((500, 784))
self.y = np.ones((500, 1))
However, when I add and extra dimension the error is thrown:
self.input = np.ones((500, 2, 784))
self.y = np.ones((500, 1))
The code to build the graph
self.x = tf.placeholder(tf.float32, shape=[None] + self.config.state_size)
self.y = tf.placeholder(tf.float32, shape=[None, 1])
# network architecture
d1 = tf.layers.dense(self.x, 512, activation=tf.nn.relu, name="dense1")
d2 = tf.layers.dense(d1, 1, name="dense2")
with tf.name_scope("loss"):
self.cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=d2))
self.train_step = tf.train.AdamOptimizer(self.config.learning_rate).minimize(self.cross_entropy,
global_step=self.global_step_tensor)
correct_prediction = tf.equal(tf.argmax(d2, 1), tf.argmax(self.y, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
Could someone explain me why this is happening and how I can fix this?
logits
is the name typically given to the output of the network, these are your predictions. A size of [32, 10]
tells me that you have a batch size of 32, and 10 outputs, such as is common with mnist, as you appear to be working with.
Your labels are sized [16, 10]
, which is to say, you're providing 16 labels/vectors of size 10. The number of labels you're providing is in conflict with the output of the network, they should be the same.
I'm not quite clear what you're doing with the extra dimension in the input, but I guess you must be accidentally doubling the samples in some way. Perhaps the [500, 2, 784]
shape is being reshaped to [1000, 784]
automatically somewhere along the way, which is then not matching the 500 labels. Also, your self.y
should be shaped [500, 10]
not, [500, 1]
, your labels need to be in one-hot encoding format. E.g. a single label of shape [1, 10]
for digit 3 would be [[0,0,0,1,0,0,0,0,0,0,0]]
, not in digit representation, e.g. [3]
as you seem to have it set up in your sanity-test here.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With