Kolmogorov-Arnold Networks (KANs) Explained: A Superior Alternative to MLPs
Mathematical deep dive into KANs and its advantages over conventional Neural Networks
While the GenAI world has got all the attention in the recent past, a great advancement was made recently in the field of Neural Networks i.e. MLPs where KANs got introduced which are claimed to be way better than MLPs, especially in capturing really complex relationships.
My debut book: “LangChain in your Pocket” is out now
In this post, I will try to break down this revolutionary, yet easy modification done to make neural networks better.
In this post, I will be covering
Kolmogorov-Arnold representation theorem
What is B-Splines?
What are KANs?
KANs Advantages and Disadvantages over MLPs
Before we start, we need to understand a very crucial theorem:
Kolmogorov-Arnold representation theorem
Imagine you want to approximate a very complex function f(x,y,z) that takes three inputs (x, y, z) and gives one output. The Kolmogorov-Arnold theorem states that you can break down this complex 3D function into two steps:
- Interior Step: First, you take each of the three inputs (x, y, z) and pass them through some simple 1D (univariate) functions ψ. These 1D functions ψ could be something like
ψ₁ = x²,
ψ₂ = sin(y)
ψ₃ = log(z+1)
You combine the outputs of these 1D functions using addition and multiplication with some constants. So eventually, after the end of the Interior step, we have got
A* x² + B*sin(y) + C*log(z+1) where A, B, C are constants
2. Outer Step: The result from the interior step is then passed through another 1D function Φ to give the final output.
So, we have finally got an approximation i.e.
f(x,y,z) = Φ(A* x² + B*sin(y) + C*log(z+1))
if this Φ = √x then
Φ(A* x² + B*sin(y) + C*log(z+1)) = D*√A* x² + B*sin(y) + C*log(z+1)
where A, B, C, D are all constants
So instead of trying to approximate the complex 3D function f(x,y,z) directly, you break it down into smaller 1D pieces (ψ functions) in the interior, and then combine their outputs using another 1D function (Φ) in the outer part.
This two-step process can theoretically approximate any continuous multivariate function f(x,y,z) to arbitrary accuracy, according to the theorem.
Note: This is an oversimplified version of the theorem, the approximation is comparatively more complex
The 1D functions we talked about in the above explanation aren’t as easy as the ones we mentioned above but are B-Splines.
What is B-Splines?
Splines are a type of function that are defined by piecewise polynomials and are smoothly connected at certain points called knots.
A piecewise polynomial is a function that uses different polynomial expressions over different intervals of its domain. For example:
m*x+a → x<5
m*x+ n*x² → 5<x<10
p*x³ → x>10
As you can see, depending on the interval of 𝑥, the polynomial representation changes. Knot points=5,10 (where the transition is happening)
Depending on the degree of the splines, different possible basis functions to construct the original, more complex function. An example for B-Spline with different degrees can be:
Degree 0: basis functions = 𝑓(𝑥)=1
Degree 1: basis functions = 𝑓(𝑥)=1, 𝑓(𝑥)=𝑥
Degree 2: basis functions = 𝑓(𝑥)=1, f(x)=x, 𝑓(𝑥)=x²
and so on. The final B-Spline function would be summation of all basis functions defined.
Basis function: These are a set of simple functions that can be used to represent complex non-linear functions.
Assume we have fₓ = 5 + 2x²
Now, to represent this complex non-linear function, we can have the below set of basis function: f₁(x)= 1, f₂(x)=x, f₃(x)= x²
Hence, fₓ = 5*f₁(x) + 0*f₂(x) + 2*f₃(x)
Now as we know almost everything required, lets jump onto
What are KANs?
Before we jump onto KANs, let’s understand how a baseline MLP network works:
Assuming x1, x2, x3,…, xn are the input for a hidden layer H. The output for this hidden layer would be
Output_hidden =
activation_function(w1*x1 + w2*x2 …+wn*xn)
Now, if you notice, a major issue with this equation if we have got a highly complex, non-linear relation to capture, this might take more hidden layers as the non-linearity is introduced just by the activation function. The w1*x1 + w2*x2 …+wn*xn part is linear and hence doesn’t contribute to capturing non-linearity.
What if even this can be non-linear so that capturing of more complex non-linear relations becomes faster and efficient?
This is what KANs do!!
In MLPs, each connection between neurons has a single scalar weight value (w). The weighted sum of inputs is then passed through a fixed non-linear activation function like ReLU or sigmoid.
In contrast, in KANs, each connection or “edge” between nodes has a learnable non-linear activation function instead of a single scalar weight. Hence, it replaces scalar float weights by non-linear learnable functions.
KANs have no linear weight matrices at all
What are these non-linear learnable functions?
You remember Kolmogorov-Arnold representation theorem we discussed in the beginning of the post? The 1D functions (B-Splines) we discussed there are the non-linear learnable functions!
So, in case of MLP, we
Multiple weights to inputs (w1*x1,w2*x2,…wn*xn)
Add these values (w1*x1+w2*x2+wn*xn)
Apply activation function (σ(w1*x1+w2*x2+wn*xn))
But in case of KAN
Apply different 1D learnable function to inputs (ψ₁*x1,ψ₂*x2,ψ₃*x3)
Add these values (ψ₁*x1 + ψ₂*x2 + ψ₃*x3)
Apply another activation function on the summation which is again a learnable 1D function (Φ(ψ₁*x1 + ψ₂*x2 + ψ₃*x3)). where Φ is another 1D learnable function
A query that came to my mind that why we are calling these 1D functions as learnable? because these functions must be having some tunable parameters say
ψ₁= A*x² + B*x where both A & B are to be learnt.
Similarly
ψ₂=C*sin(y)+D*cos(y) where C & D are learnable parameters
What are the advantages of using KANs?
The paper claims that much smaller KANs can achieve comparable or better accuracy than much larger MLPs in tasks like data fitting and solving partial differential equations (PDEs)
Theoretically and empirically, KANs are shown to possess faster neural scaling laws than MLPs, meaning their performance improves more rapidly as the model size increases.
KANs can be intuitively visualized since they use learnable activation functions on the edges instead of fixed activations on nodes like MLPs.
But the grass isn’t greener on the other side always. There are a few cons as well which we should discuss before ending this long post
Training KANs may be very slow as instead of a single scalar weight value for each input, we now might be tuning multiple parameters (used in the learnable 1D function) which is a overhead.
Heavy computational resources might be required for training purposes
As the model itself require heavy computational resources, training it with big datasets might be challenging as well.
Nonetheless, the metrics looks very fascinating and KANs are surely worth giving a try. With this, it's a wrap