Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to select specific columns from tensorflow dataset?

I am training my Tensorflow model on data from a CSV file preprocessed by tf.data.Dataset. However, I want the model to fork into three branches corresponding to a different set of csv columns, and model.fit requires a separate dataset for each output. All columns of the CSV file need to undergo the same preprocessing, so the most efficient way of preparing it would be to load the whole file, process it, and then split the dataset into three parts. However, I am struggling to find a way of doing so.

I hoped that dataset.map would allow me to select some columns using the following operation:

dset = dset.map(lambda x: x[[1, 2, 3, 7]])

but it seems that tensorflow interprets it as x[1][2][3][7] instead.

The only working way of creating separate datasets that I've found was to do it from the beginning:

y = []
for cls, keys in output_classes.items():
    tmp = tf.data.experimental.CsvDataset(data_path, [tf.int32 for i in keys], select_cols=keys)
    [...]
    y.append(tmp)
y = tf.data.Dataset.zip(tuple(y))

Unfortunately, it producess a lot of unnecessary overhead and immensely slows down the training.

Is there a way of splitting tf.data.Dataset object by a subset of features?

like image 343
Ginterhauser Avatar asked Nov 25 '19 14:11

Ginterhauser


People also ask

How do you select a column of a tensor?

To achieve what you want, we can first transpose the tensor t from which you want to select certain columns from. Then look up the rows of tf. transpose(t) (columns of t ). After the selection, we transpose the result back.

What are crossed columns?

Crossed feature columns. Combining features into a single feature, better known as feature crosses, enables a model to learn separate weights for each combination of features.

What is a PrefetchDataset?

public final class PrefetchDataset. Creates a dataset that asynchronously prefetches elements from `input_dataset`.


Video Answer


2 Answers

Try tf.gather:

tf.gather(tf.constant([1,2,3,4]), [1,2,3])
# ouputs : array([2, 3, 4])

If you have high dimensional data, use tf.gather_nd.

like image 63
tornikeo Avatar answered Oct 20 '22 13:10

tornikeo


This solution has worked for me by modifying tornikeo's answer with a .map().

dataset = tf.data.Dataset.from_tensor_slices([[1,2,3,4], 
                                              [5,6,7,8]])
dataset_filter = dataset.map(lambda x: tf.gather(x, [0, 2], axis=0))
result = list(dataset_filter.as_numpy_iterator())
print(result)

# Outputs array([1, 3], dtype=int32), array([5, 7])
like image 1
theudbald Avatar answered Oct 20 '22 13:10

theudbald