Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow hashtable lookup with arrays

I'm trying to get HashMap type of functionality to work with tensorflow. I got it to work when keys and values are of int type. But when they are arrays it gives error - ValueError: Shapes (2,) and () are not compatible on line default_value)

import numpy as np
import tensorflow as tf


input_tensor = tf.constant([1, 1], dtype=tf.int64)
keys = tf.constant(np.array([[1, 1],[2, 2],[3, 3]]),  dtype=tf.int64)
values = tf.constant(np.array([[4, 1],[5, 1],[6, 1]]),  dtype=tf.int64)
default_value = tf.constant(np.array([1, 1]),  dtype=tf.int64)

table = tf.contrib.lookup.HashTable(
        tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
        default_value)

out = table.lookup(input_tensor)
with tf.Session() as sess:
    table.init.run()
    print(out.eval())
like image 570
Mihkel L. Avatar asked May 13 '18 12:05

Mihkel L.


1 Answers

Unfortunately, tf.contrib.lookup.HashTable only works with one dimensional tensors. Here's an implementation with tf.SparseTensors, which of course only works if your keys are integer (int32 or int64) tensors.

For the values I'm storing the two columns in two separate tensors, but if you have many columns, you might want to just store them in a large tensor, and store the indices as values in one tf.SparseTensor.

This code (tested):

import tensorflow as tf

lookup = tf.placeholder( shape = ( 2, ), dtype = tf.int64 )
default_value = tf.constant( [ 1, 1 ], dtype = tf.int64 )
input_tensor = tf.constant( [ 1, 1 ], dtype=tf.int64)
keys = tf.constant( [ [ 1, 2 ], [ 3, 4 ], [ 5, 6 ] ],  dtype=tf.int64 )
values = tf.constant( [ [ 4, 1 ], [ 5, 1 ], [ 6, 1 ] ],  dtype=tf.int64 )
val0 = values[ :, 0 ]
val1 = values[ :, 1 ]

st0 = tf.SparseTensor( keys, val0, dense_shape = ( 7, 7 ) )
st1 = tf.SparseTensor( keys, val1, dense_shape = ( 7, 7 ) )

x0 = tf.sparse_slice( st0, lookup, [ 1, 1 ] )
y0 = tf.reshape( tf.sparse_tensor_to_dense( x0, default_value = default_value[ 0 ] ), () )
x1 = tf.sparse_slice( st1, lookup, [ 1, 1 ] )
y1 = tf.reshape( tf.sparse_tensor_to_dense( x1, default_value = default_value[ 1 ] ), () )

y = tf.stack( [ y0, y1 ], axis = 0 )

with tf.Session() as sess:
    print( sess.run( y, feed_dict = { lookup : [ 1, 2 ] } ) )
    print( sess.run( y, feed_dict = { lookup : [ 1, 1 ] } ) )

will output:

[4 1]
[1 1]

as desired (looks up the value [ 4, 1 ] for the key [ 1, 2 ] and the default value [ 1, 1 ] for [ 1, 1 ], which points to a non-existent entry.)

like image 166
Peter Szoldan Avatar answered Oct 12 '22 17:10

Peter Szoldan