I'm trying to get HashMap
type of functionality to work with tensorflow. I got it to work when keys and values are of int
type. But when they are arrays it gives error - ValueError: Shapes (2,) and () are not compatible
on line default_value)
import numpy as np
import tensorflow as tf
input_tensor = tf.constant([1, 1], dtype=tf.int64)
keys = tf.constant(np.array([[1, 1],[2, 2],[3, 3]]), dtype=tf.int64)
values = tf.constant(np.array([[4, 1],[5, 1],[6, 1]]), dtype=tf.int64)
default_value = tf.constant(np.array([1, 1]), dtype=tf.int64)
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_value)
out = table.lookup(input_tensor)
with tf.Session() as sess:
table.init.run()
print(out.eval())
Unfortunately, tf.contrib.lookup.HashTable
only works with one dimensional tensors. Here's an implementation with tf.SparseTensor
s, which of course only works if your keys are integer (int32 or int64) tensors.
For the values I'm storing the two columns in two separate tensors, but if you have many columns, you might want to just store them in a large tensor, and store the indices as values in one tf.SparseTensor
.
This code (tested):
import tensorflow as tf
lookup = tf.placeholder( shape = ( 2, ), dtype = tf.int64 )
default_value = tf.constant( [ 1, 1 ], dtype = tf.int64 )
input_tensor = tf.constant( [ 1, 1 ], dtype=tf.int64)
keys = tf.constant( [ [ 1, 2 ], [ 3, 4 ], [ 5, 6 ] ], dtype=tf.int64 )
values = tf.constant( [ [ 4, 1 ], [ 5, 1 ], [ 6, 1 ] ], dtype=tf.int64 )
val0 = values[ :, 0 ]
val1 = values[ :, 1 ]
st0 = tf.SparseTensor( keys, val0, dense_shape = ( 7, 7 ) )
st1 = tf.SparseTensor( keys, val1, dense_shape = ( 7, 7 ) )
x0 = tf.sparse_slice( st0, lookup, [ 1, 1 ] )
y0 = tf.reshape( tf.sparse_tensor_to_dense( x0, default_value = default_value[ 0 ] ), () )
x1 = tf.sparse_slice( st1, lookup, [ 1, 1 ] )
y1 = tf.reshape( tf.sparse_tensor_to_dense( x1, default_value = default_value[ 1 ] ), () )
y = tf.stack( [ y0, y1 ], axis = 0 )
with tf.Session() as sess:
print( sess.run( y, feed_dict = { lookup : [ 1, 2 ] } ) )
print( sess.run( y, feed_dict = { lookup : [ 1, 1 ] } ) )
will output:
[4 1]
[1 1]
as desired (looks up the value [ 4, 1 ] for the key [ 1, 2 ] and the default value [ 1, 1 ] for [ 1, 1 ], which points to a non-existent entry.)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With