Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using a Tensorflow feature_column in a Keras model

How can a Tensorflow feature_column be used in conjunction with a Keras model?

E.g. for a Tensorflow estimator, we can use an embedding column from Tensorflow Hub:

embedded_text_feature_column = hub.text_embedding_column(
    key="sentence", 
    module_spec="https://tfhub.dev/google/nnlm-en-dim128/1")

estimator = tf.estimator.DNNClassifier(
hidden_units=[100],
feature_columns=[embedded_text_feature_column],
n_classes=2,
optimizer=tf.train.AdamOptimizer(learning_rate=0.001))

However, I would like to use the TF Hub text_embedding_column as input to a Keras model. E.g.

net = tf.keras.layers.Input(...) # use embedding column here
net = tf.keras.layers.Flatten()
net = Dense(100, activation='relu')(net)
net = Dense(2)(net)

Is this possible?

like image 235
avanwyk Avatar asked May 21 '18 19:05

avanwyk


2 Answers

The answer seems to be that you don't use feature columns. Keras comes with its own set of preprocessing functions for images and text, so you can use those.

So basically the tf.feature_columns are reserved for the high level API. Then the tf.keras.preprocessing() functions are used with tf.keras models.

Here is a link to the section on preprocessing data in the keras documentation. https://keras.io/preprocessing/text/

Here is another Stackoverflow post that has an example of this approach.

Add Tensorflow pre-processing to existing Keras model (for use in Tensorflow Serving)

like image 80
krishnab Avatar answered Oct 14 '22 23:10

krishnab


The keras functional api is a viable way to do this, but if you want to use feature_columns this tutorial shows you how:

https://www.tensorflow.org/beta/tutorials/keras/feature_columns

Basically it's this DenseFeatures layer that does the job:

feature_layer = tf.keras.layers.DenseFeatures(feature_columns)

model = tf.keras.Sequential([
  feature_layer,
  layers.Dense(128, activation='relu'),
  layers.Dense(128, activation='relu'),
  layers.Dense(1, activation='sigmoid')
])
like image 28
mdaoust Avatar answered Oct 15 '22 00:10

mdaoust