What is a Recurrent Neural Network — The problem with RNNs (Part 2)
Recurrent Neural Networks (RNNs) are powerful models for processing sequential data, such as time-series or natural language. However, they can be challenging to optimize due to their ability to process history input terms. Unlike basic neural networks, the loss term in RNNs must account for multiple values in the history input sequence, which can make optimization more delicate. In this blog post, we will explore the unique challenges of optimizing RNNs, and discuss strategies for improving their performance.
Recurrent Loss Term (beware of math!)
Let’s denote the model prediction and hidden state as
hat is, the network outputs both ht and ot. we also denote Nh=h, No=o (as N outputs a tuple). from here, we may choose some loss function L (that could, for example, take the squared differences between N and some labels y), but regardless of how we chose L, there would be a term in which we specifically derive the network’s prediction, so let us focus on that term:
where ∂xt/∂θ=0 because the input is not a function of the network parameters. Now comes the tricky part, as we recursively expand the values of ∂ht/∂θ for every t to achieve the final results
note that l is some loss function over ot, so we can easily derive it concerning ot. Furthermore, ot=σy(Wy*ht+by) so we can derive it concerning ℎh as well. the third factor is given from the calculation above, and this concluded what we wished to achieve.
Lastly, we emphasize that though the derivation of ∂ht/∂θ from above can be computed fully, for large T this gets extremely long. There are some strategies for dealing with such problems, but we will not discuss them in this post.
Vanishing/Exploding Gradient Problem
when we look at the term
it is reasonable to think that if cj is relatively big, the product would converge exponentially. This will cause a problem, as our gradient iteration will not bring us to the solution (in fact, it will probably not provide any solution). On the contrary, if the term cj is sufficiently small, we will face the opposite problem of vanishing gradient, and the network will work hard to converge, if even.
the problem described above isn’t solely related to RNNs — For example, a very deep NN could encounter vanishing gradients too. To minimize the effect, “skip-connections” were introduced. A skip connection performs a rather straightforward operation — For a given vector x and a layer F, is defined as SC(x)=F(x)+x. Essentially allowing the vector to “skip” a layer.
In the next blog post, we will understand how we solve (to some extent) the gradient problem, and how our solution relates to skip connections. Hint — LSTMs 🤫