Word Embedding: Building Semantic Search from Scratch in Java

Learner1067
12 min readJul 7, 2023

--

What is Semantic Search?

Semantic Search is a type of search that aims to understand the meaning behind the user’s query and the content being searched, rather than relying entirely on word matching. Search results served are based on context,intent and relationships between words or concepts .It goes beyond traditional keyword-based searching which mostly works on the principle of word matching. Queries we have been performing using RDMS or NOSQL query engines are the best example of keyword-based search.

Example of Word Semantic Search

How to find similar/related words:

bank word is related to fund, currency,finanacial.

Example of Sentence Semantic Search

You want to implement a semantic search engine to retrieve relevant new articles based on user queries from collection of news articles.

User Query: “Find articles about the impact of climate change on wildlife populations.”

In this example, the semantic search engine would analyze the query and the articles to identify concepts like “climate change,” “wildlife populations,” and their relationship. It would understand that the user is interested in articles specifically discussing how climate change affects wildlife populations. The search engine would then retrieve articles that address this specific topic, even if they don’t contain the exact query keywords.

Word embedding: The building block of Semantic Search

The technique to convert a word into a meaningful vector that encodes the meaning of the word in such a way that words that are closer in the vector space are expected to be similar in nature is called word embedding. Extending previous example words like bank, fund, currency will be nearer to each other in vector space [ Visualize in multidimensional space]. Screen shot below showing previous example of word bank and nearby words in multidimensional space using TensorFlow Projector.

Word Embedding allows us to represent words and documents as dense vectors in a high-dimensional space, capturing their semantic meaning. This has paved the way for various applications, including semantic search.

Words are stored as vector but what is vector?

In the context of math: a vector is an object that has both a magnitude and a direction. Vector can be represented as a one-dimensional array of numbers. The below image depicts a vector in 2 dimensions ( x,y ). In the form of an array, it can be represented as [12,5]. In a three-dimensional plane, the vector can be represented as an array containing three elements, for x,y and z, respectively.

Credit for image to Mathisfun

In the context of computer science and data analysis: a vector refers to an ordered collection of numerical values or features that represent a particular object or data point.

Algorithms to implement Word embedding.

1. Word2Vec

2. GloVe

3. BERT

Sentences can also be embedded into the vector using the below techniques.

  1. Average Word Embedding using word2vec
  2. Recurrent Neural Networks (RNN): RNNs, such as Long-Short-Term Memory (LSTM)
  3. Transformer Models: Transformer-based models, such as BERT

Let’s Dive Deep into Word2Vec

Word2vec is a technique for natural language processing (NLP) published in 2013 by Google. The word2vec algorithm uses a neural network model to learn word associations. Word2Vec is backed by two options: CBOW and Skip Gram. In this article I will focus on Skip Gram. The neural network consists of an input layer, a hidden layer and an output layer. The neural network used in this case of use slightly different techniques. The weights between the input layer and the hidden layer are vector representations of words. The final output from the output layer of the neural network is not a vector embedding of the words. We will try to understand the Skip Gram with an example in detail. All major frameworks provide word2Vec implementations by default and hence there is no need to code from scratch.

High Level Gist of the Skip Gram Algorithm

1. Neural Network Architecture:

  • The skip-gram model consists of an input layer, a hidden layer (embedding layer), and an output layer.
  • The input layer and output layer have the same dimensions, which equal the size of the vocabulary.
  • The hidden layer represents the word embeddings, and its dimensions are typically much smaller than the size of the vocabulary.

2. Training the Neural Network:

  • For each training example, the skip-gram model takes an input word and tries to predict the surrounding context words.
  • The input word is converted into its one-hot encoded vector representation and is fed into the neural network.
  • The hidden layer of the neural network represents the word embeddings. The values in the hidden layer are the learned vector representations of the input word.
  • The output layer uses softmax activation to produce a probability distribution over all words in the vocabulary.
  • During training, the objective is to maximize the probability of correctly predicting the context words given the input word.
  • The model is trained using stochastic gradient descent and backpropagation, where the weights of the neural network are updated to minimize the prediction error.

3. Extracting Word Embeddings:

  • After training the model, the word embeddings are obtained from the weights of the hidden layer of the neural network.
  • These learned vector representations capture semantic and syntactic relationships between words, as similar words are expected to have similar vector representations.

