Implement the xLSTM paper from scratch with Pytorch

Arthur Lagacherie
11 min readJul 1, 2024

--

You want to implement a simple research paper ? Or just find out more about xLSTM ? You’ve come to the right place.

pytorch logo

In this article, I’ll show you how to implement from scratch the xLSTM paper. If you don’t know what is the xLSTM architecture you can read this article to see the basics of the math behind xLSTM.

For a basic explication, the xLSTM model is an improvement of the LSTM model. It introduces two new cell types :

  • mLSTM that is fully parallelizable with a matrix memory and a covariance update rule.
  • sLSTM with a scalar memory, a scalar update, and new memory mixing.

Summary

This article will be divided into several parts :

  • the mLSTM implementation
  • the sLSTM implementation
  • xLSTM implementation
  • and the test on the tiny Shakespeare dataset

mLSTM

You can find all the code on my repository.

mLSTM architecture

Proceed step by step, firstly we can implement the two first layers and init states.

class mLSTMblock(nn.Module):
def __init__(self, x_example, factor, depth, dropout=0.2):
super().__init__()
self.input_size = x_example.shape[2]
self.hidden_size = int(self.input_size*factor)

self.ln = nn.LayerNorm(self.input_size)

self.left = nn.Linear(self.input_size, self.hidden_size)
self.right = nn.Linear(self.input_size, self.hidden_size)

self.init_states(x_example)

def init_states(self, x_example):
self.ct_1 = torch.zeros([1, 1, self.hidden_size], device=x_example.device)
self.nt_1 = torch.zeros([1, 1, self.hidden_size], device=x_example.device)

def forward(self, x):
assert x.ndim == 3

x = self.ln(x) # layer norm on x

left = self.left(x) # part left
right = F.silu(self.right(x)) # part right with just swish (silu) function

Second, we can implement the causal convolution (conv4) and the diagonal block (squares with “BS=4” marker underneath).

import torch
import torch.nn as nn
import torch.nn.functional as F

class BlockDiagonal(nn.Module):
def __init__(self, in_features, out_features, num_blocks, bias=True):
super(BlockDiagonal, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.num_blocks = num_blocks

assert out_features % num_blocks == 0

block_out_features = out_features // num_blocks

self.blocks = nn.ModuleList([
nn.Linear(in_features, block_out_features, bias=bias)
for _ in range(num_blocks)
])

def forward(self, x):
x = [block(x) for block in self.blocks]
x = torch.cat(x, dim=-1)
return x

class CausalConv1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
super(CausalConv1D, self).__init__()
self.padding = (kernel_size - 1) * dilation
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.padding, dilation=dilation, **kwargs)

def forward(self, x):
x = self.conv(x)
return x[:, :, :-self.padding]

Thirdly, we can implement the left part (conv, diagonal block, layer skip, and mLSTM).

class mLSTMblock(nn.Module):
def __init__(self, x_example, factor, depth, dropout=0.2):
super().__init__()
self.input_size = x_example.shape[2]
self.hidden_size = int(self.input_size*factor)

self.ln = nn.LayerNorm(self.input_size)

self.left = nn.Linear(self.input_size, self.hidden_size)
self.right = nn.Linear(self.input_size, self.hidden_size)

self.conv = CausalConv1D(self.hidden_size, self.hidden_size, int(self.input_size/10))
self.drop = nn.Dropout(dropout+0.1)

self.lskip = nn.Linear(self.hidden_size, self.hidden_size)

self.wq = BlockDiagonal(self.hidden_size, self.hidden_size, depth)
self.wk = BlockDiagonal(self.hidden_size, self.hidden_size, depth)
self.wv = BlockDiagonal(self.hidden_size, self.hidden_size, depth)
self.dropq = nn.Dropout(dropout/2)
self.dropk = nn.Dropout(dropout/2)
self.dropv = nn.Dropout(dropout/2)

