tensorflow-datasets

Querying the structure of a Tensorflow Dataset -- rows, columns, shape


I am setting up data (training, test, and validation) to load to a simple TensorFlow model (Tensorflow v2.0). I'd like to browse the number of rows and columns in the concrete TensorFlow datasets I'm building.

If the dataset has the _tensors attribute set, I can use the technique in How to get number of rows, columns /dimensions of tensorflow.data.Dataset?.

However, It looks like some Datasets do not have tensors. For example,

a = np.array([e for e in range(10)])
df = pd.DataFrame({'a':a,'b':a,'c':a})

target = df.pop('c')

dataset = tf.data.Dataset.from_tensor_slices((df.values, target.values))

print([x.get_shape().as_list() for x in dataset._tensors])  
# Works.  Gives:
#[[10, 2], [10]]

train_dataset = dataset.shuffle(len(df)).batch(1)

print([x.get_shape().as_list() for x in train_dataset._tensors]) 
# ** Fails.  Gives:
# AttributeError: 'BatchDataset' object has no attribute '_tensors'

I see that these are different types of datasets (TensorSliceDataset vs BatchDataset):

dataset
Out[109]: <TensorSliceDataset shapes: ((2,), ()), types: (tf.int32, tf.int32)>

train_dataset
Out[110]: <BatchDataset shapes: ((None, 2), (None,)), types: (tf.int32, tf.int32)>

It looks like the following gives the number of rows in a tensor:

print(len([e for e in dataset]))
#Gives:
#10

print(len([e for e in train_dataset]))
#Gives:
#10

The following iterates over the objects enter code hereenumerated:

r=0
for t in dataset:
    for e in t:
        r+=1
        tf.print('Row #{0}={1}'.format(r,e))

Output:

Row #1=[0 0]
Row #2=0
Row #3=[1 1]
Row #4=1
Row #5=[2 2]
...
Row #17=[8 8]
Row #18=8
Row #19=[9 9]
Row #20=9

Is there a better approach?


Solution

  • Based in amish's response, I think that a simpler way to get the "shape" of the datasets is:

    print(len(dataset), dataset)
    # Gives:
    # 10 <TensorSliceDataset shapes: ((2,), ()), types: (tf.int32, tf.int32)>
    
    print(len(train_dataset), train_dataset)
    # Gives 
    # 10 <BatchDataset shapes: ((None, 2), (None,)), types: (tf.int32, tf.int32)>