Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to implement a custom keras layer that only keeps the top n values and zeros out all the rest?

I am trying to implement a custom Keras layer that will keep only the top N values of the input and convert all of the rest to zeros. I have one version that mostly works, but leaves more than N values if there are ties. I would like to use a sort function to always only leaves N non-zero values.

Here is the mostly working layer that leaves more than N values when there are ties:

def top_n_filter_layer(input_data, n=2, tf_dtype=tf_dtype):

    #### Works, but returns more than 2 values if there are ties:
    values_to_keep = tf.cast(tf.nn.top_k(input_data, k=n, sorted=True).values, tf_dtype)
    min_value_to_keep = tf.cast(tf.math.reduce_min(values_to_keep), tf_dtype)
    mask = tf.math.greater_equal(tf.cast(input_data, tf_dtype), min_value_to_keep)
    zeros = tf.zeros_like(input_data)
    output = tf.where(mask, input_data, zeros)

    return output

Here is the sorting method I'm working on, but I'm getting stuck with the tf.scatter_update function complaining about rank mismatches:

from keras.layers import Input
import tensorflow as tf
import numpy as np

tf_dtype = 'float32'

def top_n_filter_layer(input_data, n=2, tf_dtype=tf_dtype):

    indices_to_keep = tf.argsort(input_data, axis=1, direction='DESCENDING', stable=True)
    indices_to_keep = tf.slice(indices_to_keep, [0,0], [-1, n])

    values_to_keep = tf.sort(input_data, axis=1, direction='DESCENDING')
    values_to_keep = tf.slice(values_to_keep, [0,0], [-1, n])

    zeros = tf.zeros_like(input_data, dtype=tf_dtype)

    zeros_variable = tf.Variable(0.0) # Since scatter_update requires _lazy_read
    zeros_variable = tf.assign(zeros_variable, zeros, validate_shape=False)

    output = tf.scatter_update(zeros_variable, indices_to_keep, values_to_keep)

    return output

tf.reset_default_graph()
np.random.seed(0)
input_data = np.random.uniform(size=(2,10))

input_layer = Input(shape=(10,))
output_data = top_n_filter_layer(input_layer)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    result = sess.run({'output': output_data}, feed_dict={input_layer:input_data})
    print(result)

Here is the traceback:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1658   try:
-> 1659     c_op = c_api.TF_FinishOperation(op_desc)
   1660   except errors.InvalidArgumentError as e:

InvalidArgumentError: Shapes must be equal rank, but are 2 and 3 for 'ScatterUpdate' (op: 'ScatterUpdate') with input shapes: [?,10], [?,2], [?,2].

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-10-598e009077f8> in <module>()
     27 
     28 input_layer = Input(shape=(10,))
---> 29 output_data = top_n_filter_layer(input_layer)
     30 
     31 with tf.Session() as sess:

<ipython-input-10-598e009077f8> in top_n_filter_layer(input_data, n, tf_dtype)
     18     zeros_variable = tf.assign(zeros_variable, zeros, validate_shape=False)
     19 
---> 20     output = tf.scatter_update(zeros_variable, indices_to_keep, values_to_keep)
     21 
     22     return output

/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py in scatter_update(ref, indices, updates, use_locking, name)
    297   if ref.dtype._is_ref_dtype:
    298     return gen_state_ops.scatter_update(ref, indices, updates,
--> 299                                         use_locking=use_locking, name=name)
    300   return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update(  # pylint: disable=protected-access
    301       ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),

/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/gen_state_ops.py in scatter_update(ref, indices, updates, use_locking, name)
   1273   _, _, _op = _op_def_lib._apply_op_helper(
   1274         "ScatterUpdate", ref=ref, indices=indices, updates=updates,
-> 1275                          use_locking=use_locking, name=name)
   1276   _result = _op.outputs[:]
   1277   _inputs_flat = _op.inputs

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    786         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    787                          input_types=input_types, attrs=attr_protos,
--> 788                          op_def=op_def)
    789       return output_structure, op_def.is_stateful, op
    790 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in create_op(***failed resolving arguments***)
   3298           input_types=input_types,
   3299           original_op=self._default_original_op,