Understanding Skip Gram Model with an Example and how to implement it in Java?

Skip-gram is a neural network architecture used in natural language processing, specifically in word embedding models like Word2Vec. It is designed to learn word representations by predicting the context words given a target word.

Let’s try to understand it with an example.

We will create embedding for below sentences describing characters in “The Jungle Book”.

Sherkhan is king of the Jungle.

Mowgli is a brave boy.

Bagira is strong.

Ballu is very helping.

Lali is very kind.

  1. Remove extra spaces and remove stop words like is,the, a, . This is referred as Sub Sampling: High-frequency and low-frequency words often provide little information. Words with a frequency above a certain threshold, or below a certain threshold, may be subsampled or removed to speed up training. As stop words do not carry much meaning and hence it’s better to remove them.
  2. Create a map of all unique words. That will constitute overall vocab. In above example total number of unique words are 12.
  3. Skip Gram uses a sliding window to generate input word and context output word. Let’s take the value of a sliding window as 2. The screenshot below depicts how a training dataset is generated for one sentence. The training input column from example below, is first converted into one hot encoding and then fed to the neural network. The Training output column is used to improve the neural network as it is used for the expected output of the network.

4. One -hot Encoding: Each word in the vocabulary is represented as a one-hot encoded vector, where all elements are zero except for the element corresponding to the index of the word, which is set to one. There are 12 unique words in vocab. The one hot encoding will be a matrix of size 12x12. Each row of the matrix will represent a single word. The row will have only one element set to 1 and the rest of all will be 0s. The image below shows unique words with their one hot encoding. We can follow any sequence, in below example we have started with Sherkhan, but we can start with any word in the vocab. Encode all words in training dataset generated from previous step using one hot encoding before feeding to Neural Network.

So, the neural network will have input layer 12 inputs. For the sake of simplicity, let’s have 4 neurons in a hidden layer. The output layer will have 12 outputs again. The activation function for the output layer will be SoftMax. The weights between the input layer and the hidden layer are the word embedding. The output of the first hidden layer will act as an individual embedding as explained below.

The weights between the input layer and the hidden layer will be a matrix of size 12x4 as depicted below. These sample weights are word embeddings for 12 input words. For example, sherkhan is a word embedded using a vector present in the first row of the screenshot below.

The training data set is fed to the neural network. After training the model, the model should return to the king when Sherkhan is fed as an input. The weights between the input layer and the hidden layer are the word embedding / vectors for all words. The output of the first hidden layer will act as an individual embedding as explained below.

Java Code Walk through

  1. Tokenizer: This class implements one hot encoding. Create training input / output dataset from our Jungle Book example. It uses sliding window based on skip gram to create dataset.
  2. WordEmbeddingSample: It has neural network code, and it generates the word embedding.

The below screen shot from debug mode show weights /word embeddings of the neural network. Complete code can be found in subsequent sections.

The code below creates one hot encoding for our example dataset based on a skip gram moving window of 2. The class can be executed by running the main method.

public static void main(String[] args) {
String str = "Sherkhan is king of the Jungle";
String str1 = "Mowgli is a brave boy";
String str2 = "Bagira is strong";
String str3 = "Ballu is very helping";
String str4 = "Lali is very kind";

List<String> words = new ArrayList<>();
words.add(str);
words.add(str1);
words.add(str2);
words.add(str3);
words.add(str4);

HashSet<String> set = new HashSet<>();
set.add("the");
set.add("is");
set.add("a");
set.add("of");
set.add("very");
OneHotEncoder encoder = new OneHotEncoder(words, set,2);
List<Integer[][]> trainingSet = encoder.generateSkipGramsTrainingDataSet();
}

Tokenizer.java

The below implementation consists of one hot encoding, creating a training data set using a moving window algorithm used in the skip gram.

package org.ai.hope;

import java.util.*;

