Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I merge two trained neural network weight matrices into one?

I have two identical neural networks running on two separate computers (to reduce the time taken to train the network), each having a subset of a complete data set (MNIST).

My question is; can I combine the two weight matrices of both networks into one matrix, while remaining a proper accuracy? I have seen several articles about 'batching' or 'stochastic gradient descent', but I don't think this is applicable to my situation.

If this is possible, could you also provide me with some pseudo code? Any input is valuable!

Thank you,

like image 387
Gerrit Luimstra Avatar asked Apr 23 '18 18:04

Gerrit Luimstra


1 Answers

In general if you combine the weights / biases entirely after training, this is unlikely to produce good results. However, there are ways to make it work.

Intuition about combining weights Consider the following simple example: You have a MLP with one hidden layer. Any two instances of the MLP can produce identical outputs for identical inputs if the nodes in the hidden layer are permuted, the weights input->hidden are permuted the same way, and the weights hidden->output are permuted using the inverse permutation. In other words, even if there was no randomness in what the final network you end up with does, which hidden node corresponds to a particular feature would be random (and determined from the noise of the initializations).

If you train two MLPs like that on different data (or random subsets of the same data), they would likely end up with different permutations of the hidden nodes even if the initializations are all the same, because of the noise of the gradients during training.

Now, if that a certain property of the input activates most strongly the i-th node of network A and the j-th node of network B (and in general i != j), averaging the weights between the i-th node of A and i-th node of B (which corresponds to a different feature) is likely to decrease performance, or even produce a network that produces nonsense outputs.

There are two possible fixes here - you can use either one, or both together. The idea is to either figure out which nodes match between the two networks, or to force the nodes to match.

Solution A: Train both networks on different data for a few iterations. Average the weights of both, and replace both networks with the average weights. Repeat. This makes the i-th node of each network learn the same feature as the matching node of the other network, since they can't ever diverge too far. They are frequently re-initialized from the average weights, so once the permutation is determined it is likely to stay stable.

A reasonable value for how often to average is somewhere between once per epoch and once every few minibatches. Learning is still faster than training one network on all the data sequentially, although not quite 2x faster with 2 networks. Communication overhead is a lot lower than averaging weights (or gradients) after every minibatch. This can be run on different machines in a cluster: transferring weights is not prohibitive since it is relatively infrequent. Also, any number of networks trained at the same time (and splits of the data) can be more than two: up to 10-20 works okay in practice.

(Hint: for better results, after every epoch, do a new random split of the data between the networks you're training)

This is similar in effect to "gradient aggregation" that was mentioned here, but aggregates a lot less often. You can think of it as "lazy aggregation".

Solution B: Try to figure out which hidden layer nodes match before averaging. Calculate some similarity metric on the weights (could be L2 or anything along those lines), and average the weights of pairs of most-similar nodes from the two networks. You can also do a weighted average of more than just a pair of nodes; for example you can average all nodes, or the k-most similar nodes, where the weights used are a function of the similarity.

For deep networks, you have to keep track of the pairings all the way up from the input, and permute the weights according to the highest-similarity pairings of the lower level before calculating similarity on the next level (or if doing weighted averaging, propagate the weights).
This probably works for a networks with a few layers but I think for very deep networks this is unlikely to work perfectly. It will still work okay for the first few layers, but tracking the permutations will likely fail to find good matching nodes by the time you get to the top of the network.

Another way to deal with deep networks (other that tracking the permutations up from the bottom) is to run both networks on a test dataset and record the activations of all nodes for each input, then average the weights of nodes which have similar activation pattern (ie which tend to be activated strongly by the same inputs). Again this could be based on just averaging the most similar pair from A and B, or a suitable weighted average of more than two nodes.

You can use this technique together with "Solution A" above, to average weights somewhat less often. You can also use the weighted averaging by node similarity to speed the convergence of "Solution A". In that case it is okay if the method in "Solution B" doesn't work perfectly, since the networks are replaced with the combined network once in a while - but the combined network could be better if it is produced by some matching method rather than simple averaging. Whether the extra calculations are worth it vs the reduced communication overhead in a cluster and faster convergence depends on your network architecture etc.

like image 66
Alex I Avatar answered Nov 11 '22 15:11

Alex I