self.i_gate = nn.Linear(self.hidden_size, self.hidden_size)
self.f_gate = nn.Linear(self.hidden_size, self.hidden_size)
self.o_gate = nn.Linear(self.hidden_size, self.hidden_size)

self.ln_c = nn.LayerNorm(self.hidden_size)
self.ln_n = nn.LayerNorm(self.hidden_size)

self.lnf = nn.LayerNorm(self.hidden_size)
self.lno = nn.LayerNorm(self.hidden_size)
self.lni = nn.LayerNorm(self.hidden_size)

self.drop2 = nn.Dropout(dropout)

def init_states(self, x_example):
self.ct_1 = torch.zeros([1, 1, self.hidden_size], device=x_example.device)
self.nt_1 = torch.zeros([1, 1, self.hidden_size], device=x_example.device)

def forward(self, x):
assert x.ndim == 3

x = self.ln(x) # layer norm on x

left = self.left(x) # part left
right = F.silu(self.right(x)) # part right with just swish (silu) function

left_left = left.transpose(1, 2)
left_left = F.silu( self.drop( self.conv( left_left ).transpose(1, 2) ) )
l_skip = self.lskip(left_left)

# start mLSTM
q = self.dropq(self.wq(left_left))
k = self.dropk(self.wk(left_left))
v = self.dropv(self.wv(left))

i = torch.exp(self.lni(self.i_gate(left_left)))
f = torch.exp(self.lnf(self.f_gate(left_left)))
o = torch.sigmoid(self.lno(self.o_gate(left_left)))

ct_1 = self.ct_1
ct = f*ct_1 + i*v*k
ct = torch.mean(self.ln_c(ct), [0, 1], keepdim=True)
self.ct_1 = ct.detach()

nt_1 = self.nt_1
nt = f*nt_1 + i*k
nt =torch.mean( self.ln_n(nt), [0, 1], keepdim=True)
self.nt_1 = nt.detach()

ht = o * ((ct*q) / torch.max(nt*q)) # [batchs_size, ?, hiddden_size]
# end mLSTM
ht = ht

left = self.drop2(self.GN(ht + l_skip))

And to finish the final layer.

import torch
import torch.nn as nn
import torch.nn.functional as F
from xLSTM.utils import BlockDiagonal, CausalConv1D

class mLSTMblock(nn.Module):
def __init__(self, x_example, factor, depth, dropout=0.2):
super().__init__()
self.input_size = x_example.shape[2]
self.hidden_size = int(self.input_size*factor)

self.ln = nn.LayerNorm(self.input_size)

self.left = nn.Linear(self.input_size, self.hidden_size)
self.right = nn.Linear(self.input_size, self.hidden_size)

self.conv = CausalConv1D(self.hidden_size, self.hidden_size, int(self.input_size/10))
self.drop = nn.Dropout(dropout+0.1)

self.lskip = nn.Linear(self.hidden_size, self.hidden_size)

self.wq = BlockDiagonal(self.hidden_size, self.hidden_size, depth)
self.wk = BlockDiagonal(self.hidden_size, self.hidden_size, depth)
self.wv = BlockDiagonal(self.hidden_size, self.hidden_size, depth)
self.dropq = nn.Dropout(dropout/2)
self.dropk = nn.Dropout(dropout/2)
self.dropv = nn.Dropout(dropout/2)

self.i_gate = nn.Linear(self.hidden_size, self.hidden_size)
self.f_gate = nn.Linear(self.hidden_size, self.hidden_size)
self.o_gate = nn.Linear(self.hidden_size, self.hidden_size)

self.ln_c = nn.LayerNorm(self.hidden_size)
self.ln_n = nn.LayerNorm(self.hidden_size)

self.lnf = nn.LayerNorm(self.hidden_size)
self.lno = nn.LayerNorm(self.hidden_size)
self.lni = nn.LayerNorm(self.hidden_size)

self.GN = nn.LayerNorm(self.hidden_size)
self.ln_out = nn.LayerNorm(self.hidden_size)

