Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to find the closest word to a vector using BERT

I am trying to get textual representation(or the closest word) of given word embedding using BERT. Basically I am trying to get similar functionality as in gensim:

>>> your_word_vector = array([-0.00449447, -0.00310097, 0.02421786, ...], dtype=float32)
>>> model.most_similar(positive=[your_word_vector], topn=1))

So far, I have been able to generate contextual word embedding using bert-as-service but can't figure out how to get closest words to this embedding. I have used pre-trained bert model (uncased_L-12_H-768_A-12) and haven't done any fine tuning.

like image 802
vishalaksh Avatar asked Jan 22 '20 18:01

vishalaksh


1 Answers

TL;DR

Following the Jindtrich's answer I implement a context-aware nearest neighbor searcher. The full code is available in my Github gist

It requires a BERT-like model (I use bert-embeddings) and a corpus of sentences (I took a small one from here), processes each sentence, and stores contextual token embeddings in an efficiently searchable data structure (I use KDTree, but feel free to choose FAISS or HNSW or whatever).

Examples

The model is constructed as follows:

# preparing the model
storage = ContextNeighborStorage(sentences=all_sentences, model=bert)
storage.process_sentences()
storage.build_search_index()

Then it can be queried for contextually most similar words, like

# querying the model
distances, neighbors, contexts = storage.query(
    query_sent='It is a power bank.', query_word='bank', k=5)

In this example, the nearest neighbor would be the word "bank" in the sentence "Finally, there’s a second version of the Duo that incorporates a 2000mAH power bank, the Flip Power World.".

If, however, we look for the same word with another context, like

distances, neighbors, contexts = storage.query(
    query_sent='It is an investment bank.', query_word='bank', k=5)

then the nearest neighbor will be in the sentence "The bank also was awarded a 5-star, Superior Bauer rating for Dec. 31, 2017, financial data."

If we don't want to retrieve the word "bank" or its derivative word, we can filter them out

distances, neighbors, contexts = storage.query(
     query_sent='It is an investment bank.', query_word='bank', k=5, filter_same_word=True)

and then the nearest neighbor will be the word "finance" in the sentence "Cahal is Vice Chairman of Deloitte UK and Chairman of the Advisory Corporate Finance business from 2014 (previously led the business from 2005).".

Application in NER

One of the cool applications of this approach is interpretable named entity recognition. We can fill the search index with IOB-labeled examples, and then use retrieved examples to infer the right label for the query word.

For example, the nearest neighbor of "Bezos announced that its two-day delivery service, Amazon Prime, had surpassed 100 million subscribers worldwide." is "Expanded third-party integration including Amazon Alexa, Google Assistant, and IFTTT.".

But for "The Atlantic has sufficient wave and tidal energy to carry most of the Amazon's sediments out to sea, thus the river does not form a true delta" the nearest neighbor is "And, this year our stories are the work of traveling from Brazil’s Iguassu Falls to a chicken farm in Atlanta".

So if these neighbors were labeled, we could infer that in the first context "Amazon" is an ORGanization, but in the second one it is a LOCation.

The code

Here is the class that does this work:

import numpy as np
from sklearn.neighbors import KDTree
from tqdm.auto import tqdm


class ContextNeighborStorage:
    def __init__(self, sentences, model):
        self.sentences = sentences
        self.model = model

    def process_sentences(self):
        result = self.model(self.sentences)

        self.sentence_ids = []
        self.token_ids = []
        self.all_tokens = []
        all_embeddings = []
        for i, (toks, embs) in enumerate(tqdm(result)):
            for j, (tok, emb) in enumerate(zip(toks, embs)):
                self.sentence_ids.append(i)
                self.token_ids.append(j)
                self.all_tokens.append(tok)
                all_embeddings.append(emb)
        all_embeddings = np.stack(all_embeddings)
        # we normalize embeddings, so that euclidian distance is equivalent to cosine distance
        self.normed_embeddings = (all_embeddings.T / (all_embeddings**2).sum(axis=1) ** 0.5).T

    def build_search_index(self):
        # this takes some time
        self.indexer = KDTree(self.normed_embeddings)

    def query(self, query_sent, query_word, k=10, filter_same_word=False):
        toks, embs = self.model([query_sent])[0]

        found = False
        for tok, emb in zip(toks, embs):
            if tok == query_word:
                found = True
                break
        if not found:
            raise ValueError('The query word {} is not a single token in sentence {}'.format(query_word, toks))
        emb = emb / sum(emb**2)**0.5

        if filter_same_word:
            initial_k = max(k, 100)
        else:
            initial_k = k
        di, idx = self.indexer.query(emb.reshape(1, -1), k=initial_k)
        distances = []
        neighbors = []
        contexts = []
        for i, index in enumerate(idx.ravel()):
            token = self.all_tokens[index]
            if filter_same_word and (query_word in token or token in query_word):
                continue
            distances.append(di.ravel()[i])
            neighbors.append(token)
            contexts.append(self.sentences[self.sentence_ids[index]])
            if len(distances) == k:
                break
        return distances, neighbors, contexts
like image 146
David Dale Avatar answered Sep 26 '22 02:09

David Dale