Understand Spiking Neural Network through code implementation

Arthur Lagacherie
The Deep Hub
Published in
4 min readJul 12, 2024

If you want to understand how spiking neural networks work and implement them with Pytorch, you’ve come to the right place.

image generated by Adobe Firefly

Introduction

Introduction Before diving into the topic, let’s start with a brief introduction to set the context:

Today, deep learning is primarily based on artificial neural networks (ANNs), which, despite their exceptional performance in almost every field, are not truly representative of actual intelligence and brain function.

This is why the first Spiking Neural Network (SNN) appeared in 1952, designed to mimic the neurons in the brain.

Explication:

Each neuron has a state, and every time one or more spikes arrive from other neurons, each spike is multiplied by a value (similar to ANNs) and added to the others. If the state exceeds the threshold, the neuron generates a spike.

src
image by me

Now that we’ve covered the basics, let’s get down to business and start implementing our own SNN layer.

Implementation

First, we need to define the variables that we will use later.

import torch
import torch.nn as nn

inp_size = 3
out_size = 5

state = torch.zeros(1, out_size) # state
w = torch.zeros(inp_size, out_size) # weights
threshold = 0.6

x = torch.randn(1, inp_size) # input

Next, we begin the feedforward process. The first step is to multiply the inputs by the weights, which is the same process as in a basic neural network.

  1. multiple inputs to weights

This process is the same as the basic neural network.

out = x @ w

2. update the state

To update the state you first need to divide them by two, otherwise all the neurons would generate spikes. And after we add the out to the state.

state = state*0.5 # divid by 2
state = state + out # add the output

3. generate spikes

A spike is generated if the value is greater than the threshold.

spikes = torch.where(state < threshold, 0, 1) # if < 0.6 => 0 else 1

4 and end. reset states which generate spikes

All states that generated a spike will be reset to 0.

reset = -state * spikes # the inverse of state with a mask of spikes
state = state + reset

OK, now we have the math we can create the class

class SNNLayer(nn.Module):
def __init__(self, inp_size, out_size, threshold=0.6):
super().__init__()

self.threshold = threshold
self.w = nn.Parameter(torch.randn(inp_size, out_size))
self.w.requires_grad = False # don't apply gradient

self.state = torch.randn(1, out_size)

def forward(self, x):
out = x @ self.w # activ
self.state = self.state*0.5 + out # state
spikes = torch.where(self.state < self.threshold, 0, 1) # spike

reset = -self.state * spikes
self.state += reset # reset neuron which have a spike

return spikes

Learning

Well, it’s all very well for our layer, but now she has to train. To do this we’ll use unsupervised learning, more specifically the Hebbian Learning rule.

The principle is simple if a neuron is activated then it will

  • strengthen the connection with all the neurons to enter activated and
  • decrease the connection for all the neurons to enter inactivated

Note: I’m sorry the code is very ugly because I wanted to make it as easy to understand as possible.

for n in range(spike.shape[1]):# for each spike
if spike[0, n] == 1:# if is activated

for i in range(x.shape[1]): # for each input

if x[0, i] > threshold: # if is > threshold
w[i, n] += 0.05 # add a little value
else:
w[i, n] -= 0.05 # subtract a little value

It was maths and now all we have to do is apply it to the class.

class SNNLayer(nn.Module):
def __init__(self, inp_size, out_size, threshold=0.6):
super().__init__()

self.threshold = threshold
self.w = nn.Parameter(torch.randn(inp_size, out_size))
self.w.requires_grad = False # don't apply gradient

self.state = torch.randn(1, out_size)

def forward(self, x):
out = x @ self.w # activ
self.state = self.state*0.5 + out # state
spikes = torch.where(self.state < self.threshold, 0, 1) # spike

reset = -self.state * spikes
self.state += reset # reset neuron which have a spike

# learning
for n in range(spikes.shape[1]):# for each spike
if spikes[0, n] == 1:# if is activated

for i in range(x.shape[1]): # for each input

if x[0, i] > self.threshold: # if is > threshold
self.w[i, n] += 0.05 # add a little value
else:
self.w[i, n] -= 0.05 # subtract a little value

return spikes

And that’s it, you’ve got a complete and functional implementation of an SNN layer. 👍

I hope this article has been of use to you, and if it has, there’s no reason why you shouldn’t clap it. = )

Photo by david Griffiths on Unsplash

--

--

Arthur Lagacherie
The Deep Hub

I am a French high school student passionate about artificial intelligence. I enjoy sharing my curiosity with others.