self.drop2 = nn.Dropout(dropout)

self.proj = nn.Linear(self.hidden_size, self.input_size)
self.ln_proj = nn.LayerNorm(self.input_size)

self.init_states(x_example)

def init_states(self, x_example):
self.ct_1 = torch.zeros([1, 1, self.hidden_size], device=x_example.device)
self.nt_1 = torch.zeros([1, 1, self.hidden_size], device=x_example.device)

def forward(self, x):
assert x.ndim == 3

x = self.ln(x) # layer norm on x

left = self.left(x) # part left
right = F.silu(self.right(x)) # part right with just swish (silu) function

left_left = left.transpose(1, 2)
left_left = F.silu( self.drop( self.conv( left_left ).transpose(1, 2) ) )
l_skip = self.lskip(left_left)

# start mLSTM
q = self.dropq(self.wq(left_left))
k = self.dropk(self.wk(left_left))
v = self.dropv(self.wv(left))

i = torch.exp(self.lni(self.i_gate(left_left)))
f = torch.exp(self.lnf(self.f_gate(left_left)))
o = torch.sigmoid(self.lno(self.o_gate(left_left)))

ct_1 = self.ct_1
ct = f*ct_1 + i*v*k
ct = torch.mean(self.ln_c(ct), [0, 1], keepdim=True)
self.ct_1 = ct.detach()

nt_1 = self.nt_1
nt = f*nt_1 + i*k
nt =torch.mean( self.ln_n(nt), [0, 1], keepdim=True)
self.nt_1 = nt.detach()

ht = o * ((ct*q) / torch.max(nt*q)) # [batchs_size, ?, hiddden_size]
# end mLSTM
ht = ht

left = self.drop2(self.GN(ht + l_skip))

out = self.ln_out(left * right)
out = self.ln_proj(self.proj(out))

return out

sLSTM

You can find all the code on my repository.

Idem for the sLSTM we can proceed step by step.

First, we will implement the diagonal blocks, conv, and the sLSTM.

import torch
import torch.nn as nn
import torch.nn.functional as F
from xLSTM.utils import BlockDiagonal, CausalConv1D

class sLSTMblock(nn.Module):
def __init__(self, x_example, depth, dropout=0.2):
super().__init__()
self.input_size = x_example.shape[2]
conv_channels = x_example.shape[1]

self.ln = nn.LayerNorm(self.input_size)

self.conv = CausalConv1D(self.input_size, self.input_size, int(self.input_size/8))
self.drop = nn.Dropout(dropout)

self.i_gate = BlockDiagonal(self.input_size, self.input_size, depth)
self.f_gate = BlockDiagonal(self.input_size, self.input_size, depth)
self.o_gate = BlockDiagonal(self.input_size, self.input_size, depth)
self.z_gate = BlockDiagonal(self.input_size, self.input_size, depth)

self.ri_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)
self.rf_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)
self.ro_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)
self.rz_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)

self.ln_i = nn.LayerNorm(self.input_size)
self.ln_f = nn.LayerNorm(self.input_size)
self.ln_o = nn.LayerNorm(self.input_size)
self.ln_z = nn.LayerNorm(self.input_size)

self.GN = nn.LayerNorm(self.input_size)
self.ln_c = nn.LayerNorm(self.input_size)
self.ln_n = nn.LayerNorm(self.input_size)
self.ln_h = nn.LayerNorm(self.input_size)

def init_states(self, x):
self.nt_1 = torch.zeros(1, 1, x.shape[2], device=x.device)
self.ct_1 = torch.zeros(1, 1, x.shape[2], device=x.device)
self.ht_1 = torch.zeros(1, 1, x.shape[2], device=x.device)
self.mt_1 = torch.zeros(1, 1, x.shape[2], device=x.device)

def forward(self, x):
x = self.ln(x)

x_conv = F.silu( self.drop(self.conv( x.transpose(1, 2) ).transpose(1, 2) ) )

