BindsNET — A Framework for Spiking Neuronal Simulation

Hananel Hazan
PyTorch
Published in
10 min readMay 20, 2020

Authors: Hananel Hazan and Michael Levin at The Allen Discovery Center at Tufts University

What is BindsNET?

BindsNET (Biologically Inspired Neural & Dynamical Systems in Networks), is an open-source Python framework that builds around PyTorch and enables rapid building of rich simulation of spiking networks in a concise syntax. The framework provides an easy way to test ideas while incorporating biological abilities to spiking neuronal networks. It also allows to test them with the help of the benchmarking tools that were used to evaluate a standard neural network. This framework is a result of a collaboration between computer scientists, neurocomputational scientists, and biologists. BindsNET explores the computational ability of neurons with bio-inspired properties utilizing standard machine learning benchmarks. While biologists research a nervous system to learn more about how a gene expression functions and what are the effects of ionic conductance, other scientists investigate how these discoveries potentially benefit the computational world.

Do you want to try BindsNET? Just install it through its main GitHub repository

pip install git+https://github.com/BindsNET/bindsnet.git

BindsNET is a joint work at the BINDS lab under Director Hava Siegelmann, the temporary co-director Robert Kozma, and the hard-working students: Daniel Saunders, Hassaan Khan, Devdhar Patel, Darpan Sanghavi, Sharath Ramkumar, and Cooper Sigrist.

Why is BindsNET useful?

BindsNET provides an easy and fast development environment to test computational abilities of a nervous system, mainly a spiking nervous system, in a standardized benchmark within the mainstream computer science.

The BindsNET is useful mainly for computer scientists specialized in bioinformatics and computational biology. At the same time, this tool can be used by specialists from fields such as neurobiology and neuroscience who are interested in the field of information exchange in neuronal tissues and want to explore new computational concepts inspired by biology. When examining a problem from a computational perspective, these specialists are often confronted with new difficult biological concepts that are hard to tackle. First, there is the matter of overcoming the differences in terminology and understanding the complexity of the ideas being introduced. Second, one needs to utilize standardized benchmarks to implement those ideas and then to compare them against a standard model. BindsNET can help to do these tasks in an easy and as straightforward way as possible.

For instance, a common fruit fly has about 250,000 neurons (less than the number of neurons in Alpha Go), and when in danger, it performs complex flight maneuvers that fighter pilots could only dream of doing. When it comes to learning how to utilize and test the computational power of spiking neuronal algorithms, the BindsNET framework helps to facilitate exploring different concepts, such as reinforcement learning with spiking networks, deep learning, or information theory much easier. It does so for instance by providing researchers with tools to use a spiking network in a reinforcement learning environment and the tools to monitor and adapt the architecture based on results. In addition to increasing our understanding of the biological functionality of information processing, one of the main motivations for using spiking neurons is lower demand on energy consumption. Specifically, spiking neurons run on designated hardware that consumes much less power than standard artificial neurons. Until we all have our own neuromorphic chip, the BindsNET can help us to utilize our GPU to facilitate simulations more efficiently.

One of the best ways to understand how things work is to build their models from scratch using the principles, information, and rules and to test them with the help of these models. One can say that the functioning of an examined biological system has been understood if behaviors of both the model and the original biological system correspond. Or the in words of Richard Feynman “What I cannot create, I do not understand”.

How does BindsNET work?

BindsNET utilizes the ease of use and the facilities provided by PyTorch to simulate the needs of bio-inspired computations. BindsNET can simulate a variety of biologically inspired cells. For now, we mainly use several variants of spiking neurons to simulate functionality of the closest computational cell in a nervous system — the spiking neuron. One of the main advantages of using BindsNET is the ability to intervene in every step of the computation without leaving the comfort of Python to other programming languages. More importantly, it is very easy to create a hybrid network, for example, a network composed of spiking neurons (using BindsNET) and artificial neurons (using PyTorch) in one code under the same loop. PyTorch serves as a fast and dynamic computational engine and an efficient infrastructure for the BindsNET code providing it with enhanced computational features. Thus, the running code can utilize any variable at any stage of the simulation. This is a great advantage compared to other solutions that do not support changes of variables during their operation. This unique ability enables users to access every variable at any stage of the simulation and to use PyTorch or other tools to analyze and to intervene in the activity of the network.

