Recurrent Neural Network — Lesson 5: Backpropagation Through Time (BPTT) and RNN Training

Machine Learning in Plain English
2 min readAug 13, 2023

--

Understanding BPTT: Unfolding Computation Graphs

  • RNNs as Unfolded Computational Graphs: When training RNNs, it’s helpful to visualize them as an “unfolded” version across time steps. This reveals the sequential nature of RNNs, where each time step can be seen as a separate layer.
  • How BPTT Works: Like standard backpropagation in feedforward neural networks, BPTT computes gradients by propagating the gradient backward in time. For each time step, the gradients are computed and accumulated. These accumulated gradients are then used to update the weights of the RNN.
  • Truncated BPTT: In practice, RNNs can have a large number of time steps which makes BPTT computationally expensive. Truncated BPTT is a common approach where the backward pass is limited to a fixed number of time steps. This not only saves computational time but also helps mitigate the vanishing/exploding gradient problem.

The Challenge of Training RNNs: Computational and Memory Requirements

  • Vanishing and Exploding Gradients: As previously discussed, RNNs, especially deep ones or those processing long sequences, suffer from these issues, making the training unstable or very slow.
  • Memory Usage: Because of the recurrent nature, each gradient computation requires information from the current and previous time steps. This increases the memory requirement significantly, especially for long sequences.
  • Computational Complexity: Training RNNs, especially on long sequences, can be computationally intensive because the unfolding across time steps means each forward and backward pass is longer than in standard feedforward networks.

Strategies to Address Overfitting in RNNs

  • Dropout: While dropout is a commonly used regularization technique in feedforward networks, applying it in RNNs is a bit trickier. Naively dropping out neurons can disrupt the RNN’s state and harm performance. Instead, “variational dropout” is often used, where dropout masks are consistent across time steps.
  • Regularization: L1 and L2 regularization can also be applied to the weights of RNNs. However, one should be cautious, as too much regularization might hinder the network from capturing intricate temporal patterns.
  • Gradient Clipping: This isn’t precisely a regularization method to prevent overfitting, but it’s a crucial technique to prevent exploding gradients in RNNs. If a gradient surpasses a certain threshold, it’s scaled down to prevent excessively large updates that can destabilize training.

--

--