I want to prepare the omniglot dataset for n-shot learning. Therefore I need 5 samples from 10 classes (alphabet)
Code to reproduce
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
builder = tfds.builder("omniglot")
# assert builder.info.splits['train'].num_examples == 60000
builder.download_and_prepare()
# Load data from disk as tf.data.Datasets
datasets = builder.as_dataset()
dataset, test_dataset = datasets['train'], datasets['test']
def resize(example):
image = example['image']
image = tf.image.resize(image, [28, 28])
image = tf.image.rgb_to_grayscale(image, )
image = image / 255
one_hot_label = np.zeros((51, 10))
return image, one_hot_label, example['alphabet']
def stack(image, label, alphabet):
return (image, label), label[-1]
def filter_func(image, label, alphabet):
# get just images from alphabet in array, not just 2
arr = np.array(2,3,4,5)
result = tf.reshape(tf.equal(alphabet, 2 ), [])
return result
# correct size
dataset = dataset.map(resize)
# now filter the dataset for the batch
dataset = dataset.filter(filter_func)
# infinite stream of batches (classes*samples + 1)
dataset = dataset.repeat().shuffle(1024).batch(51)
# stack the images together
dataset = dataset.map(stack)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(32)
for i, (image, label) in enumerate(tfds.as_numpy(dataset)):
print(i, image[0].shape)
Now I want to filter the images in the dataset by using the filter function. tf.equal just let me filter by one class, I want something like tensor in array.
Do you see a way doing this with the filter function? Or is this the wrong way and there is a much simpler way?
I want to create a batch of 51 images and according labels, which are from the same N=10 classes. From every class, I need K=5 different images and an additional one (which I need to classify). Every batch of N*K+1 (51) images should be from 10 new random classes.
Thank you very much in advance.
To KEEP only specific labels use this predicate:
dataset = datasets['train']
def predicate(x, allowed_labels=tf.constant([0, 1, 2])):
label = x['label']
isallowed = tf.equal(allowed_labels, tf.cast(label, allowed_labels.dtype))
reduced = tf.reduce_sum(tf.cast(isallowed, tf.float32))
return tf.greater(reduced, tf.constant(0.))
dataset = dataset.filter(predicate).batch(20)
for i, x in enumerate(tfds.as_numpy(dataset)):
print(x['label'])
# [1 0 0 1 2 1 1 2 1 0 0 1 2 0 1 0 2 2 0 1]
# [1 0 2 2 0 2 1 2 1 2 2 2 0 2 0 2 1 2 1 1]
# [2 1 2 1 0 1 1 0 1 2 2 0 2 0 1 0 0 0 0 0]
allowed_labels
specifies labels you want to keep. All labels that are not in this tensor will be filtered out.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With