-> 3300           op_def=op_def)
   3301       self._create_op_helper(ret, compute_device=compute_device)
   3302     return ret

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
   1821           op_def, inputs, node_def.attr)
   1822       self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1823                                 control_input_ops)
   1824 
   1825     # Initialize self._outputs.

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1660   except errors.InvalidArgumentError as e:
   1661     # Convert to ValueError for backwards compatibility.
-> 1662     raise ValueError(str(e))
   1663 
   1664   return c_op

ValueError: Shapes must be equal rank, but are 2 and 3 for 'ScatterUpdate' (op: 'ScatterUpdate') with input shapes: [?,10], [?,2], [?,2].

@Vlad's answer below shows a working method using one-hot encoding. Here is an example that shows it working:

import tensorflow as tf
import numpy as np

tf.reset_default_graph()

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer((10,)))

def top_n_filter_layer(input_data, n=2):

    topk = tf.nn.top_k(input_data, k=n, sorted=False)

    res = tf.reduce_sum(                                 
        tf.one_hot(topk.indices,                         
                   input_data.get_shape().as_list()[-1]), 
        axis=1)                                          

    res *= input_data

    return res

model.add(tf.keras.layers.Lambda(top_n_filter_layer))

x_train = [[1,2,3,4,5,6,7,7,7,7]]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(model.output.eval({model.inputs[0]:x_train}))

# [[0. 0. 0. 0. 0. 0. 7. 7. 0. 0.]]
like image 873
cnash Avatar asked Apr 12 '19 10:04

cnash


People also ask

How do you freeze a layer in Keras?

Freeze all layers in the base model by setting trainable = False . Create a new model on top of the output of one (or several) layers from the base model. Train your new model on your new dataset.

When should you avoid using the Keras function adapt ()?

Typically, a vocabulary larger than 500MB would be considered "very large". In such a case, for best performance, you should avoid using adapt() .

What is model subclassing in Keras?

In Model Sub-Classing there are two most important functions __init__ and call. Basically, we will define all the tf. keras layers or custom implemented layers inside the __init__ method and call those layers based on our network design inside the call method which is used to perform a forward propagation.


1 Answers

Let's do it step by step:

  1. First we take the softmaxed output of the network and find its top k values and their indices.
  2. We create a one-hot encoded vector such that each vector has ones at the location of top k indices. We then sum up k such vectors to get the original output shape with exactly k ones.
  3. Once we have a tensor with ones at the top k location we do element-wise multiplication with original softmax output of the network.

Tensorflow example for top k=2 values:

import tensorflow as tf
import numpy as np

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(
    units=5, input_shape=(2, ), activation=tf.nn.softmax,
    kernel_initializer=tf.initializers.random_normal))

softmaxed = model.output # <-- take the *softmaxed* output
topk = tf.nn.top_k(softmaxed,    # <-- find its top k values and their indices
                   k=2,
                   sorted=False)

res = tf.reduce_sum(                                 # <-- create a one-hot encoded
    tf.one_hot(topk.indices,                         #     vectors out of top k indices
               softmaxed.get_shape().as_list()[-1]), #     and sum each k of them to
    axis=1)                                          #     create a single binary tensor

res *= softmaxed # <-- element-wise multiplication

x_train = [np.random.normal(size=(2, ))] # <-- train data

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(res.eval({model.inputs[0]:x_train})) # [[0.2 0.2 0.  0.  0. ]]
    print(softmaxed.eval({model.inputs[0]:x_train})) # [[0.2 0.2 0.2 0.2 0.2]]
like image 163
Vlad Avatar answered Oct 23 '22 03:10

Vlad