Understanding xLSTM through code implementation
Recently a new research paper has been released, called “xLSTM: Extended Long Short-Term Memory”. He introduced a new model architecture specialise in NLP (Natural Language Processing).
If you don’t know what the xLSTM is, I’ll simply explain it to you. The xLSTM model is just an improvement of the LSTM model. It’s composed of two new cells: the mLSTM cell which has a good matrix memory, and the sLSTM which is good in memory selection.
The article is structured into three sections:
- the implementation of mLSTM,
- the implementation of sLSTM,
- and the testing of xLSTM in NLP. While I do not fully implement xLSTM, I utilize this repository for testing purposes, (which is my).
sLSTM implementation
We begin with the mathematical formula of the sLSTM.
Witch can be divided into several parts.
The gates:
And the states:
Process step by step, first, we need to define the variable we will use for the implementation.
import torch
import torch.nn as nn
inp_size = 3
out_size = 5
batch_size = 1
x = torch.zeros(1, 3)
h = torch.zeros(batch_size, out_size)
ht_1 = torch.zeros(batch_size, out_size) # ht-1
ct_1 = torch.zeros(batch_size, out_size)
nt_1 = torch.zeros(batch_size, out_size)
Second, we can implement the forward pass (the gates) with z, i, f, and o, it’s just the sum of two neural network layers. One has on input the prev state (ht-1) and the other has the input (xt).
Then we apply on the sum a function (σ represents sigmoid, the strange s represents the tanh function and the exp represents the exponential function).
z_gate = nn.Linear(inp_size, out_size)
i_gate = nn.Linear(inp_size, out_size)
f_gate = nn.Linear(inp_size, out_size)
o_gate = nn.Linear(inp_size, out_size)
zr_gate = nn.Linear(out_size, out_size, bais=False)
ir_gate = nn.Linear(out_size, out_size, bais=False)
fr_gate = nn.Linear(out_size, out_size, bais=False)
or_gate = nn.Linear(out_size, out_size, bais=False)
z = torch.tanh(z_gate(x) + zr_gate(ht_1)) # (1, out_size)
i = torch.exp(i_gate(x) + ir_gate(ht_1)) # (1, out_size)
f = torch.exp(f_gate(x) + fr_gate(ht_1)) # (1, out_size)
o = torch.sigmoid(o_gate(x) + or_gate(ht_1)) # (1, out_size)
Third, we can implement the hidden state and the output.
c = f*c_1 + i*z
n = f*nt_1 + i
h = c/n
h = o * h # (1, out_size)
ct_1 = c
nt_1 = n
ht_1 = h
# output = h (batch_size, out_size)
mLSTM implementation
Like the sLSTM, we begin with just the formula.
Witch is divided into three parts :
-gates (I, f, o)
-inputs (v, k, q)
-states (h, n, C)
First, let’s define the variables
inp_size = 3
out_size = 5
batch_size = 1
x = torch.zeros(batch_size, inp_size)
Ct_1 = torch.zeros(batch_size, out_size)
nt_1 = torch.zeros(batch_size, out_size)
Second, implement the gates
i_gate = nn.Linear(inp_size, out_size)
f_gate = nn.Linear(inp_size, out_size)
o_gate = nn.Linear(inp_size, out_size)
i = torch.exp(i_gate(x)) # (batch_size, out_size)
f = torch.exp(f_gate(x)) # (batch_size, out_size)
o = torch.sigmoid(o_gate(x)) # (batch_size, out_size)
Third implement the inputs
query = nn.Linear(inp_size, out_size)
key = nn.Linear(inp_size, out_size)
value = nn.Linear(inp_size, out_size)
q = query(x) # (batch_size, out_size)
k = key(x) / torch.rsqrt(k.shape[-1]) # (batch_size, out_size)
v = value(x) # (batch_size, out_size)
And to finish we can implement the states.
C = f*Ct_1 + i * v*k # (batch_size, out_size)
n = f*nt_1 +i*k # (batch_size, out_size)
h = C*q / torch.max(torch.abs(n*q)) # (batch_size, out_size)
h = o*h # (batch_size, out_size)
nt_1 = n
Ct_1 = C
"""
if you want to create a class for the mlstm and slstm, don't forget to do :
nt_1 = n.detach()
Ct_1 = C.detach()
"""
Test
For the implementation of xLSTM, which is simply mLSTM combined with sLSTM, I use this repository (which is my own), you can install it by executing this command.
pip install git+https://github.com/styalai/xLSTM-pytorch
For the test, we will use the Tiny Shakespeare dataset from Karpathy. First, we need to install and import the necessary libraries.
import time
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
from math import *
from xLSTM.xLSTM import xLSTM as xlstm
After we download the dataset and create a function to get the inputs.
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
print(len(text))
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.1*len(data)) # first 90% will be train, rest val
train_data = data[n:]
val_data = data[:n]
# data loading
def get_batch(split, block_size, batch_size):
# generate a small batch of data of inputs x and targets y
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
Second, we can implement the model and estimate_loss function.
class Model(nn.Module):
def __init__(self, vocab_size, x, config_layers, device):
super().__init__()
self.vocab_size = vocab_size
self.n_embd = x.shape[2]
self.block_size = x.shape[1]
self.device = device
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(self.vocab_size, self.n_embd)
self.position_embedding_table = nn.Embedding(self.block_size, self.n_embd)
self.xlstm = xlstm(config_layers, x)
self.ln_f = nn.LayerNorm(self.n_embd)
self.head = nn.Linear(self.n_embd, self.vocab_size)
def init_states(self, x):
self.xlstm.init_states(x)
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets are both (B,T) tensor of integers
tok_emb = self.token_embedding_table(idx) # (B,T,C)
pos_emb = self.position_embedding_table(torch.arange(T, device=self.device)) # T, C
x = tok_emb + pos_emb # (B, T, C)
x = self.xlstm(x)
x = self.ln_f(x)
logits = self.head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# crop idx to the last self.block_size tokens
idx_cond = idx[:, -self.block_size:]
# get the predictions
logits, loss = self(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx, idx_next
@torch.no_grad()
def estimate_loss(model, eval_iters, block_size, batch_size):
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split, block_size, batch_size)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
Third, initialize the hyperparameters and get the number of parameters.
# hyperparameters
batch_size = 24 # how many independent sequences will we process in parallel?
block_size = 124 # what is the maximum context length for predictions? # impact little
eval_iters = 3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
n_embd = 164
dropout = 0.2
config_block = "m" # just a mLSTM
num_heads = 8 # define number of block diagonal
head_size = 4 # define hiddden size
# ------------
torch.set_default_device(device)
x = torch.zeros(batch_size, block_size, n_embd)
model = Transformer(vocab_size, x, config_block, device)
m = model.to(device)
paras = list(str(sum(p.numel() for p in m.parameters())))
print(paras)
Fourth, we can develop the training function.
def train(m):
learning_rate = 3e-4
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 300, eta_min=1e-5)
loss_list_t = []
loss_list_v = []
max_iters_total = 80000
max_iters = int(max_iters_total/batch_size)
eval_interval = 100
idx = 0
chunk_size = 25
for iter in tqdm(range(max_iters)):
# every once in a while evaluate the loss on train and val sets
if iter % eval_interval == 0:
torch.save(m.state_dict(), '/kaggle/working/model')
print(optimizer.param_groups[0]["lr"])
losses = estimate_loss(m, eval_iters, block_size, batch_size)
loss_list_t.append(losses['train'].cpu())
loss_list_v.append(losses['val'].cpu())
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
if iter % chunk_size == 0:
m.xlstm.init_states(x)
idx = random.randint(0, 1000000)
# sample a batch of data
xb, yb = get_batch('train', block_size, batch_size)
# evaluate the loss
logits, loss = m(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
torch.save(m.state_dict(), '/kaggle/working/model')
# draw loss
print(loss)
plt.plot(range(len(loss_list_t)), loss_list_t)
plt.plot(range(len(loss_list_v)), loss_list_v)
plt.xlabel("Number of Iterations")
plt.ylabel("Loss")
plt.show()
And finally train the model.
model = Model(vocab_size, x, config_block, device) # ≈ 18M parameters
model = model.to(device)
train(model) # 80 000 iters
We can see the model learns well but the value loss doesn’t decrease very well. And after 80k iters more :
Conclusion: the model mLSTM (my implementation) overfitting a little but isn’t very bad.
To find out more about xLSTM, read this article article
And if you want to implement the xLSTM architecture not just the math you can read this article:
I hope you have been interested in this article and if this is the case you can clap it. =)