public class Tokenizer {

private List<String> inputs;

private HashSet<String> stopWords;

private int skipGramWindow;

public Tokenizer(List<String> ins, HashSet<String> words) {
this.inputs = ins;
stopWords = words;
}

public Tokenizer(List<String> ins, HashSet<String> words,int skipGramWindow ) {
this.inputs = ins;
stopWords = words;
this.skipGramWindow=skipGramWindow;
}

private HashSet<String> uniqueWords = new HashSet<>();

private List<String[]> removeOutStopWords() {
List<String[]> list = new ArrayList<>();

for (int i = 0; i < inputs.size(); i++) {
String temp = inputs.get(i).toLowerCase();
String[] words = temp.split(" ");
int space =0;
for (int j = 0; j < words.length; j++) {
if (stopWords.contains(words[j])) {
words[j] = "";
space = space+1;
} else {
uniqueWords.add(words[j]);
}
}

String[] finalWord = new String[words.length-space];
int index =0;
for(int k =0;k<words.length;k++)
{
if(words[k] == "")
{
continue;
}
finalWord[index]=words[k] ;
index=index+1;
}
list.add(finalWord);
}

return list;
}

private List<String[]> createBigrams(List<String[]> words) {
List<String[]> list = new ArrayList<>();

for (int i = 0; i < words.size(); i++) {
String[] array = words.get(i);

for (int j = 0; j < array.length; j++) {
if (array[j] == null) {
continue;
}

for (int k = 0; k < array.length; k++) {
if (k == j || array[k] == null) {
continue;
} else {
String[] bigrams = new String[2];

bigrams[0] = array[j];
bigrams[1] = array[k];
list.add(bigrams);
}
}
}
}

return list;
}

private List<String[]> createSkipGramsTrainingDataSet(List<String[]> words, int window) {
List<String[]> list = new ArrayList<>();

for (int i = 0; i < words.size(); i++) {
String[] array = words.get(i);

for (int j = 0; j < array.length; j++) {

if(array[j]==null)
{
continue;
}
int tWindow = window;
int k = j - 1;
int p = j + 1;
while (tWindow > 0) {
if (k >= 0) {
String[] grams = new String[2];
grams[0] = array[j];
grams[1] = array[k];
k = k - 1;
list.add(grams);

}
if (p < array.length) {
String[] grams = new String[2];
grams[0] = array[j];
grams[1] = array[p];
p = p + 1;
list.add(grams);
}
tWindow = tWindow - 1;
}
}

}

return list;
}

public List<Integer[][]> generateBigramsEncoding() {
List<Integer[][]> doubles = new ArrayList<>();
List<String[]> temp = removeOutStopWords();
List<String[]> grams = createBigrams(temp);

HashMap<String, Integer[]> hashMap = oneHotEncoder();

System.out.println(hashMap);

for (int i = 0; i < grams.size(); i++) {
String[] str = grams.get(i);
Integer[][] array = new Integer[2][uniqueWords.size()];

array[0] = hashMap.get(str[0]);
}
return doubles;
}

public List<Integer[][]> generateSkipGramsTrainingDataSet() {
List<Integer[][]> doubles = new ArrayList<>();
List<String[]> temp = removeOutStopWords();
List<String[]> grams = createSkipGramsTrainingDataSet(temp,skipGramWindow);

HashMap<String, Integer[]> hashMap = oneHotEncoder();

System.out.println(hashMap);

for (int i = 0; i < grams.size(); i++) {
String[] str = grams.get(i);
Integer[][] array = new Integer[2][uniqueWords.size()];

array[0] = hashMap.get(str[0]);
array[1] = hashMap.get(str[1]);
doubles.add(array);
}
return doubles;
}

public HashMap<Integer, List<Integer[]>> generateSkipGramsTrainingData() {
List<String[]> temp = removeOutStopWords();
HashMap<Integer, List<Integer[]>> dataSet = new HashMap<>();
List<String[]> grams = createSkipGramsTrainingDataSet(temp,skipGramWindow);
List<Integer[]> inputs = new ArrayList<>();
List<Integer[]> outputs = new ArrayList<>();
HashMap<String, Integer[]> hashMap = oneHotEncoder();
dataSet.put(1, inputs);
dataSet.put(2, outputs);
System.out.println(hashMap);

for (int i = 0; i < grams.size(); i++) {
String[] str = grams.get(i);
Integer[] tempInput = hashMap.get(str[0]);
Integer[] tempOutput = hashMap.get(str[0]);
inputs.add(tempInput);
outputs.add(tempOutput);
}
return dataSet;
}

public HashMap<String, Integer[]> oneHotEncoder() {
int i = 0;
HashMap<String, Integer[]> encoder = new HashMap<>();
for (String key : uniqueWords) {
Integer[] values = new Integer[uniqueWords.size()];
values[i] = 1;
encoder.put(key, values);
i = i+1;

}
return encoder;
}

public static void main(String[] args) {
String str = "Sherkhan is king of the Jungle";
String str1 = "Mowgli is a brave boy";
String str2 = "Bagira is strong";
String str3 = "Ballu is very helping";
String str4 = "Lali is very kind";


List<String> words = new ArrayList<>();

words.add(str);
words.add(str1);
words.add(str2);
words.add(str3);
words.add(str4);


HashSet<String> set = new HashSet<>();

set.add("the");
set.add("is");
set.add("a");
set.add("of");
set.add("very");

Tokenizer encoder = new Tokenizer(words, set,2);
List<Integer[][]> trainingSet = encoder.generateSkipGramsTrainingDataSet();

}

}

