Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the recommended way to do embeddings in jax?

Tags:

jax

So I mean something where you have a categorical feature $X$ (suppose you have turned it into ints already) and say you want to embed that in some dimension using the features $A$ where $A$ is arity x n_embed.

What is the usual way to do this? Is using a for loop and vmap correct? I do not want something like jax.nn, something more efficient like

https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding

For example consider high arity and low embedding dim.

Is it jnp.take as in the flax.linen implementation here? https://github.com/google/flax/blob/main/flax/linen/linear.py#L624

like image 836
mathtick Avatar asked Oct 28 '25 09:10

mathtick


1 Answers

Indeed the typical way to do this in pure jax is with jnp.take. Given array A of embeddings of shape (num_embeddings, num_features) and categorical feature x of integers shaped (n,) then the following gives you the embedding lookup.

jnp.take(A, x, axis=0)  # shape: (n, num_features)

If using Flax then the recommended way would be to use the flax.linen.Embed module and would achieve the same effect:

import flax.linen as nn

class Model(nn.Module): 
  @nn.compact
  def __call__(self, x):
    emb = nn.Embed(num_embeddings, num_features)(x)  # shape
like image 109
Jon Deaton Avatar answered Oct 31 '25 12:10

Jon Deaton



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!