I am implementing federated learning with tensorflowjs. But i am kind of stuck in the federated averaging process. The idea is simple: get updated weights from multiple clients and average it in the server.
I have trained a model on browser, got the updated weights via model.getWeights() method and sent the weights to server for averaging.
//get weights from multiple clients(happens i client-side)
w1 = model.getWeights(); //weights from client 1
w2 = model.getWeights(); //weights from client 2
//calculate average of the weights(server-side)
var mean_weights= [];
let length = w1.length; // length of all weights_array is same
for(var i=0; i<length; i++){
let sum = w1[i].add(w2[i]);
let mean = sum.divide(2); //got confused here, how to calculate mean of tensors ??
mean_weights.push(mean);
}
//apply updates to the model(both client-side and server-side)
model.setWeights(mean_weights);
So my question is: How do I calculate the mean of tensor array ? Also, is this the right approach to perform federated averaging via tensorflowjs ?
Yes, but be careful. You can average two tensors with tf.mean
like https://stackoverflow.com/users/5069957/edkeveked said. However, remember axis=0
should be shortened to just 0
in JavaScript.
Just to rewrite his code in a second way:
const x = tf.tensor([1, 2, 3, 2, 3, 4], [2, 3]);
x.mean(0).print()
However, you asked if you're doing it right, and that depends on if you're averaging as you go or not. There's an issue with a rolling average.
If you average (10, 20) then 30, you get (22.5) a different number than averaging (20, 30) then 10 (17.5), which is of course different from averaging all three at the same time, which would give you 20.
Averages do not adhere to an order-irrelevance principle once they've been calculated. It's the division part that removes the associative property. So you'll need to either:
A: Store all model weights and calculate a new average each time based on all previous models
or
B: Add a weighting system to your federated average so more recent models do not significantly affect the system.
I recommend B in the situation that you:
You can computer a weighted average adjusting the denominator for your existing model vs your incoming model.
In JavaScript you can do something simple like this to computer a weighted average between two values:
const modelVal1 = 0
const modelVal2 = 1
const weight1 = 0.5
const weight2 = 1 - weight1
const average = (modelVal1 * weight1) + (modelVal2 * weight2)
The above code is your common evenly weighted average, but as you adjust the weight1, you are rebalancing the scales to significantly adjust the outcome in favor of modelVal1
or modelVal2
.
Obviously, you'll need to convert the JavaScript I have shown into tensor mathematical functions, but that's trivial.
Iterate averaging (or weighted average) with weights decaying is often used in Federated learning. See Iterate averaging as regularization for stochastic gradient descent, and Server Averaging for Federated Learning.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With