# start sLSTM
ht_1 = self.ht_1

i = torch.exp(self.ln_i( self.i_gate(x_conv) + self.ri_gate(ht_1) ) )
f = torch.exp( self.ln_f(self.f_gate(x_conv) + self.rf_gate(ht_1) ) )

m = torch.max(torch.log(f)+self.mt_1[:, 0, :].unsqueeze(1), torch.log(i))
i = torch.exp(torch.log(i) - m)
f = torch.exp(torch.log(f) + self.mt_1[:, 0, :].unsqueeze(1)-m)
self.mt_1 = m.detach()

o = torch.sigmoid( self.ln_o(self.o_gate(x) + self.ro_gate(ht_1) ) )
z = torch.tanh( self.ln_z(self.z_gate(x) + self.rz_gate(ht_1) ) )

ct_1 = self.ct_1
ct = f*ct_1 + i*z
ct = torch.mean(self.ln_c(ct), [0, 1], keepdim=True)
self.ct_1 = ct.detach()

nt_1 = self.nt_1
nt = f*nt_1 + i
nt = torch.mean(self.ln_n(nt), [0, 1], keepdim=True)
self.nt_1 = nt.detach()

ht = o*(ct/nt) # torch.Size([4, 8, 16])
ht = torch.mean(self.ln_h(ht), [0, 1], keepdim=True)
self.ht_1 = ht.detach()
# end sLSTM

slstm_out = self.GN(ht)

Second and to finish the final layers.

import torch
import torch.nn as nn
import torch.nn.functional as F
from xLSTM.utils import BlockDiagonal, CausalConv1D

class sLSTMblock(nn.Module):
def __init__(self, x_example, depth, dropout=0.2):
super().__init__()
self.input_size = x_example.shape[2]
conv_channels = x_example.shape[1]

self.ln = nn.LayerNorm(self.input_size)

self.conv = CausalConv1D(self.input_size, self.input_size, int(self.input_size/8))
self.drop = nn.Dropout(dropout)

self.i_gate = BlockDiagonal(self.input_size, self.input_size, depth)
self.f_gate = BlockDiagonal(self.input_size, self.input_size, depth)
self.o_gate = BlockDiagonal(self.input_size, self.input_size, depth)
self.z_gate = BlockDiagonal(self.input_size, self.input_size, depth)

self.ri_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)
self.rf_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)
self.ro_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)
self.rz_gate = BlockDiagonal(self.input_size, self.input_size, depth, bias=False)

self.ln_i = nn.LayerNorm(self.input_size)
self.ln_f = nn.LayerNorm(self.input_size)
self.ln_o = nn.LayerNorm(self.input_size)
self.ln_z = nn.LayerNorm(self.input_size)

self.GN = nn.LayerNorm(self.input_size)
self.ln_c = nn.LayerNorm(self.input_size)
self.ln_n = nn.LayerNorm(self.input_size)
self.ln_h = nn.LayerNorm(self.input_size)

self.left_linear = nn.Linear(self.input_size, int(self.input_size*(4/3)))
self.right_linear = nn.Linear(self.input_size, int(self.input_size*(4/3)))

self.ln_out = nn.LayerNorm(int(self.input_size*(4/3)))

self.proj = nn.Linear(int(self.input_size*(4/3)), self.input_size)

self.init_states(x_example)

def init_states(self, x):
self.nt_1 = torch.zeros(1, 1, x.shape[2], device=x.device)
self.ct_1 = torch.zeros(1, 1, x.shape[2], device=x.device)
self.ht_1 = torch.zeros(1, 1, x.shape[2], device=x.device)
self.mt_1 = torch.zeros(1, 1, x.shape[2], device=x.device)

def forward(self, x):
x = self.ln(x)

x_conv = F.silu( self.drop(self.conv( x.transpose(1, 2) ).transpose(1, 2) ) )

# start sLSTM
ht_1 = self.ht_1

