Understand Spiking Neural Network through code implementation
If you want to understand how spiking neural networks work and implement them with Pytorch, you’ve come to the right place.
Introduction
Introduction Before diving into the topic, let’s start with a brief introduction to set the context:
Today, deep learning is primarily based on artificial neural networks (ANNs), which, despite their exceptional performance in almost every field, are not truly representative of actual intelligence and brain function.
This is why the first Spiking Neural Network (SNN) appeared in 1952, designed to mimic the neurons in the brain.
Explication:
Each neuron has a state, and every time one or more spikes arrive from other neurons, each spike is multiplied by a value (similar to ANNs) and added to the others. If the state exceeds the threshold, the neuron generates a spike.
Now that we’ve covered the basics, let’s get down to business and start implementing our own SNN layer.
Implementation
First, we need to define the variables that we will use later.
import torch
import torch.nn as nn
inp_size = 3
out_size = 5
state = torch.zeros(1, out_size) # state
w = torch.zeros(inp_size, out_size) # weights
threshold = 0.6
x = torch.randn(1, inp_size) # input
Next, we begin the feedforward process. The first step is to multiply the inputs by the weights, which is the same process as in a basic neural network.
- multiple inputs to weights
This process is the same as the basic neural network.
out = x @ w
2. update the state
To update the state you first need to divide them by two, otherwise all the neurons would generate spikes. And after we add the out to the state.
state = state*0.5 # divid by 2
state = state + out # add the output
3. generate spikes
A spike is generated if the value is greater than the threshold.
spikes = torch.where(state < threshold, 0, 1) # if < 0.6 => 0 else 1
4 and end. reset states which generate spikes
All states that generated a spike will be reset to 0.
reset = -state * spikes # the inverse of state with a mask of spikes
state = state + reset
OK, now we have the math we can create the class
class SNNLayer(nn.Module):
def __init__(self, inp_size, out_size, threshold=0.6):
super().__init__()
self.threshold = threshold
self.w = nn.Parameter(torch.randn(inp_size, out_size))
self.w.requires_grad = False # don't apply gradient
self.state = torch.randn(1, out_size)
def forward(self, x):
out = x @ self.w # activ
self.state = self.state*0.5 + out # state
spikes = torch.where(self.state < self.threshold, 0, 1) # spike
reset = -self.state * spikes
self.state += reset # reset neuron which have a spike
return spikes
Learning
Well, it’s all very well for our layer, but now she has to train. To do this we’ll use unsupervised learning, more specifically the Hebbian Learning rule.
The principle is simple if a neuron is activated then it will
- strengthen the connection with all the neurons to enter activated and
- decrease the connection for all the neurons to enter inactivated
Note: I’m sorry the code is very ugly because I wanted to make it as easy to understand as possible.
for n in range(spike.shape[1]):# for each spike
if spike[0, n] == 1:# if is activated
for i in range(x.shape[1]): # for each input
if x[0, i] > threshold: # if is > threshold
w[i, n] += 0.05 # add a little value
else:
w[i, n] -= 0.05 # subtract a little value
It was maths and now all we have to do is apply it to the class.
class SNNLayer(nn.Module):
def __init__(self, inp_size, out_size, threshold=0.6):
super().__init__()
self.threshold = threshold
self.w = nn.Parameter(torch.randn(inp_size, out_size))
self.w.requires_grad = False # don't apply gradient
self.state = torch.randn(1, out_size)
def forward(self, x):
out = x @ self.w # activ
self.state = self.state*0.5 + out # state
spikes = torch.where(self.state < self.threshold, 0, 1) # spike
reset = -self.state * spikes
self.state += reset # reset neuron which have a spike
# learning
for n in range(spikes.shape[1]):# for each spike
if spikes[0, n] == 1:# if is activated
for i in range(x.shape[1]): # for each input
if x[0, i] > self.threshold: # if is > threshold
self.w[i, n] += 0.05 # add a little value
else:
self.w[i, n] -= 0.05 # subtract a little value
return spikes
And that’s it, you’ve got a complete and functional implementation of an SNN layer. 👍
I hope this article has been of use to you, and if it has, there’s no reason why you shouldn’t clap it. = )