Demystifying Neural Networks: RNN

Stateful Neural Network for Sequential Data

Dagang Wei
5 min readFeb 8, 2024
RNN | source

This article is part of the series Demystifying Neural Networks.

Introduction

Deep learning, a powerhouse within the realm of artificial intelligence, continues to pave the way for groundbreaking advancements in fields like computer vision and natural language processing. This progress is due in no small part to the extraordinary capabilities of neural networks — computational systems drawing inspiration from the biological brain. In this blog post, we’ll dive into a specific class of neural networks known as Recurrent Neural Networks (RNNs), designed to excel at handling sequential data.

What is an RNN?

Traditional neural networks often have the assumption that inputs (and outputs) are independent of each other. But consider tasks like understanding a sentence or predicting the next word in a sequence. The context provided by the words that came before is crucial! This is where RNNs shine.

RNNs possess an internal ‘memory’ that allows them to incorporate information from previous inputs when processing the current one. This makes them an ideal choice for:

  • Natural Language Processing (NLP): Machine translation, text generation, sentiment analysis
  • Time Series Analysis: Stock market prediction, weather forecasting
  • Speech Recognition: Transcribing spoken language

Why RNNs Matter

The core strength of RNNs lies in their ability to make sense of sequential data. Consider these cases:

  • Understanding long-term dependencies: In a sentence like “The clouds are in the sky,” an RNN can connect the word “sky” back to the earlier word “clouds.”
  • Continuous learning: RNNs aren’t just limited to fixed-size inputs. They can process sequences of varying lengths, making them versatile for real-world applications.

How RNNs Work

Let’s break down the mechanics of an RNN:

  1. The RNN Cell: The RNN is built from interconnected cells. Each cell receives the current input at a given timestep and the hidden state from the previous timestep.
  2. Hidden State: This is the RNN’s “memory.” It encapsulates information gleaned from past inputs.
  3. Calculations: Within the cell, the input and hidden state are combined via mathematical operations to compute a new hidden state and produce an output.
  4. Unfolding: The concept of unfolding an RNN over time lets us visualize it as multiple copies of the same cell connected in a chain.

Limitations of RNNs

While RNNs excel at processing sequences, they do have certain shortcomings worth considering:

  • Vanishing and Exploding Gradients: During training, gradients (which drive the learning process) can become either exceedingly small (vanishing) or extremely large (exploding). These issues make it difficult to effectively train RNNs, especially when dealing with long sequences where dependencies might span many timesteps.
  • Difficulty Handling Long-Term Dependencies: While RNNs have a ‘memory’ mechanism, in practice, they often struggle to retain and connect information across very long input sequences. This leads to errors when the relevant context lies many steps back.
  • Training Challenges: RNNs can be computationally expensive to train due to the sequential nature of calculations. They can be sensitive to hyperparameter choices, requiring careful tuning.
  • Limited Interpretability: It can be difficult to understand precisely how RNNs arrive at their outputs, making their decision-making process somewhat opaque.

Improvements

Researchers have developed techniques to address some of these challenges:

  • LSTM and GRU: Specialized RNN architectures, called Long Short-Term Memory (LSTM) and Gated Recurrent Units (GRU), feature gating mechanisms that help improve the preservation of information across longer sequences, addressing the vanishing/exploding gradient problem to some extent.
  • Attention Mechanisms: Used alongside RNNs, attention allows the model to selectively focus on the most relevant parts of the input sequence at each stage, improving its ability to handle long-term dependencies.
  • Transformers: In recent years, Transformer architectures have emerged as a powerful alternative to RNNs for many sequence processing tasks. Transformers rely entirely on attention mechanisms, avoiding the sequential processing bottleneck of RNNs. This enables them to better handle long-range dependencies and often leads to faster training and improved performance.

Example: Sequence Prediction

