Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Converting tokens to word vectors effectively with TensorFlow Transform

I would like to use TensorFlow Transform to convert tokens to word vectors during my training, validation and inference phase.

I followed this StackOverflow post and implemented the initial conversion from tokens to vectors. The conversion works as expected and I obtain vectors of EMB_DIM for each token.

import numpy as np
import tensorflow as tf

EMB_DIM = 10

def load_pretrained_glove():
    tokens = ["a", "cat", "plays", "piano"]
    return tokens, np.random.rand(len(tokens), EMB_DIM)

# sample string 
string_tensor = tf.constant(["plays", "piano", "unknown_token", "another_unknown_token"])

pretrained_vocab, pretrained_embs = load_pretrained_glove()

vocab_lookup = tf.contrib.lookup.index_table_from_tensor(
    mapping = tf.constant(pretrained_vocab),
    default_value = len(pretrained_vocab))
string_tensor = vocab_lookup.lookup(string_tensor)

# define the word embedding
pretrained_embs = tf.get_variable(
    initializer=tf.constant_initializer(np.asarray(pretrained_embs), dtype=tf.float32),

unk_embedding = tf.get_variable(
    shape=[1, EMB_DIM],
    initializer=tf.random_uniform_initializer(-0.04, 0.04),

embeddings = tf.cast(tf.concat([pretrained_embs, unk_embedding], axis=0), tf.float32)
word_vectors = tf.nn.embedding_lookup(embeddings, string_tensor)

with tf.Session() as sess:

When I refactor the code to run as a TFX Transform Graph, I am getting the error the ConversionError below.

import pprint
import tempfile
import numpy as np
import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam.impl as beam_impl
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import dataset_schema


EMB_DIM = 10

def load_pretrained_glove():
    tokens = ["a", "cat", "plays", "piano"]
    return tokens, np.random.rand(len(tokens), EMB_DIM)

def embed_tensor(string_tensor, trainable=False):
    Convert List of strings into list of indices then into EMB_DIM vectors

    pretrained_vocab, pretrained_embs = load_pretrained_glove()

    vocab_lookup = tf.contrib.lookup.index_table_from_tensor(
    string_tensor = vocab_lookup.lookup(string_tensor)

    pretrained_embs = tf.get_variable(
        initializer=tf.constant_initializer(np.asarray(pretrained_embs), dtype=tf.float32),
    unk_embedding = tf.get_variable(
        shape=[1, EMB_DIM],
        initializer=tf.random_uniform_initializer(-0.04, 0.04),

    embeddings = tf.cast(tf.concat([pretrained_embs, unk_embedding], axis=0), tf.float32)
    return tf.nn.embedding_lookup(embeddings, string_tensor)

def preprocessing_fn(inputs):
    input_string = tf.string_split(inputs['sentence'], delimiter=" ") 
    return {'word_vectors': tft.apply_function(embed_tensor, input_string)}

raw_data = [{'sentence': 'This is a sample sentence'},]
raw_data_metadata = dataset_metadata.DatasetMetadata(dataset_schema.Schema({
  'sentence': dataset_schema.ColumnSchema(
      tf.string, [], dataset_schema.FixedColumnRepresentation())

with beam_impl.Context(temp_dir=tempfile.mkdtemp()):
    transformed_dataset, transform_fn = (  # pylint: disable=unused-variable
        (raw_data, raw_data_metadata) | beam_impl.AnalyzeAndTransformDataset(

    transformed_data, transformed_metadata = transformed_dataset  # pylint: disable=unused-variable

Error Message

TypeError: Failed to convert object of type <class 
'tensorflow.python.framework.sparse_tensor.SparseTensor'> to Tensor. 
Contents: SparseTensor(indices=Tensor("StringSplit:0", shape=(?, 2), 
dtype=int64), values=Tensor("hash_table_Lookup:0", shape=(?,), 
dtype=int64), dense_shape=Tensor("StringSplit:2", shape=(2,), 
dtype=int64)). Consider casting elements to a supported type.


  1. Why would the TF Transform step require an additional conversion/casting?
  2. Is this approach of converting tokens to word vectors feasible? The word vectors might be multiple gigabytes in memory. How is Apache Beam handling the vectors? If Beam in a distributed setup, would it require N x vector memory with N the number of workers?
like image 642
Tony Yotto Avatar asked Jul 31 '18 05:07

Tony Yotto

1 Answers

The SparseTensor related error is because you are calling string_split which returns a SparseTensor. Your test code does not call string_split so that's why it only happens with your Transform code.

Regarding memory, you are correct, the embedding matrix must be loaded into each worker.

like image 109
Kester Tong Avatar answered Oct 15 '22 08:10

Kester Tong