RankNet, LambdaRank TensorFlow Implementation — part III

Louis Kit Lung Law
The Startup
Published in
3 min readFeb 8, 2021

In this blog, I will talk about the how to speed up training of RankNet and I will refer to this speed up version as Factorised RankNet. Note that this is published in the paper Learning to Rank with Nonsmooth Cost Functions.

RankNet Training Process Examined

Typically, to train a ML model, we need to do two things:

  1. compute cost function
  2. compute gradient (derivative of cost with respect to model’s weights)

And recall that in part I, we have the following equations:

Equation 1. RankNet’s cost function
Equation 2. Derivative of cost with respect to RankNet’s weights
Equation 3. Derivative of cost with respect to oij=oi-oj

Now let’s look at an example. Assume there are 4 documents for a query d1, d2, d3, d4, thus there are four pairs of documents d1&d2, d1&d3, d1&d4, d2&d3, d2&d4, d4&d3 and the training process would be like this (note that stochastic gradient descent is used in RankNet:

Fig. 1, Stochastic gradient descent

As we can see, for a query with n (e.g. 4 )documents, n(n-1) (e.g. 12) back propagation are done to compute doi/dWk. Thus the training time almost scale quadratically with the mean number of documents per query.

But some terms are “repeated” multiple times (e.g. the one highlighted in blue above: do1/dWk), so could we combine them together to reduce the number of back propagations needed?

Factorised RankNet

In order to speed up the training, we could switch from stochastic gradient descent to mini-batch where each batch corresponds to one query. So now the training process looks like this:

Fig. 2, Mini-Batch gradient descent

But using mini-batch alone wont give too much speed up as we still need to compute n(n-1) back propagations for each query.

Let see how we could factorise the sum highlighted in Fig. 2

Fig. 3, Factoring mini-batch gradient descent

While we still need to compute n(n-1) times dCij/doij, we now only need to compute n times doi/dWk !
And since the computation of dCij/doij is very cheap (refer to equation 3), thus the original O(n²) RankNet algorithm is effectively reduced to O(n).

To make this more general:

Fig. 4, Generalised the result of Fig. 3

And the training process of Factorised RankNet is like below

Fig. 5, Factorised RankNet’s training

Implementation

Model Architecture

First let define the model, note that instead of taking a pair of documents as input as before, only one document at a time.

Fig. 6 Factorised RankNet’s model architecture

Epoch Looping

In the following function, for each epoch, each query is treated as one batch, and documents under that query are used to train the model.

Gradient Calculation

The following function will calculate the gradient and apply gradient descent to train the model.

RankNet vs Factorised RankNet

Now everything is ready, let’s compare the performance of RankNet vs Factorised RankNet.

Fig. 7, Loss Plot on Validation Set

As we can see from the graph above, Factorised RankNet used less time to achieve the same loss as compared to RankNet.

Here is the link to the jupyter notebook which used to generate the graph above.

--

--