Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to access Tensor shape within .map function?

I have a dataset of audios in multiple lengths, and I want to crop all of them in 5 second windows (which means 240000 elements with 48000 sample rate). So, after loading the .tfrecord, I'm doing:

audio, sr = tf.audio.decode_wav(image_data)

which returns me a Tensor that has the audio length. If this length is less than the 240000 I would like to repeat the audio content til it's 240000. So I'm doing on ALL audios, with a tf.data.Dataset.map() function:

audio = tf.tile(audio, [5])

Since that's what it takes to pad my shortest audio to the desired length.

But for efficiency I wanted to do the operation only on elements that need it:

if audio.shape[0] < 240000:
  pad_num = tf.math.ceil(240000 / audio.shape[0]) #i.e. if the audio is 120000 long, the audio will repeat 2 times
  audio = tf.tile(audio, [pad_num])

But I can't access the shape property since it's dynamic and varies across the audios. I've tried using tf.shape(audio), audio.shape, audio.get_shape(), but I get values like None for the shape, that doesn't allow me to do the comparison.

Is it possible to do this?

like image 637
Leonardo Avatar asked Dec 06 '25 05:12

Leonardo


1 Answers

You can use a function like this:

import tensorflow as tf

def enforce_length(audio):
    # Target shape
    AUDIO_LEN = 240_000
    # Current shape
    current_len = tf.shape(audio)[0]
    # Compute number of necessary repetitions
    num_reps = AUDIO_LEN // current_len
    num_reps += tf.dtypes.cast((AUDIO_LEN % current_len) > 0, num_reps.dtype)
    # Do repetitions
    audio_rep = tf.tile(audio, [num_reps])
    # Trim to required size
    return audio_rep[:AUDIO_LEN]

# Test
examples = tf.data.Dataset.from_generator(lambda: iter([
    tf.zeros([100_000], tf.float32),
    tf.zeros([300_000], tf.float32),
    tf.zeros([123_456], tf.float32),
]), output_types=tf.float32, output_shapes=[None])
result = examples.map(enforce_length)
for item in result:
    print(item.shape)

Output:

(240000,)
(240000,)
(240000,)
like image 170
jdehesa Avatar answered Dec 09 '25 06:12

jdehesa