Kolmogorov–Arnold Networks (KAN) Are About To Change The AI World Forever

Forget Everything You Know About Neural Networks, KANs Are Here to Rewrite the Rules

Manushi Mukhi
Accredian
5 min readMay 10, 2024

--

In AI, the Multi-Layer Perceptron (MLP) stands as the bedrock, its neural architecture shaping the landscape of countless applications. Yet, Kolmogorov–Arnold Networks (KAN) seek to disrupt this foundation by reimagining the very essence of the neuron’s working in neural networks.

Source: PyKAN Github (https://github.com/KindXiaoming/pykan)

Introduction:

In the ever-evolving landscape of machine learning, a recent research paper titled “KAN: Kolmogorov–Arnold Network” has sparked a wave of excitement among enthusiasts. This innovative approach challenges the conventional wisdom of Multi-Layer Perceptron (MLPs), offering a fresh perspective on neural network architectures.

What are Kolmogorov–Arnold Networks (KAN’s):

At the heart of this groundbreaking concept lies the Kolmogorov-Arnold representation theorem, a mathematical theory developed by Vladimir Arnold and Andrey Kolmogorov. This theorem asserts that complex multivariate functions can be decomposed into simpler one-dimensional functions, laying the foundation for KANs’ unique structure.

Source: https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Arnold_representation_theorem

Now, The obvious question becomes what are these “simpler one-dimensional functions”. For anyone with a bit of knowledge of mathematics or computational graphics, we are talking about age old and trusted piecewise by polynomials known as Splines

Example of a B-Spline (Google.com)

The Secret Sauce of KAN, Splines!

Splines are mathematical functions that enable the creation of smooth curves by connecting a series of control points. Splines offer flexibility in adjusting the shape of the curve while ensuring continuity and smoothness between adjacent segments.

To create a spline, one typically starts with a set of control points that define the path of the curve. The curve is then constructed by interpolating or approximating the path between these control points using basis functions, such as B-splines or Bezier curves.

Source: Unity Manual | Getting Started with Splines

In essence, splines provide a versatile tool for representing complex curves or surfaces with precision and flexibility, making them invaluable in various fields.

But, how these Splines used and Exploited in an KAN Architecture?

Simplest Way To Understand the Working of KAN

KANs diverge from traditional MLPs by replacing fixed activation functions with learnable functions (B-Splines), along the edges of the network.

This adaptive architecture allows KANs to effectively model complex functions while maintaining interpretability and reducing the number of parameters needed.

Source: PyKAN Github (https://github.com/KindXiaoming/pykan)

Unlike their counterparts in MLPs, which serve as passive conduits for transmitting signals, neurons in KANs are active participants in the learning process, designed to dynamically shape their behavior in response to the data they encounter.

This transformative shift is made possible through the adoption of learnable activation functions situated at the edges of the network.

Source: PyKAN Github (https://github.com/KindXiaoming/pykan)

Leveraging the expressive power of B-Splines, these functions imbue KANs with unparalleled flexibility and adaptability, enabling them to navigate complex data landscapes with ease.

Major Advantages of KANs:

Enhanced Scalability

KANs demonstrate superior scalability compared to MLPs, particularly in high-dimensional data scenarios. Their ability to decompose complex functions into simpler components enables efficient processing of large datasets, making them ideal for tasks with vast amounts of information.

Improved Accuracy

Despite using fewer parameters, KANs achieve higher accuracy and lower loss than traditional MLPs across various tasks. This is attributed to their ability to adaptively model relationships within data, resulting in more precise predictions and better generalization to unseen examples.

Interpretable Models

The structure of KANs facilitates interpretability, enabling researchers to derive symbolic formulas that represent learned patterns effectively. Unlike black-box models, KANs offer insights into how input features are transformed throughout the network, enhancing transparency and understanding.

Now we know what are KAN’s and why are they such a big deal in Artificial Intelligence Landscape, but the world doesn’t move only by theories and models that looks good in papers.

But the best thing about KAN’s is they’re a very similar to scale and utilize in your own data science problems using the new Python Library “PyKAN”.

Let’s end our discussion with an Example of how to implement these architecture in Python

Python Implementation of KAN’s (PyKAN):

Let’s use a Classification Problem for our demonstration.

Creating the Dataset

We will create a synthetic dataset using the “make_moons” function of sklearn library.

import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
import torch
import numpy as np

dataset = {}
train_input, train_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)
test_input, test_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)

dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label)
dataset['test_label'] = torch.from_numpy(test_label)

X = dataset['train_input']
y = dataset['train_label']
plt.scatter(X[:,0], X[:,1], c=y[:])

Output (The Dataset Visualized)

Creating and Training a KAN

from kan import KAN

odel = KAN(width=[2,2], grid=3, k=3)

def train_acc():
return torch.mean((torch.argmax(model(dataset['train_input']),
dim=1) == dataset['train_label']).float())

def test_acc():
return torch.mean((torch.argmax(model(dataset['test_input']),
dim=1) == dataset['test_label']).float())

results = model.train(dataset, opt="LBFGS", steps=20,
metrics=(train_acc, test_acc),
loss_fn=torch.nn.CrossEntropyLoss())

Obtaining the Symbolic Formula from the Model

After this, a symbolic formula is derived that represents what the model has learned from the data.

formula1, formula2 = model.symbolic_formula()[0]

Calculating the Accuracy

Finally, the accuracy can be obtained from the learned formula

def acc(formula1, formula2, X, y):
batch = X.shape[0]
correct = 0
for i in range(batch):

logit1 = np.array(formula1.subs('x_1',
X[i,0]).subs('x_2', X[i,1])).astype(np.float64)
logit2 = np.array(formula2.subs('x_1', X[i,0]).subs('x_2',
X[i,1])).astype(np.float64)

correct += (logit2 > logit1) == y[i]

return correct/batch

# Print Accuracy
print('train acc of the formula:', acc(formula1,
formula2,
dataset['train_input'],
dataset['train_label']))

print('test acc of the formula:', acc(formula1,
formula2,
dataset['test_input'],
dataset['test_label']))

Output

train acc of the formula: tensor(0.9700)
test acc of the formula: tensor(0.9660)

Conclusion

In conclusion, Kolmogorov–Arnold Networks (KANs) represent a paradigm shift in neural network architecture. While further research and experimentation are needed to fully unlock their potential, KANs hold promise as a valuable tool for advancing machine learning and scientific discovery in the years to come.

As the field continues to evolve, KANs stand at the forefront of innovation, shaping the future of intelligent systems and revolutionizing the way we approach complex data analysis and modeling.

--

--

Manushi Mukhi
Accredian

Demystifying Data Science & AI for myself and others, Hop on the ride if you would like.