python-3.xtensorflowtensorflow-federatedfederated-learningfederated

How to access labels with TFF


I was following this Image classification tutorial and Text Generation tutorial. So I've implemented transfer learning with fine-tuning on my dataset but I don't know how to access labels whenever I am doing predictions. I transformed my data into the right shape (tf.data.Dataset) so I am using the Keras model for predictions. So for example if I want just to predict one label: keras_model.predict(federated_train_data[0])

federated_train_data consists of following elements:

(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None,), dtype=tf.int64, name=None))

First Tensor is an image shape and the second one represents encoded labels.

My goal is to illustrate what are true and predicted labels of an image, for example:(Predicted classes)

TLDR: Is there a way that you can access just labels when you have tf.data.Dataset?


Solution

  • If federated_train_data is a tf.data.Dataset whose .element_spec property returns:

    (TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None),
     TensorSpec(shape=(None,), dtype=tf.int64, name=None))
    

    Then iterating over the dataset is possible:

    # Get the first batch
    first_batch = next(iter(federated_train_data)) 
    
    # Examine all batches
    for batch in federated_train_data:
      print(batch)
    

    From the .element_spec we know each batch is a 2-tuple of (features, labels), so we can get the labels using the second index:

    labesl = first_batch[1]
    
    # Or unpack
    features, labels = first_batch
    

    Combining this with the model predictions:

    for batch in federated_train_data:
      features, labels = batch
      predictions = keras_model.predict(features)
      # Now we have all three pieces: features, labels, and predictions.