Let’s say we have a short text snippet: “The cat sat.” Our goal is to train a basic RNN to predict the next word.

  1. Vocabulary: Create a vocabulary (e.g., {“the”, “cat”, “sat”, “on”, “mat”})
  2. One-hot Encoding: Represent each word with a vector, with a ‘1’ at the word’s position and ‘0’ elsewhere.
  3. Training: We’d feed the words “the”, “cat”, and “sat” sequentially into the RNN. For each word, the RNN will compute a new hidden state, and at the end of the sequence predict the probability of the next word being “on”, “mat”, etc.

Implementing RNN from Scratch

For illustrative purposes, here’s a simplified Python implementation of an RNN to predict the next number in the modulo-7 Fibonacci sequence. The code is available in this colab notebook.

import numpy as np

# Activation function and its derivative
def tanh_derivative(x):
return 1 - np.tanh(x)**2

# Initialize RNN parameters
input_size = 1
hidden_size = 5 # Increased hidden size for capturing more complex patterns
output_size = 1

# Weights and biases
Wxh = np.random.randn(hidden_size, input_size) * 0.01 # input to hidden
Whh = np.random.randn(hidden_size, hidden_size) * 0.01 # hidden to hidden
Why = np.random.randn(output_size, hidden_size) * 0.01 # hidden to output
bh = np.zeros((hidden_size, 1)) # hidden bias
by = np.zeros((output_size, 1)) # output bias

# Learning rate
learning_rate = 0.005 # Adjusted learning rate for stability

def rnn_step_forward(x, h_prev):
h_next = np.tanh(np.dot(Wxh, x) + np.dot(Whh, h_prev) + bh)
y_pred = np.dot(Why, h_next) + by
return y_pred, h_next

def train_sequence(sequence, epochs=500): # Increased epochs for more training
global Wxh, Whh, Why, bh, by
for epoch in range(epochs):
h_prev = np.zeros((hidden_size, 1))
overall_loss = 0
for i in range(len(sequence) - 1):
x = np.array([[sequence[i]]])
y_true = np.array([[sequence[i + 1]]])
y_pred, h_prev = rnn_step_forward(x, h_prev)

# Loss and gradients
loss = (y_pred - y_true) ** 2 / 2
overall_loss += loss.item()

# Backward pass: compute gradients
dy_pred = (y_pred - y_true)
dWhy = np.dot(dy_pred, h_prev.T)
dby = dy_pred
dh_prev = np.dot(Why.T, dy_pred) + np.dot(Whh.T, tanh_derivative(h_prev) * np.dot(Why.T, dy_pred))
dWxh = np.dot(tanh_derivative(h_prev) * dh_prev, x.T)
dWhh = np.dot(tanh_derivative(h_prev) * dh_prev, h_prev.T)
dbh = tanh_derivative(h_prev) * dh_prev

# Update weights and biases
Wxh -= learning_rate * dWxh
Whh -= learning_rate * dWhh
Why -= learning_rate * dWhy
bh -= learning_rate * dbh
by -= learning_rate * dby

if epoch % 50 == 0:
print(f"Epoch {epoch}, Loss: {overall_loss}")

def predict_next_number(sequence):
h_prev = np.zeros((hidden_size, 1))
x = np.array([[sequence[-1]]])
y_pred, _ = rnn_step_forward(x, h_prev)
return y_pred.item() % 7 # Apply modulo to prediction

# Generate a modified Fibonacci sequence
def generate_mod_fibonacci(n):
sequence = [0, 1]
for _ in range(n-2):
next_val = (sequence[-1] + sequence[-2]) % 7
sequence.append(next_val)
return sequence

# Example: Generate, train, and predict
mod_fib_sequence = generate_mod_fibonacci(20) # Generate a longer sequence for more training data
print("Modified Fibonacci sequence:", mod_fib_sequence)

# Train the RNN
train_sequence(mod_fib_sequence, epochs=500)

# Predict the next number after training
predicted_next = predict_next_number(mod_fib_sequence)
print(f"Predicted next number (modulo 7): {predicted_next}")

Conclusion

RNNs are powerful tools for dealing with sequential data, capable of capturing temporal dynamics and dependencies. While this introduction and simple implementation scratch the surface, they offer a glimpse into the potential of RNNs for solving complex problems in various domains.

--

--