I'm working on implementing a semantic segmentation network in Tensorflow, and I'm trying to figure out how to write out summary images of the labels during training. I want to encode the images in a similar style to the class segmentation annotations used in the Pascal VOC dataset.
For example, let's assume I have a network that trains on a batch size of 1 with 4 classes. The networks final predictions have shape [1, 3, 3, 4]
Essentially I want to take the output predictions and run it through argmin
to get a tensor containing the most likely class at each point in the output:
[[[0, 1, 3],
[2, 0, 1],
[3, 1, 2]]]
The annotated images use a color palette of 255 colors to encode labels. I have a tensor containing all the color triples:
[[ 0, 0, 0],
[128, 0, 0],
[ 0, 128, 0],
[128, 128, 0],
[ 0, 0, 128],
...
[224, 224, 192]]
How could I obtain a tensor of shape [1, 3, 3, 3]
(a single 3x3 color image) that indexes into the color palette using the values obtained from argmin
?
[[palette[0], palette[1], palette[3]],
[palette[2], palette[0], palette[1]],
[palette[3], palette[1], palette[2]]]
I could easily wrap some numpy and PIL code in tf.py_func
but I'm wondering if there is a pure Tensorflow way of obtaining this result.
EDIT:
For those curious, this is the solution I got using just numpy. It works quite well, but I still dislike the use of tf.py_func
:
import numpy as np
import tensorflow as tf
def voc_colormap(N=256):
bitget = lambda val, idx: ((val & (1 << idx)) != 0)
cmap = np.zeros((N, 3), dtype=np.uint8)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r |= (bitget(c, 0) << 7 - j)
g |= (bitget(c, 1) << 7 - j)
b |= (bitget(c, 2) << 7 - j)
c >>= 3
cmap[i, :] = [r, g, b]
return cmap
VOC_COLORMAP = voc_colormap()
def grayscale_to_voc(input, name="grayscale_to_voc"):
return tf.py_func(grayscale_to_voc_impl, [input], tf.uint8, stateful=False, name=name)
def grayscale_to_voc_impl(input):
return np.squeeze(VOC_COLORMAP[input])
You can use tf.gather_nd(), but you will need to modify the shapes of the palette and logits to obtain the desired image, for example:
import tensorflow as tf
import numpy as np
import PIL.Image as Image
# We can load the palette from some random image in the PASCAL VOC dataset
palette = Image.open('.../VOC2012/SegmentationClass/2007_000032.png').getpalette()
# We build a random logits tensor of the requested size
batch_size = 1
height = width = 3
num_classes = 4
np.random.seed(1234)
logits = np.random.random_sample((batch_size, height, width, num_classes))
logits_argmax = np.argmax(logits, axis=3) # shape = (1, 3, 3)
# array([[[3, 3, 0],
# [1, 3, 1],
# [0, 2, 0]]])
sess = tf.InteractiveSession()
image = tf.gather_nd(
params=tf.reshape(palette, [-1, 3]), # reshaped from list to RGB
indices=tf.reshape(logits_argmax, [batch_size, -1, 1]))
image = tf.cast(tf.reshape(image, [batch_size, height, width, 3]), tf.uint8)
sess.run(image)
# array([[[[128, 128, 0],
# [128, 128, 0],
# [ 0, 0, 0]],
# [[128, 0, 0],
# [128, 128, 0],
# [128, 0, 0]],
# [[ 0, 0, 0],
# [ 0, 128, 0],
# [ 0, 0, 0]]]], dtype=uint8)
The resulting tensor can be directly fed to a tf.summary.image(), but depending on your implementation you should upsample it before the summary.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With