WordEmbeddingSample.java

Neural Network Definition Code for finding word embedding of our example in Java using Deeplearning4j Framework. The code below defines the neural network configuration and training model using a dataset created by the above code.


package org.ai.hope;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Random;

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;

public class WordEmbeddingSample {

public static void main(String[] args) {
DataSetIterator iterator = getTrainingData(batchSize, rng);
final int numInputs = 12;
int nHidden = 4;

MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder().seed(seed)
.weightInit(WeightInit.XAVIER).updater(new Nesterovs(learningRate, 0.9)).list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(nHidden).build())
.layer(1, new OutputLayer.Builder().activation(Activation.SOFTMAX).nIn(nHidden).nOut(numInputs).build())
.build());
net.init();
net.setListeners(new ScoreIterationListener(1));

for (int i = 0; i < nEpochs; i++) {
iterator.reset();
net.fit(iterator);
}

}

public static final int seed = 12345;
// Number of iterations per minibatch
public static final int iterations = 1;
// Number of epochs (full passes of the data)
public static final int nEpochs = 200;
// Number of data points
public static final int nSamples = 1000;
// Batch size: i.e., each epoch has nSamples/batchSize parameter updates
public static final int batchSize = 12;
// Network learning rate
public static final double learningRate = 0.01;
public static final Random rng = new Random(seed);

private static DataSetIterator getTrainingData(int batchSize, Random rand) {

String str = "Sherkhan is king of the Jungle";
String str1 = "Mowgli is a brave boy";
String str2 = "Bagira is strong";
String str3 = "Ballu is very helping";
String str4 = "Lali is very kind";

List<String> words = new ArrayList<>();

words.add(str);
words.add(str1);
words.add(str2);
words.add(str3);
words.add(str4);

HashSet<String> set = new HashSet<>();

set.add("the");
set.add("is");
set.add("a");
set.add("of");
set.add("very");

Tokenizer encoder = new Tokenizer(words, set, 2);
HashMap<Integer, List<Integer[]>> trainingSet = encoder.generateSkipGramsTrainingData();

List<Integer[]> inputs = trainingSet.get(1);
List<Integer[]> oputs = trainingSet.get(2);

double[][] ins = new double[inputs.size()][12];
double[][] ous = new double[oputs.size()][12];
for (int i = 0; i < inputs.size(); i++) {
double[] arrinput = new double[inputs.get(i).length];
double[] arrop = new double[inputs.get(i).length];
for (int j = 0; j < arrinput.length; j++) {
arrinput[j] = inputs.get(i)[j] == null ? 0 : inputs.get(i)[j].intValue();
arrop[j] = oputs.get(i)[j] == null ? 0 : oputs.get(i)[j].intValue();
}

ins[i] = arrinput;
ous[i] = arrop;
}
INDArray inputNDArray = Nd4j.create(ous);
;
INDArray outPut = Nd4j.create(ins);
DataSet dataSet = new DataSet(inputNDArray, outPut);
List<DataSet> listDs = dataSet.asList();
Collections.shuffle(listDs, rng);
return new ListDataSetIterator(listDs, batchSize);

}

}

Note: This article covers how word embeddings are created for words. In the next article we will cover how to use cosine similarity and k nearest neighbor is used in querying.

Disclaimer: This article projects only my personal view based on my findings and learnings.

--

--