Opening the Black Box of the Simplest Neural Network

Visualise the decision process (weights) of a neural network

Luis F. Camarillo-Guerrero
The Startup
7 min readOct 29, 2020

--

Neural networks (NNs) are often deemed as a ‘black box’, which means that we cannot easily pinpoint exactly how they make decisions. Given that NNs store their knowledge in their weights, then it makes sense that their examination should reveal some insights about their decision process.

In this article, we are going to train NNs that recognise handwritten numbers (0–9) and then open their ‘black box’ by visualising their weights.

All the code is written in Python and can be found here

Reading handwritten digits with a neural network

We are going to use handwritten digits that are stored in the MNIST database. Each digit is in greyscale and has a pixel size of 28x28 (width x height) = 784 pixels in total. In addition, each pixel can take any value from 0 (black) to 255 (white). Any value in between [1,254] corresponds to a different shade of grey.

Sample of the first 40 handwritten digits of the MNIST database

Next, we need a proper representation for each handwritten digit so that it can be read by our NNs. The most simple way is to generate a vector (x) that will contain the 784 values of each image e.g. x = [12,0,0,…,234] (reading the pixel values from left to right and top to bottom).

Regarding the topology of our NN, it will involve 1 output neuron which will receive the 784 pixel values (input layer) from the handwritten digit (see below). The output neuron will output the probability that a given image (x) corresponds to the number it was trained to recognise. Thus, we will need 10 different NNs to recognise all the numbers (0–9) from the MNIST database (this strategy is known as One-vs-All in multi-class classification problems)

As you can see from the above image, each pixel has an associated weight (w) with a subscript indicating the pixel of origin (e.g. w1 corresponds to the very first pixel). It’s precisely these weights that will reveal how the neuron is processing each image presented to it.

Interpreting the weights

Mathematically, the output neuron is multiplying each weight with its corresponding pixel and then summing over all the products plus a bias (b). However, we can use a trick and simplify this operation by adding an extra weight (w0) and then introducing an imaginary pixel with a value of +1.

See how the last equivalence is summing over 785 instead of 784 values? that’s because b = w0*1

Does this transformation look familiar? indeed, it’s the famous dot product from physics that involves two vectors. In a nutshell, the dot products tends to increase when two vectors point to the same direction¹. An equivalent way of expressing the dot product is by multiplying the lengths of each vector between each other times the cosine of the angle between them (a⋅b = ∥a∥∥b∥cos(θ). If two vectors are exactly aligned then the angle will be 0 and the cosine function is maximised (cos(0)=1).

You may be wondering why we are suddenly talking about vectors. Well, it turns out each handwritten digit can be visualised as a vector in a 784 dimensional space (see below). Grasping this idea is critical to interpret the weights from our NNs.

It’s tempting to think that the same handwritten numbers (e.g. all the 3’s in our dataset) may have similar pixel values, and thus their vectors will tend to point into the same direction². If that’s true, then the cosine values between pairs of 3’s will be larger than cosine values between 3’s and 6’s. Let’s find out.

In the graph above, I took one example of a “3” and then computed the cosine of the angle against other MNIST numbers (I only considered instances found in the first 1000 examples). As you can see, our intuition was correct. The vector of a ‘3’ is close to other vectors that are images of 3’s, and thus the cosine of the angle between them is highest (the angle between them is lowest). On the other hand, the cosine value goes down against other MNIST numbers indicating that they are found farther away in the 784 dimensional space.

There’s still one part of the computation that I haven’t discussed and that is the sigmoid function. This function squashes any real number (in this case Z) into the range 0–1 . The higher the value of Z, the closer the value to 1.

As you can see from the sigmoid function above, the output neuron wants to generate large values of Z so that it can get the label right (values closer to 1)³. Here’s the key question: when will our output neuron assign large values of Z to positive examples? when the dot product between the image vector and the weight vector is the highest. Thus, this occurs when the weight vector is close to the image’s vectors considered positive examples.

Just as images can be put into a 784 dimensional space, there’s no reason why W (which has 784 components) cannot be put in the same hyper-dimensional space. How will the W vector look like if we visualise it like an image? indeed, like a 3! the same thing will happen for any other NN trained to recognise any other MNIST digit. During training, a NN will move the weight vector through the 784-dimensional space and the best results will be achieved when it places it next to the positive examples (maximising the dot product).

Training our neural net and visualising its weights

Time to put this theory to the test (complete Python code can be found here and data files here). First, read in the first 150 handwritten numbers in MNIST and save them in X along with their corresponding labels in y. We also have to normalise the pixel values of the images to facilitate the learning process (I’m just subtracting the mean and dividing by the standard deviation). Regarding the labels, since we are going to train our NN to recognise only the number 3 (y=1), we have to assign y=0 to all the other numbers.

image_size = 28  # 28x28 pixel images
num_images = 500 # Images to read
num_class = 3 # Number to recognise
# Reading handwritten digitsf = gzip.open('train-images-idx3-ubyte.gz','r') # File with digits
f.read(16) # Offset for numbers
buf = f.read(image_size * image_size * num_images)
X = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
X = data.reshape(num_images, image_size*image_size)
X = [normIM(x) for x in X_imgs] # Normalising images
# Reading labelsf = gzip.open('train-labels-idx1-ubyte.gz','r') # File with labels
f.read(8) # Offset for labels
buf = f.read(num_images)
y = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
y = [1 if y_i==num_class else 0 for y_i in y_lab]

Now let’s wire our NN with Keras. We only have 1 neuron with 784 inputs and a sigmoid activation.

model = Sequential()
model.add(Dense(1, input_dim=784, activation='sigmoid'))
model.compile(loss='binary_crossentropy',optimizer='RMSprop', metrics=['accuracy'])

Finally, we train our NN.

model.fit(X, y, epochs=250, class_weight=’balanced’,verbose=True)

After this step, you can recover the weights from the model and visualise them with:

model.get_weights()[0] # The 784 weights
model.get_weights()[1] # The bias (b)
w=[i[0] for i in model.get_weights()[0]]+model.get_weights()[1]plt.imshow(w.reshape(1,28,28)[0],cmap='gray')
plt.show()

When I repeat this procedure for all the MNIST numbers, I can visualise the 10 weight vectors for each of the 10 digits.

As you can see, for each NN the respective weights look like the numbers they are trying to recognise. The reason is that when you project these weights into a 784 dimensional space they are going to lie close to the target numbers in order to maximise the dot product and thus Z, pushing the sigmoid function to output values closer to 1 (and reduce the training error).

Conclusion

In this post, I tried to explain some of the intuition behind the weights of neural networks. In particular, I analysed the weights of the simplest NN which consists of a single output neuron. Interpreting the weights of more complex neural networks is not an easy task, however, when working with images, visualising the weights can give you some insights. For instance, it’s known that deep neural networks learn features of images e.g. the eyes from faces and that learning can be monitored by weight visualisation.

  1. This is not always true because the magnitude of the dot product is also proportional to the product of the lengths of the two vectors. a⋅b = ∥a∥∥b∥cos(θ)
  2. This idea works fine for the MNIST digits because the images are relatively simple and uniform. You can imagine that different image backgrounds can affect this heuristic.
  3. Depending on the cost function, there will be another driving force pushing in the opposite direction to assign low values of Z to negative examples to make the neuron output values closer to 0.

--

--

Luis F. Camarillo-Guerrero
The Startup

PhD in Genomics at University of Cambridge — Bioinformatics/Viruses. MSc in Bioinformatics and Theoretical Systems Biology — Imperial College London