In the section below, we demonstrate how one can create a self-organizing spiking network that classifies an MNIST dataset and demonstrates its ability to rearrange itself in a self-organized fashion that is similar to the cortex’s sensory map. Next, we show that spiking neurons can execute tasks at a similar level of performance like standard artificial neurons.

What is the future of BindsNET?

BindsNET aims to develop both on the purely computational side as well as on the side of computational biology. On the computational side, BindsNET aims to add more models and benchmarking tools to facilitate the comparison between a spiking neuronal network and a standard neuronal network. On the side of computational biology, the framework aims to incorporate additional biologically-inspired features that would help to explore and demonstrate the computational power of biological systems on more standard benchmarks. Combination of the two approaches will benefit both sides by providing them with better tools that will help to standardize the computational approach to biology.

Self-Organizing Maps (SOM) (a.k.a. Kohonen maps)

The existence of self-organizing properties in the natural world has always fascinated researchers. For instance, Teuvo Kohonen (see) made a significant contribution to the field of artificial neural networks with his work on Self-Organizing Maps (SOM). A SOM, a type of an artificial neural network, uses an algorithm to automatically form clusters of similarities that are then featured in the selected data accordingly. While a self-organizing behavior is exhibited on the data input, the neurons are also being classified based on their proximity to other neurons that have similar features.

To the best of our knowledge, even with it being as important as it is in biology, the self-organizing feature has not been used in bio-inspired neurons. Based on the work of Diehl and Cook (see), we improve their model to exhibit a self-organizing property that can improve the overall performance and accuracy and also add another dimension to the classification ability.

Let’s build our network with the following:

network = IncreasingInhibitionNetwork(
n_input=784,
n_neurons=256,
start_inhib=1,
max_inhib=-20.0,
theta_plus=0.05,
tc_theta_decay=1e7,
inpt_shape=(1, 28, 28),
nu=(1e-4, 1e-2),
)

This code will create a network with two layers, an input layer and an output layer. Our network will have 784 input cells corresponding to the MNIST digit image size (28x28). The output layer will have 256 neurons and will serve as the memory capacity to capture the differences in the input data. The bigger is the size of the output layer, the bigger is the accuracy and the ability to classify the input data. The connectivity between neurons inside the second layer inhibits activity of all other neurons in the layer when one neuron spikes. The magnitude of the inhibition will depend on the distance between the cell and its neighbors. Each one of the neurons is equipped with a firing tuning algorithm that will tune its activity. The threshold decay rate of this algorithm is 1e7 and its sliding threshold increases each spike by 0.05.

Next, we will use BindsNET dataloader that extends the functionality of Pytorch dataloader:

# Load MNIST data.
dataset = MNIST(
PoissonEncoder(time=time, dt=dt),
None,
root=os.path.join("data", "MNIST"),
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Lambda(lambda x: x * intensity)]
),
)

This code will return a BindsNET data-loader connected to a Poisson Encoder that converts MNIST images to spike trains with help of the length of the time variable and the intensity variable.

Now, let’s iterate and feed each image to the network and train its weights. The following code will demonstrate this easily

for step, batch in enumerate(tqdm(dataloader)):
# Get next input sample.
inputs = {
X”: batch[“encoded_image”].view(time, 1, 1, 28, 28).to(device)
}
# Run the network on the input.
network.run(inputs=inputs, time=time, input_time_dim=1)
# Get spikes from the network
temp_spikes = spikes[“Y”].get(“s”).squeeze()

To see how the network progresses we can display the weights of the neurons:

square_weights = get_square_weights(input_exc_weights.view(784, n_neurons), n_sqrt, 28)
plot_weights(square_weights, im=weights_im, save=save_weights_fn)

The full source code can be found in the examples of BindsNET (link) and in the published paper here.

Running the example and speeding up the training will look like this:

Self-Organizing Maps (SOM) using spiking neurons
Network Performance

Memory Transfer between networks