i = torch.exp(self.ln_i( self.i_gate(x_conv) + self.ri_gate(ht_1) ) )
f = torch.exp( self.ln_f(self.f_gate(x_conv) + self.rf_gate(ht_1) ) )

m = torch.max(torch.log(f)+self.mt_1[:, 0, :].unsqueeze(1), torch.log(i))
i = torch.exp(torch.log(i) - m)
f = torch.exp(torch.log(f) + self.mt_1[:, 0, :].unsqueeze(1)-m)
self.mt_1 = m.detach()

o = torch.sigmoid( self.ln_o(self.o_gate(x) + self.ro_gate(ht_1) ) )
z = torch.tanh( self.ln_z(self.z_gate(x) + self.rz_gate(ht_1) ) )

ct_1 = self.ct_1
ct = f*ct_1 + i*z
ct = torch.mean(self.ln_c(ct), [0, 1], keepdim=True)
self.ct_1 = ct.detach()

nt_1 = self.nt_1
nt = f*nt_1 + i
nt = torch.mean(self.ln_n(nt), [0, 1], keepdim=True)
self.nt_1 = nt.detach()

ht = o*(ct/nt) # torch.Size([4, 8, 16])
ht = torch.mean(self.ln_h(ht), [0, 1], keepdim=True)
self.ht_1 = ht.detach()
# end sLSTM

slstm_out = self.GN(ht)

left = self.left_linear(slstm_out)
right = F.gelu(self.right_linear(slstm_out))

out = self.ln_out(left*right)
out = self.proj(out)
return out

xLSTM

To finish with the implementations, we just need to combine the two cells to create a real xLSTM model.

import torch
import torch.nn as nn
import torch.nn.functional as F
from xLSTM.mLSTMblock import mLSTMblock
from xLSTM.sLSTMblock import sLSTMblock

class xLSTM(nn.Module):
def __init__(self, layers, x_example, depth=4, factor=2):
super(xLSTM, self).__init__()

self.layers = nn.ModuleList()
for layer_type in layers:
if layer_type == 's':
layer = sLSTMblock(x_example, depth)
elif layer_type == 'm':
layer = mLSTMblock(x_example, factor, depth)
else:
raise ValueError(f"Invalid layer type: {layer_type}. Choose 's' for sLSTM or 'm' for mLSTM.")
self.layers.append(layer)

def init_states(self, x):
[l.init_states(x) for l in self.layers]

def forward(self, x):
x_original = x.clone()
for l in self.layers:
x = l(x) + x_original

return x

And that’s all there is to the implementation part, now we can test a model on the tiny Shakespear dataset. I’ll spare you the code, but you can always find it on my repository.

Test

Code in my repository :

The hyperparameters are :

# hyperparameters
batch_size = 24 # how many independent sequences will we process in parallel?
block_size = 124 # what is the maximum context length for predictions? # impact little
eval_iters = 3 # more fast ( when it's low )
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
n_embd = 164 # impact big

dropout = 0.2 # no impact
config_block = "msm"
num_heads = 8 # define number of block diagonal
head_size = 4 # define hiddden size
# ------------ 6 196 201 parameters
First epoch :

