Why do neural nets work?
Neural networks are probably one of the worst named concepts in all of computer science. Neural networks were loosely inspired by biology but what we will find is that it is more appropriate to think of neural networks as a sort of kinky linear algebra.
Many machine learning and AI problems have a simple functional formulation[1]. A program takes an input vector[2] (X) and performs some computation to produce an output vector (Y). The trick is that we get very meta: instead of trying to figure out the program to compute Y from X ourselves, we’ll use some combination of real-world examples and intuition to get a computer to figure out what a good function to get from X to Y would be. This is important because many functions (e.g. “Is there a dog in this image?” or “What is the melody line of this music” or “Is this text in German”) are intuitively obvious to humans but we can not articulate the precise algorithm that our brains use.
Linear regression
Linear regression is literally the first chapter in your favorite machine learning textbook. It is simple and has a lot of useful mathematical properties. We assume that Y=WX, for some matrix W, some input X, and some output Y. Then, we define an error function for any (X,Y) pair as E=(Y-WX)². Linear regression comes with a procedure for computing W given some data. That procedure is only possible because we know how to differentiate E with respect to W which lets us compute the value of W that minimizes the sum-of-errors across all test inputs. [3]
For example, X might be (mass, height, pitch-of-speaking-voice, heart-rate, hair-length) and Y might be (sex, age). Where sex is >0 for females and <0 for males. You could imagine figuring out a formula to guess sex and age from those parameters. Using linear regression, a computer can figure it out for you with a few examples.
There are two important facts about this approach that we will make use of:
- Matrix multiplication can only represent a small subset functions (linear maps) and therefore regression can only work for relatively simple problems. You might use regression to guess sex given physical parameters like height and weight. You won’t use it to guess if a photograph features a dog.
- Because matrix multiplication is linear, we know that a small change in input will cause a small change in the output. If we differentiate Y with respect to X, we get W! This is the critical property that we need to be able to efficiently estimate an optimal value for W, given some example inputs.
In essence, matrix multiplication is easy to reason about, but insufficiently expressive.
What is a neural network
A neural network consists of a number of layers of neurons. Each neuron takes input from every neuron in the previous layer and provides output to every neuron in the next layer. The first and last layers are the input and output respectively. Each neuron performs a weighted sum over its inputs which it applies to a slightly non-linear sigmoid function to produce the output. [4]


OUTPUT=Sigmoid(CONST_1*INPUT_1 + CONST_2*INPUT_2 +…)
What would happen to our neural network if we picked we just used a straight line (aka linear aka f(x)=x) instead of our sigmoid function. “Weighted sums over inputs” is exactly the dot product so each level of a linear neural network is just matrix multiplication like we did with linear regression. We know that matrix multiplication follows the associative law so regardless of how many levels of neural network we have, we could condense it down to a single level.
In other words, linear regression is a one-level neural network.
But this hints at a powerful idea. By introducing some sort of non-linearity we can make our linear regression more powerful. If we pick a non-linear function with the right properties, we might be able to improve our expressiveness without sacrificing our ability to optimize W.
We upgrade from linear regression to a neural network by converting Y=W*X to Y=S(W3*S(W2*S(W1*X))), where S is a “sigmoid” function that operates on each element of the vector it operates on. Intuitively, given the shape of S, a small change in input will still cause a small change in the output. If we apply the chain rule to compute dY/dW[n], we find that dW[n] depends only on information local to that layer: its input, its output and W[n-1]. From our knowledge of W, we can then use gradient descent (and the chain rule, again) over dE/dW to compute the optimal W. In fact, this is exactly how the backpropogation algorithm for neural networks work.
Conclusion
We see that Neural Networks are more usefully understood in the context of linear algebra than any sort of biological analogy. A Neural Network is simply iterated matrix multiplication where each step is followed by a non-linear sigmoid function. It is possible to quickly compute the derivative of the error function with respect to the input coefficients which lets us use gradient descent to quickly train a neural network.
I honestly don’t know why we still insist on calling these things neural nets. I suppose Kinky Linear Algebra would not be PC but we should be able to come up with something. I’ll vote for Quasi-Linear Modelling.
[1] Here’s a good reference on machine learning: https://www.amazon.com/Pattern-Recognition-Learning-Information-Statistics/dp/0387310738
[2] A vector is just a bunch of numbers. These could correspond to a text document or an image or a recording or anything else.
