Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Windowing a TensorFlow dataset without losing cardinality information?

tf.data.Dataset.window returns a new dataset, whose elements are datasets, and elements of those nested datasets are windows of the desired size. If you have a dataset (say, Dataset.range(10) and want a dataset of windows like [0 1 2] [1 2 3] ... [7 8 9]), there's a trick to do that with window plus flat_map:

>>> d = tf.data.Dataset.range(10).window(3, shift=1, drop_remainder=True).flat_map(lambda x: x.batch(3))
>>> print(list(d))
[<tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 1, 2])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 3, 4])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([3, 4, 5])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([4, 5, 6])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([5, 6, 7])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([6, 7, 8])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([7, 8, 9])>]

However, the flat_map causes the dataset to lose cardinality information:

>>> d.cardinality.numpy()
<tf.Tensor: shape=(), dtype=int64, numpy=-2>

(-2 is UNKNOWN_CARDINALITY; see Tensorflow 2.0: flat_map() to flatten Dataset of Dataset returns cardinality -2)

I would like to create a dataset of such windows, while retaining the cardinality information. One slight annoyance from working with datasets of unknown cardinality is that Keras training progress bars need to run on one epoch first before they can produce an ETA. I tried .take(n_windows) where I calculate n_windows myself, but that still returned a dataset with UNKNOWN_CARDINALITY.

Is there some way to window a dataset without losing cardinality information?

like image 515
Rai Avatar asked Dec 09 '25 17:12

Rai


1 Answers

The main issue is that cardinality is computed statically. Therefore the cardinality of a flat_map operation can not be computed. You can refer to this issue

The solution, as you know the relation of the flat_map inputs and output, is to set the cardinality yourself using tf.data.experimental.assert_cardinality.

This is an example on how to set back the window cardinality:

import tensorflow as tf

ds = tf.data.Dataset.range(10)
print("Original cardinality -> ", ds.cardinality().numpy())
# Output:
# Original cardinality -> 10

ds = ds.window(3, shift=1, drop_remainder=True)
# cardinality at this point is still known.
# as drop_remainder is true, window cardinality will be <= original cardinality
window_cardinality = ds.cardinality()
print("window cardinality ->",window_cardinality.numpy())
# Output:
# window cardinality -> 8

ds = ds.flat_map(lambda x: x.batch(3))
# after flat_map the inferred cardinality is lost.
print("flat cardinality ->",ds.cardinality().numpy())
# Output:
# flat cardinality -> -2

# as we know the flat_map relation is 1:1 we can set the cardinality back to the original value.
ds = ds.apply(tf.data.experimental.assert_cardinality(window_cardinality))
print("dataset cardinality ->",ds.cardinality().numpy())
print("length of dataset ->", len(list(ds)))
# Output: 
# dataset cardinality -> 8
# length of dataset -> 8

for idx, x in ds.enumerate():
    print(f"{idx} -> {x}")
# Output:
# 0 -> [0 1 2]
# 1 -> [1 2 3]
# 2 -> [2 3 4]
# 3 -> [3 4 5]
# 4 -> [4 5 6]
# 5 -> [5 6 7]
# 6 -> [6 7 8]
# 7 -> [7 8 9]
like image 64
jorpilo Avatar answered Dec 12 '25 07:12

jorpilo



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!