tensorflow-datasetstensorflow-federatedfederated-learning

Running Out of RAM using FilePerUserClientData


I have a problem with training using tff.simulation.FilePerUserClientData - I am quickly running out of RAM after 5-6 rounds with 10 clients per round. The RAM usage is steadily increasing with each round. I tried to narrow it down and realized that the issue is not the actual iterative process but the creation of the client datasets. Simply calling create_tf_dataset_for_client(client) in a loop causes the problem.

So this is a minimal version of my code:

import tensorflow as tf
import tensorflow_federated as tff
import numpy as np
import pickle

BATCH_SIZE = 16
EPOCHS = 2
MAX_SEQUENCE_LEN = 20
NUM_ROUNDS = 100
CLIENTS_PER_ROUND = 10

def decode_fn(record_bytes):
    return tf.io.parse_single_example(
        record_bytes,
        {"x": tf.io.FixedLenFeature([MAX_SEQUENCE_LEN], dtype=tf.string),
         "y": tf.io.FixedLenFeature([MAX_SEQUENCE_LEN], dtype=tf.string)}
    )

def dataset_fn(path):
    return tf.data.TFRecordDataset([path]).map(decode_fn).padded_batch(BATCH_SIZE).repeat(EPOCHS)

def sample_client_data(data, client_ids, sampling_prob):
    clients_total = len(client_ids)
    x = np.random.uniform(size=clients_total)
    sampled_ids = [client_ids[i] for i in range(clients_total) if x[i] < sampling_prob]
    data = [train_data.create_tf_dataset_for_client(client) for client in sampled_ids]
    return data
    
with open('users.pkl', 'rb') as f:
    users = pickle.load(f)
    
train_client_ids = users["train"]
client_id_to_train_file = {i: "reddit_leaf_tf/" + i for i in train_client_ids}

train_data = tff.simulation.datasets.FilePerUserClientData(
    client_ids_to_files=client_id_to_train_file,
    dataset_fn=dataset_fn
)

sampling_prob = CLIENTS_PER_ROUND / len(train_client_ids)

for round_num in range(0, NUM_ROUNDS):
    print('Round {r}'.format(r=round_num))
    participants_data = sample_client_data(train_data, train_client_ids, sampling_prob)
    print("Round Completed")

I am using tensorflow-federated 19.0.

Is there something wrong with the way I create the client datasets or is it somehow expected that the RAM from the previous round is not freed?


Solution

  • schmana@ noticed this occurs when changing the cardinality of the CLIENTS placement (different number of client datasets) each round. This results in a cache filing up as documented in http://github.com/tensorflow/federated/issues/1215.

    A workaround in the immediate term would be to call:

    tff.framework.get_context_stack().current.executor_factory.clean_up_executors()
    

    At the start or end of every round.