I’m not aware of any way to accomplish what you’re describing besides those you’ve ruled out (federated learning and mixtures of experts). Naively averaging weights of models trained on disjoint datasets won’t work for LLMs or 1+ hidden layer DNNs (though it will for logistic or linear models). This sounds to me like an open research question.
ohmygad45
joined 11 months ago
Here’s a simple intuition as to why averaging the weights of a 1+ hidden layer NN won’t work: pick a hidden layer in your model and apply a permutation matrix to its weights (along the input axis) and the inverse permutation matrix to the previous layer (along the output axis). Obviously the model is unchanged (from an input/output perspective). Repeat that N times (where N is the input dimension of the hidden layer you picked). You now have N models that are identical from an input output perspective. If you average those model weights, your hidden layer will output a constant because all its weights will be identical. This averaged weights model is obviously completely broken even though it’s the average of N “identical” (from an input / output perspective) model. QED.