step 0: train loss 4.3270, val loss 4.3219
step 100: train loss 2.5443, val loss 2.5433
step 200: train loss 2.2394, val loss 2.2802
step 300: train loss 2.0870, val loss 2.1471
step 400: train loss 1.9984, val loss 2.0720
step 500: train loss 1.9027, val loss 1.9928
step 600: train loss 1.8615, val loss 1.9432
step 700: train loss 1.8409, val loss 1.9058
step 800: train loss 1.7475, val loss 1.8943
step 900: train loss 1.7450, val loss 1.8395
step 1000: train loss 1.7428, val loss 1.8279
step 1100: train loss 1.6523, val loss 1.7903
step 1200: train loss 1.6506, val loss 1.7470
step 1300: train loss 1.6484, val loss 1.7698
step 1400: train loss 1.5625, val loss 1.7178
step 1500: train loss 1.5710, val loss 1.7284
step 1600: train loss 1.5428, val loss 1.7289
step 1700: train loss 1.5986, val loss 1.7167
step 1800: train loss 1.5377, val loss 1.7134
step 1900: train loss 1.5549, val loss 1.7062
step 2000: train loss 1.5613, val loss 1.6671
step 2100: train loss 1.5179, val loss 1.6790
step 2200: train loss 1.5278, val loss 1.6495
step 2300: train loss 1.4849, val loss 1.6532
step 2400: train loss 1.5001, val loss 1.6386
step 2500: train loss 1.5027, val loss 1.6386
step 2600: train loss 1.4460, val loss 1.6156
step 2700: train loss 1.4861, val loss 1.6227
step 2800: train loss 1.4734, val loss 1.6330
step 2900: train loss 1.4346, val loss 1.6448
step 3000: train loss 1.4397, val loss 1.6578
step 3100: train loss 1.4703, val loss 1.6631
step 3200: train loss 1.4505, val loss 1.6331
step 3300: train loss 1.4627, val loss 1.6056

Second epoch :
step 0: train loss 1.4553, val loss 1.6180
step 100: train loss 1.4511, val loss 1.5716
step 200: train loss 1.4466, val loss 1.6184
step 300: train loss 1.4587, val loss 1.5864
step 400: train loss 1.4191, val loss 1.6034
step 500: train loss 1.4374, val loss 1.5921
step 600: train loss 1.4766, val loss 1.5779
step 700: train loss 1.4460, val loss 1.6704
step 800: train loss 1.4235, val loss 1.6071
step 900: train loss 1.3958, val loss 1.5877
step 1000: train loss 1.4000, val loss 1.6223
step 1100: train loss 1.4050, val loss 1.5730
step 1200: train loss 1.4398, val loss 1.5611
step 1300: train loss 1.3808, val loss 1.5550
step 1400: train loss 1.3967, val loss 1.5965
step 1500: train loss 1.3856, val loss 1.5943
step 1600: train loss 1.3697, val loss 1.5646
step 1700: train loss 1.3749, val loss 1.6092
step 1800: train loss 1.4000, val loss 1.5567
step 1900: train loss 1.3721, val loss 1.5831
step 2000: train loss 1.4092, val loss 1.6221
step 2100: train loss 1.3879, val loss 1.5526
step 2200: train loss 1.4292, val loss 1.6169
step 2300: train loss 1.3557, val loss 1.5500
step 2400: train loss 1.4180, val loss 1.5842
step 2500: train loss 1.3877, val loss 1.5540
step 2600: train loss 1.3817, val loss 1.5708
step 2700: train loss 1.3729, val loss 1.5689
step 2800: train loss 1.3902, val loss 1.6510
step 2900: train loss 1.3939, val loss 1.5258
step 3000: train loss 1.3804, val loss 1.5745
step 3100: train loss 1.3531, val loss 1.5454
step 3200: train loss 1.3447, val loss 1.5827
step 3300: train loss 1.3709, val loss 1.5592

prediction :

And of give I loldiied out ways.

ESCALUS:
If youEly kinsbatter:
Sweet thieat, I may not to hour, eye you his royal that my heart. Lody of us, good lord.

Clifford, my heavy
gatch’d my fear turnsit mhalf
Antecia what he gone, for it is: where is helpased to gody,
Againss the noisest, that
unsweal bethis futwarsisty, how clesp: to my hold.
Whose ’tis wanddy.

LUEEN:
Is much caprient show’d in bree letmaging that
else partions you ort and King Hence! Vinceland!
’tis pricking of I’ll,
Unlesson, so

Conclusion

The model isn’t too bad for its very small number of parameters, it overfits quite well but learns well.

I hope you have been interested in this article and if this is the case you can clap it and follow me. =)

--

--

Arthur Lagacherie

I am a French high school student with a passion for artificial intelligence. I like to share my curiosity with others.👍