Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Randomly sample from multiple tf.data.Datasets in Tensorflow

suppose I have N tf.data.Datasets and a list of N probabilities (summing to 1), now I would like to create dataset such that the examples are sampled from the N datasets with the given probabilities.

I would like this to work for arbitrary probabilities -> simple zip/concat/flatmap with fixed number of examples from each dataset is probably not what I am looking for.

Is it possible to do this in TF? Thanks!

like image 735
serycjon Avatar asked Apr 10 '18 09:04

serycjon


1 Answers

if p is a Tensor of probabilities (or unnormalized relative probabilities) where p[i] is the probability that dataset i is chosen, you can use tf.multinomial in conjunction with tf.contrib.data.choose_from_datasets:

# create some datasets and their unnormalized probability of being chosen
datasets = [
    tf.data.Dataset.from_tensors(['a']).repeat(),
    tf.data.Dataset.from_tensors(['b']).repeat(),
    tf.data.Dataset.from_tensors(['c']).repeat(),
    tf.data.Dataset.from_tensors(['d']).repeat()]
p = [1., 2., 3., 4.]  # unnormalized

# random choice function
def get_random_choice(p):
  choice = tf.multinomial(tf.log([p]), 1)
  return tf.cast(tf.squeeze(choice), tf.int64)

# assemble the "choosing" dataset
choice_dataset = tf.data.Dataset.from_tensors([0])  # create a dummy dataset
choice_dataset = choice_dataset.map(lambda x: get_random_choice(p))  # populate it with random choices
choice_dataset = choice_dataset.repeat()  # repeat

# obtain your combined dataset, assembled randomly from source datasets
# with the desired selection frequencies. 
combined_dataset = tf.contrib.data.choose_from_datasets(datasets, choice_dataset)

Note that the dataset needs to be initialized (you can't use a simple make_one_shot_iterator):

choice_iterator = combined_dataset.make_initializable_iterator()
choice = choice_iterator.get_next()
with tf.Session() as sess:
  sess.run(choice_iterator.initializer)
  print ''.join([sess.run(choice)[0] for _ in range(20)])

>> ddbcccdcccbbddadcadb
like image 83
eriophora Avatar answered Nov 15 '22 21:11

eriophora