I am experimenting with TensorFlow Federated, simulating a training process with the FedAvg algorithm.
def model_fn():
# Wrap a Keras model for use with TensorFlow Federated
keras_model = get_uncompiled_model()
# For the federated procedure, the model must be uncompiled
return tff.learning.models.functional_model_from_keras(
keras_model,
loss_fn=tf.keras.losses.BinaryCrossentropy(),
input_spec=(
tf.TensorSpec(shape=[None, X_train.shape[1]], dtype=tf.float32),
tf.TensorSpec(shape=[None], dtype=tf.int32)
),
metrics_constructor=collections.OrderedDict(
accuracy=tf.keras.metrics.BinaryAccuracy,
precision=tf.keras.metrics.Precision,
recall=tf.keras.metrics.Recall,
false_positives=tf.keras.metrics.FalsePositives,
false_negatives=tf.keras.metrics.FalseNegatives,
true_positives=tf.keras.metrics.TruePositives,
true_negatives=tf.keras.metrics.TrueNegatives
)
)
trainer = tff.learning.algorithms.build_weighted_fed_avg(
model_fn= model_fn(),
client_optimizer_fn=client_optimizer,
server_optimizer_fn=server_optimizer
)
I want to use custom weights to aggregate the clients' updates instead of using their number of samples. I know that tff.learning.algorithms.build_weighted_fed_avg()
has a parameter called client_weighting,
but the only value accepted is from the class tff.learning.ClientWeighting
, which is an enum.
So, the only way to do that seems to be to write a custom WeightedAggregator. I've tried following this tutorial that explains how to write an unweighted aggregator, but I cannot make it work transforming it into a weighted one.
This is what I've tried to do:
@tff.tensorflow.computation
def custom_weighted_aggregate(values, weights):
# Normalize client weights
total_weight = tf.reduce_sum(weights)
normalized_weights = weights / total_weight
# Compute weighted sum of client updates
weighted_sum = tf.nest.map_structure(
lambda v: tf.reduce_sum(normalized_weights * v, axis=0),
values
)
return weighted_sum
class CustomWeightedAggregator(tff.aggregators.WeightedAggregationFactory):
def __init__(self):
pass
def create(self, value_type, weight_type):
@tff.federated_computation
def initialize():
return tff.federated_value(0.0, tff.SERVER)
@tff.federated_computation(
initialize.type_signature.result,
tff.FederatedType(value_type, tff.CLIENTS),
tff.FederatedType(weight_type, tff.CLIENTS)
)
def next(state, value, weight):
aggregate_value = tff.federated_map(custom_weighted_aggregate, (value, weight))
return tff.templates.MeasuredProcessOutput(
state, aggregate_value, tff.federated_value((), tff.SERVER)
)
return tff.templates.AggregationProcess(initialize, next)
@property
def is_weighted(self):
return True
But I get the following error:
AggregationPlacementError: The "result" attribute of return type of next_fn
must be placed at SERVER, but found {<float32[7],float32,float32[1],float32>}@CLIENTS.
To do cross-device reductions in TFF, we must use TFF's special intrinsic
symbols--essentially, these 'register' certain reductions (e.g., the reduce_sum
above) as special, so that they can be identified later as the ones that the user intended to use to express 'this is a reduction that should go cross-device now'.
In TFF, pure tensorflow logic is always 'running locally', rather than extracted to run cross-device. This means that the tff.tensorflow.computation
you have above (custom_weighted_aggregate
) is really expressing a 'per-client reduction', rather than a cross-client reduction.
One way you might express such a thing, if your values
and weights
are placed at clients, could be that captured in this implementation. Or, alternatively, I believe that implementation should be directly usable from this symbol, whose create
symbol should return you an aggregation process to whom you can pass custom weights.