Importance Sampling for Deep Learning

(I thanks the reader for her or his comments. Being not fluent in English, if you see grammatical mistakes which reduces this article quality, please contact me!)

In this article, I will talk about importance sampling and propose it as an alternative to the browsing of an entire dataset at each epoch when training a deep learning model. The idea is, once the model begins to do correct predictions, to examine more frequently training examples where it performs poorly. The use of epochs is given up and replaced by the Metropolis-Hastings algorithm correlated to the loss funtion, used to derive a probability distribution of picking up a particular training example.

Introduction

Deep Learning is a revolution that everybody can easily take part in. Indeed, with the open access to performant frameworks such as TensorFlow or PyTorch, anybody can easily reproduce powerful models and do amazing things at home even though it has just been developed by leading research teams in the field of artificial intelligence (see the example of melody generators or neural style transfer, and their pedagogical use to teach deep learning). In this recent field, creativity and imagination prevail on mathematical knowledge, and even if it requires efforts to understand research articles in the field of AI, an individual can understand, reproduce, and even apply at home on new cases what researchers did.

However, the biggest difficulty of training a model is due to the time it requires. Datasets are often very large, and performing just one epoch of training a model always takes an undetermined time. For that, people at home with limited resources may have brilliant ideas, but may not be able to concretize any of them. I will then propose here an alternative to the usual epoch browsing while training a model by introducing importance sampling.

Classic model training

Models are trained by a series of forward and backward propagations :

  1. The set of training example is subdivided into batches,
  2. The model evaluates a batch by calculating its loss function (forward propagation),
  3. the model estimates the gradients and actualizes its weights (backpropagation),
  4. the model repeats steps 2 and 3 until it reaches the end of the training set (that does one epoch),
  5. the model repeats the whole process as many times the user requires it. Optionally, the examples can be shuffled before training another epoch, thus making new batches reducing the risks of overfitting the data.

This classic way of training a deep learning model has however one drawback, which is making no difference between examples where it performs very well (which have a low individual loss) and those where it performs poorly (high loss). This is depicted in the figure below.

If we assume that example in green has a loss near-zero (perfect match between prediction and label), they won’t contribute much in improving the model. A solution would be to put away the green examples to have more yellow and red. It would be even better if the model spent more time in training the red examples than the yellow one. But how to do that? Importance sampling will be the solution.

Importance sampling: basis

Importance sampling is making a random sample of a set according to a probability distribution among the elements of the set. In the case of a training batch, we will attach weights to the training examples, and the elements of high weight will have more chance to be selected.

Each batch is here generated independently (actually, in practice, it will be almost independently as when we will apply the Metropolis-Hasting algorithm it will introduce small correlations between samples). You can perfectly have examples appearing in two consecutive batches. The red examples are selected more frequently than the yellows and the greens. The greens are not totally forgotten but will appear less frequently. This is expected to save time in training, but how to apply importance sampling? The individual losses of the training examples (the loss function evaluated in only one example) make perfect candidates for the weights. However, one question comes in minds: how to sample the training set without calculating all the losses? Indeed, our goal is to provide an algorithm saving computational time, thus we want to avoid calculating the losses of example we won’t use in the next batch. This will be realized with the Metropolis-Hasting algorithm.

Random walks - Metropolis-Hasting algorithm

The idea behind the Metropolis-Hasting algorithm is to perform random walks inside the set. In the literature about this algorithm, moving units performing the random walks are referred to walkers. Given one batch of training example, we will attach one walker to every example (so we will have a number of walkers equaling the fixed batch size). At the end of the forward-backward propagation step, when a new batch is needed, every walker will move independently and randomly to a new training example. A new set of examples will thus be proposed to compose the next batch, however, up to now, no use has been done of the loss function for importance sampling.

Importance sampling appears when we perform a random toss to decide, for each walker move, if we accept it or not. For each walker, a random number t will be drawn uniformly between 0 and 1, and it will decide if the walker will accept the new example for the next batch or if it will keep the old one. The relative losses determine the frontier 0<f<1 to be compared with the random number t. If t>f, the move is accepted, if not, it is rejected. So the algorithm will tend to keep examples with high losses longer than the one with low losses.

The number f introduced previously is referred to as the acceptance probability in the literature. We can slightly modify the theory in this case by not having strictly a probabilistic number which can be higher to 1, but taking f = loss(old example)/loss(proposed example). Therefore if loss(old example) is the double of loss(new example), the move will be automatically rejected.

We now see that the Metropolis-Hasting may help us to save time the way we wanted because an example will high loss will effectively be examined more often than another one with low loss. But one question we may ask is how to track the global loss function if we do not compute it anymore over one epoch? We could take the average over one batch, or average different batches, but if we average uniformly, we should expect to obtain a result higher than the actual loss. Indeed, each batch tends to contain more example of high loss, thus a uniform average of the individual losses would overestimate the global loss. The solution will be to ponderate them with an appropriate probability as we will describe now.

Computation of the loss function

The loss function over the entire set is defined as the average of individual losses of the training examples:

L = Sum(individual losses) / (Number of examples) = E|u(individual losses)

where E|u refers to the statistical expected value of the losses with regard to the uniform distribution.

If we sample randomly and uniformly the training set, we should obtain a good approximation of its loss. But with importance sampling, we do not pick the examples uniformly, but according to a probability distribution p. We can retrieve L by ponderating the individual losses l(x) of example x by 1/p(x):

Average over a batch(l(x)/p(x)) -> Average over the set(p(x) x l(x)/p(x))
= Average over the set(l(x)) =
L

when the number of batches tends to infinity.

It remains to determine how to calculate p. But the probability distribution derives itself from the global loss function:

p(x) = l(x) / L

At this moment, we will let this question unanswered to perform some tests first.

Conclusion and perspectives

With Importance sampling, we proposed an algorithm based on random walks to sample more often the training examples of high loss. Some tests remain to be done to evaluate the efficiency of this method. Its impact on overfitting should also be studied.

Further readings

This article offers only an overview of what can be done with importance sampling in the specific case of Deep Learning. However, this subject is extremely much richer than what is shown here, many important concepts such a the detailed balance condition to which the walker obey have been ignored to give only a brief overview. The Wikipedia page
https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm
gives deeper theoretical insight.

Acknowledgment

The author thanks his former Ph. D. advisor Michele Casula for having taught him stochastic methods in numerical computation. Also, I thank the readers who will suggest me further improvement to this article, to make it clearer, with fewer mistakes (mathematical or grammatical), or just for positive discussions.

--

--