In hardware, spiking neurons are able to be more energy-efficient and have been shown to be noise-robust. Training spiking neurons to perform a complex task can be a daunting endeavor. Until we can come up with a better way to train a spiking neuronal network, let’s try to use trained weights of a regular network to perform the same task on a spiking network. As our work illustrates, (see) and will also demonstrate here, transferring the knowledge between similar networks will result in similar performance with some added bonus. In our tests, the spiking network shows remarkable robustness to noise and to missing input compared to the standard network that has originally been trained on the data.

From the biological side, we note that it is important to develop computational models for moving memory between two instances of a network. For example, when trained planarian flatworms have their head amputated, the tail portion will regenerate an entirely new brain, and when it is complete, behavioral testing shows they remember the original training. How does the somatic tissue imprint its memories onto the newly developing neural networks of the brain? Likewise, caterpillars metamorphosing into butterflies are able to transfer memories to a largely rebuilt brain, and memory transplants have been achieved with tissue and biochemical transfer across individuals. It is critical to begin to develop computational models of memory storage, retrieval, and manipulation in plastic substrates, and BindsNET is a tool that can be used in these efforts. Critically, this goes well beyond neuroscience, as it is now well-appreciated that many cellular networks in the body process information for morphogenesis (during embryonic development, regeneration, and cancer suppression) via bioelectrical signaling. The modeling of the computations that are carried out by all cell types is an essential effort both for regenerative medicine and for inspiring new machine learning architectures based on ancient, pre-neural principles discovered by evolution long ago as bacterial biofilms.

Here we will demonstrate this memory transfer on a smaller scale. First, we will train a standard network with three fully connected layers to play ATARI Breakout. Then we will copy the weights with an additional constant to a spiking network with the same architecture and run the game.

The full source code can be found in the examples of BindsNET (link). The same process can be applied to a network at a bigger scale (see).

Let’s build a spiking network that plays the ATARI Breakout game. We will start by loading the gym environment with these lines:

environment = GymEnvironment(‘BreakoutDeterministic-v4’, clip_rewards=True)

Let’s define a standard network and load the trained weights(‘train_shallow_ANN.pt’) using a torch load model:

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(6400, 1000)
self.fc2 = nn.Linear(1000, 4)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# Load ANN
dqn_network = torch.load('train_shallow_ANN.pt',map_location=device)

The following code shows how to create a spiking network with the same topology and how to apply the weights from the standard network with a magnification factor:

# Build network.
network = Network(dt=dt)
# Layers of neurons.
inpt = Input(n=6400, traces=False) # Input layer
exc = LIFNodes(n=hidden_neurons, refrac=0, traces=True, thresh=-52.0, rest=-65.0) # Excitatory layer
readout = LIFNodes(n=4, refrac=0, traces=True, thresh=-52.0, rest=-65.0) # Readout layer
layers = {'X': inpt, 'E': exc, 'R': readout}
# Connections between layers.
# Input -> hidden.
input_exc_conn = Connection(source=layers['X'], target=layers['E'],
w=torch.transpose(dqn_network.fc1.weight, 0, 1) * layer1scale)
# hidden -> readout.
exc_readout_conn = Connection(source=layers['E'], target=layers['R'],
w=torch.transpose(dqn_network.fc2.weight, 0, 1).view([1000, 4]) * layer2scale)
# Add all layers and connections to the network.
for layer in layers:
network.add_layer(layers[layer], name=layer)
network.add_connection(input_exc_conn, source='X', target='E')
network.add_connection(exc_readout_conn, source='E', target='R')
spikes = {}
# Add all monitors to the network.
spikes['R'] = Monitor(network.layers['R'], state_vars=['s'], time=runtime)
network.add_monitor(spikes['R'], name='%s_spikes' % layer)

Finding the right coefficient factor for each layer can be a little bit tricky. In our paper we use a particle swarm optimization, but any other method should assist in determining these factors easily.

And now is the time to let the network play!

Shallow Spiking network playing ATARI Breakout

Do you have new ideas on how to use Spiking Neurons? Try them with BindsNET.

Finally I would like to thank the PyTorch community for their support and hard work that has provided the scientific community with the PyTorch libraries that help to leverage the power of GPU and for maintaining clear and concise documentation and examples.

I also want to thank Alžběta Krausová for several improvements she made to this text.

--

--