pythontensorflowtensorflow-datasets

how to get string value out of tf.tensor which dtype is string


I want to use tf.data.Dataset.list_files function to feed my datasets.
But because the file is not image, I need to load it manually.
The problem is tf.data.Dataset.list_files pass variable as tf.tensor and my python code can not handle tensor.

How can I get string value from tf.tensor. The dtype is string.

train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: load_audio_file(x))

def load_audio_file(file_path):
  print("file_path: ", file_path)
  # i want do something like string_path = convert_tensor_to_string(file_path)

file_path is Tensor("arg0:0", shape=(), dtype=string)

I use tensorflow 1.13.1 and eager mode.

thanks in advance


Solution

  • You can use tf.py_func to wrap load_audio_file().

    import tensorflow as tf
    
    tf.enable_eager_execution()
    
    def load_audio_file(file_path):
        # you should decode bytes type to string type
        print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
        return file_path
    
    train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
    train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))
    
    for one_element in train_dataset:
        print(one_element)
    
    file_path:  clean_4s_val/1.wav <class 'str'>
    (<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
    file_path:  clean_4s_val/3.wav <class 'str'>
    (<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
    file_path:  clean_4s_val/2.wav <class 'str'>
    (<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)
    

    UPDATE for TF 2

    The above solution will not work with TF 2 (tested with 2.2.0), even when replacing tf.py_func with tf.py_function, giving

    InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'
    

    To make it work in TF 2, make the following changes: