What is a Recurrent Neural Network — The problem with RNNs (Part 2)

Hadar Sharvit
3 min readFeb 26, 2024

Recurrent Neural Networks (RNNs) are powerful models for processing sequential data, such as time-series or natural language. However, they can be challenging to optimize due to their ability to process history input terms. Unlike basic neural networks, the loss term in RNNs must account for multiple values in the history input sequence, which can make optimization more delicate. In this blog post, we will explore the unique challenges of optimizing RNNs, and discuss strategies for improving their performance.

Recurrent Loss Term (beware of math!)

Let’s denote the model prediction and hidden state as

model prediction and hidden state

hat is, the network outputs both ht​ and ot​. we also denote Nh​=h, No​=o (as N outputs a tuple). from here, we may choose some loss function L (that could, for example, take the squared differences between N and some labels y), but regardless of how we chose L, there would be a term in which we specifically derive the network’s prediction, so let us focus on that term:

deriving the hidden state w.r.t the parameters

where ∂xt/θ​​=0 because the input is not a function of the network parameters. Now comes the tricky part, as we recursively expand the values of ∂ht/θ​​ for every t to achieve the final results

loss derivative

note that l is some loss function over ot​, so we can easily derive it concerning ot​. Furthermore, ot​=σy​(Wy*ht​+by​) so we can derive it concerning ℎh as well. the third factor is given from the calculation above, and this concluded what we wished to achieve.

Lastly, we emphasize that though the derivation of ∂ht/θ​​ from above can be computed fully, for large T this gets extremely long. There are some strategies for dealing with such problems, but we will not discuss them in this post.

Vanishing/Exploding Gradient Problem

when we look at the term

product term

it is reasonable to think that if cj​ is relatively big, the product would converge exponentially. This will cause a problem, as our gradient iteration will not bring us to the solution (in fact, it will probably not provide any solution). On the contrary, if the term cj​ is sufficiently small, we will face the opposite problem of vanishing gradient, and the network will work hard to converge, if even.
the problem described above isn’t solely related to RNNs — For example, a very deep NN could encounter vanishing gradients too. To minimize the effect, “skip-connections” were introduced. A skip connection performs a rather straightforward operation — For a given vector x and a layer F, is defined as SC(x)=F(x)+x. Essentially allowing the vector to “skip” a layer.

In the next blog post, we will understand how we solve (to some extent) the gradient problem, and how our solution relates to skip connections. Hint — LSTMs 🤫

--

--