Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Federated learning with tensorflowjs

I am implementing federated learning with tensorflowjs. But i am kind of stuck in the federated averaging process. The idea is simple: get updated weights from multiple clients and average it in the server.

I have trained a model on browser, got the updated weights via model.getWeights() method and sent the weights to server for averaging.


//get weights from multiple clients(happens i client-side)
w1 = model.getWeights(); //weights from client 1
w2 = model.getWeights(); //weights from client 2


//calculate average of the weights(server-side)
var mean_weights= [];
let length = w1.length; // length of all weights_array is same
for(var i=0; i<length; i++){
    let sum = w1[i].add(w2[i]);
    let mean = sum.divide(2); //got confused here, how to calculate mean of tensors ??
    mean_weights.push(mean);
}

//apply updates to the model(both client-side and server-side)
model.setWeights(mean_weights);

So my question is: How do I calculate the mean of tensor array ? Also, is this the right approach to perform federated averaging via tensorflowjs ?

like image 406
Almighty Avatar asked Nov 07 '22 18:11

Almighty


1 Answers

Yes, but be careful. You can average two tensors with tf.mean like https://stackoverflow.com/users/5069957/edkeveked said. However, remember axis=0 should be shortened to just 0 in JavaScript.

Just to rewrite his code in a second way:

const x = tf.tensor([1, 2, 3, 2, 3, 4], [2, 3]);
x.mean(0).print()

However, you asked if you're doing it right, and that depends on if you're averaging as you go or not. There's an issue with a rolling average.

Example:

If you average (10, 20) then 30, you get (22.5) a different number than averaging (20, 30) then 10 (17.5), which is of course different from averaging all three at the same time, which would give you 20.

Averages do not adhere to an order-irrelevance principle once they've been calculated. It's the division part that removes the associative property. So you'll need to either:

A: Store all model weights and calculate a new average each time based on all previous models

or

B: Add a weighting system to your federated average so more recent models do not significantly affect the system.

Which makes sense?

I recommend B in the situation that you:

  1. Don't want to or cannot store every model and weight ever submitted.
  2. You know some models have seen more valid data, and should be weighted appropriately compared to blind models.

You can computer a weighted average adjusting the denominator for your existing model vs your incoming model.

In JavaScript you can do something simple like this to computer a weighted average between two values:

const modelVal1 = 0
const modelVal2 = 1

const weight1 = 0.5
const weight2 = 1 - weight1


const average = (modelVal1 * weight1) + (modelVal2 * weight2)

The above code is your common evenly weighted average, but as you adjust the weight1, you are rebalancing the scales to significantly adjust the outcome in favor of modelVal1 or modelVal2.

Obviously, you'll need to convert the JavaScript I have shown into tensor mathematical functions, but that's trivial.

Iterate averaging (or weighted average) with weights decaying is often used in Federated learning. See Iterate averaging as regularization for stochastic gradient descent, and Server Averaging for Federated Learning.

like image 194
Gant Laborde Avatar answered Nov 15 '22 13:11

Gant Laborde