I am trying to use a custom embedding model in Langchain with chromaDB. I can't seem to find a way to use the base embedding class without having to use some other provider (like OpenAIEmbeddings or HuggingFaceEmbeddings). Am I missing something?
On the Langchain page it says that the base Embeddings class in LangChain provides two methods: one for embedding documents and one for embedding a query. so I figured there must be a way to create another class on top of this class and overwrite/implement those methods with our own methods. But how do I do that?
I tried to somehow use the base embeddings class but am unable to create a new embedding object/class on top of it.
In order to use embeddings with something like langchain, you need to include the embed_documents and embed_query methods. Otherwise, routines such as
Like so...
from sentence_transformers import SentenceTransformer
from typing import List
class MyEmbeddings:
def __init__(self, model):
self.model = SentenceTransformer(model, trust_remote_code=True)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.model.encode(t).tolist() for t in texts]
def embed_query(self, query: str) -> List[float]:
return self.model.encode([query]).tolist()
#...
embeddings=MyEmbeddings('your model name') # e.g. "sentence-transformers/all-MiniLM-L6-v2"
chromadb = Chroma.from_documents(
documents=your_docs,
embedding=embeddings,
)
You can create your own class and implement the methods such as embed_documents. If you strictly adhere to typing you can extend the Embeddings class (from langchain_core.embeddings.embeddings import Embeddings) and implement the abstract methods there. You can find the class implementation here.
Below is a small working custom embedding class I used with semantic chunking.
from sentence_transformers import SentenceTransformer
from langchain_experimental.text_splitter import SemanticChunker
from typing import List
class MyEmbeddings:
def __init__(self):
self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.model.encode(t).tolist() for t in texts]
embeddings = MyEmbeddings()
splitter = SemanticChunker(embeddings)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With