Pytorch's EmbeddingBag allows for efficient lookup + reduce operations on varying length collections of embedding indices. There are 3 modes: "sum", "average" and "max" for the reduce operation. With "sum", you can also provide per_sample_weights giving you a weighted sum.
Why is per_sample_weights not allowed for the "max" operation? Looking at how it's implemented, I can only assume there is an issue with performing a "ReduceMean" or "ReduceMax" operation after a "Mul" operation. Could that be something to do with calculating gradients??
p.s: It's easy enough to turn a weighted sum into a weighted average by dividing by the sum of the weights, but for "max" you can't get a weighted equivalent like that.
The argument per_sample_weights was only implemented for mode='sum', not due to technical limitations, but because the developers found no use cases for a "weighted max":
I haven't been able to find use cases for "weighted mean" (which can be emulated via weighted sum) and "weighted max".
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With