Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Assign indexed entry of Keras tensor

Tags:

python

keras

I'm something of a Keras beginner so my apologies in advance for any generally poor understanding.

I want to manually set some values of my Keras tensor according to, say, indices stored in another tensor. I believe I understand how to access entries of a tensor using tf.gather_nd (my untested attempt below), and I think I understand that I can only set values of a variable and not a tensor.

For clarity, this is taking place between the generation and discrimination stages of a GAN.

gen_out = generator(inputs)

indices_to_reset = Input(shape=(1,),dtype='int32')
new_values = Input(shape=(1,), dtype='int32')

batch_size = K.shape(x)[0]

idx_0 = K.reshape(K.arange(batch_size),(1,))
indices_to_reset = K.reshape(indices_to_reset, (1,))

idx = K.stack((idx_0, indices_to_reset), axis=0)

grabbed_entries = Lambda(lambda x: tf.gather_nd(gen_out,x))(idx)

# Doesn't work
# gen_out[:,indices_to_reset] = new_values

updated_gen_out = ???
like image 458
Mose Wintner Avatar asked Feb 15 '19 22:02

Mose Wintner


1 Answers

It's a lot easier if you transform everything into one_hot tensors and use switch:

(remember to have all operations inside a lambda layer, or you will have problems)

def replace_values(x):
    outs, indices, values = x

    #this is due to a strange bug between lambda and integers....
    indices = K.cast(indices, 'int32')


    #create one_hot indices
    one_hot_indices = K.one_hot(indices, size) #size is the size of gen_out
    one_hot_indices = K.batch_flatten(one_hot_indices)

    #have the desired values at their correct positions
    values_to_use = one_hot_indices * new_values


    #if values are 0, use gen_out, else use values
    return K.switch(K.equal(values_to_use, 0), outs, values_to_use)


updated_gen_out = Lambda(replace_values)([gen_out, indices_to_reset, new_values])

Warning: new_values cannot be integer, they must be the same type as gen_out.


Dummy example:

import numpy as np
from keras.layers import *
from keras.models import Model

size = 5
batch_size = 15

gen_out = Input((size,))
indices_to_reset = Input((1,), dtype='int32')
new_values = Input((1,))

def replace_values(x):
    outs, indices, values = x
    print(K.int_shape(outs))
    print(K.int_shape(indices))

    #this is due to a strange bug between lambda and integers....
    indices = K.cast(indices, 'int32')
    one_hot_indices = K.one_hot(indices, size)
    print(K.int_shape(one_hot_indices))
    one_hot_indices = K.batch_flatten(one_hot_indices)
    print(K.int_shape(one_hot_indices))

    values_to_use = one_hot_indices * new_values
    print(K.int_shape(values_to_use))

    return K.switch(K.equal(values_to_use, 0), outs, values_to_use)

updated_gen_out = Lambda(replace_values)([gen_out, indices_to_reset, new_values])

model = Model([gen_out,indices_to_reset,new_values], updated_gen_out)

gen_outs = np.arange(batch_size * size).reshape((batch_size, size))
indices = np.concatenate([np.arange(5)]*3, axis=0)
new_vals = np.arange(15).reshape((15,1))

print('\n\ngen outs')
print(gen_outs)

print('\n\nindices')
print(indices)

print('\n\nvalues')
print(new_vals)

print('\n\n results')
print(model.predict([gen_outs, indices, new_vals]))

Outputs:

(None, 5)
(None, 1)
(None, 1, 5)
(None, None)
(None, None)
(None, 5)
(None, 1)
(None, 1, 5)
(None, None)
(None, None)


gen outs
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]
 [25 26 27 28 29]
 [30 31 32 33 34]
 [35 36 37 38 39]
 [40 41 42 43 44]
 [45 46 47 48 49]
 [50 51 52 53 54]
 [55 56 57 58 59]
 [60 61 62 63 64]
 [65 66 67 68 69]
 [70 71 72 73 74]]


indices
[0 1 2 3 4 0 1 2 3 4 0 1 2 3 4]


values
[[ 0]
 [ 1]
 [ 2]
 [ 3]
 [ 4]
 [ 5]
 [ 6]
 [ 7]
 [ 8]
 [ 9]
 [10]
 [11]
 [12]
 [13]
 [14]]


 results
[[ 0.  1.  2.  3.  4.]
 [ 5.  1.  7.  8.  9.]
 [10. 11.  2. 13. 14.]
 [15. 16. 17.  3. 19.]
 [20. 21. 22. 23.  4.]
 [ 5. 26. 27. 28. 29.]
 [30.  6. 32. 33. 34.]
 [35. 36.  7. 38. 39.]
 [40. 41. 42.  8. 44.]
 [45. 46. 47. 48.  9.]
 [10. 51. 52. 53. 54.]
 [55. 11. 57. 58. 59.]
 [60. 61. 12. 63. 64.]
 [65. 66. 67. 13. 69.]
 [70. 71. 72. 73. 14.]] 

Notice the diagonal values of gen_outs replaced with the values in new_vals.

like image 190
Daniel Möller Avatar answered Oct 21 '22 02:10

Daniel Möller