How Do Neural Networks Work?
When you first look at neural networks, they seem mysterious. While there is an intuitive way to understand linear models and decision trees, neural networks don’t have such clean explanations.
In this article we try to develop an intuitive understanding for neural networks. To do so, we first describe the workings of linear models and decision trees. The vocabulary and tools developed in the process will allow us to explain neural networks in simple terms later.
We drive the discussion using a simple binary classification task that has two input features, x and y. Example of such a task could be classifying avocadoes into good or bad. We encode the positive decision indicating a good avocado as 1, and use 0 for the negative decision corresponding to a bad avocado. The input x could represents how dark is the color of the avocado and y could represents how firm the avocado feels. The figure below shows the input data avialable. You want to avoid avocados that are too raw or too ripe, the good ones are somewhere in the middle.
Each point on the x-y place represents an avocado with a particular color and firmness. The classification task boils down to coming up with a curve on the x-y plane that can separate the points labeled 1 from the points labeled 0. One side of the curve will represent the positive decision region, the other side the negative decision region. When generalized to hundereds of features, the task becomes coming up with surfaces in a multi-dimensional space that separate the positives from the negatives.
Linear Models
Linear models work by drawing a straight line to separate the positives from the negatives. In the above example, a linear model might do the following separation.
Note that all linear classification boundaries can be expressed in the form:
w0 + w1*x + w2* y = 0
where {w0, w1, w2} are constant values such as {w0=-1, w1=2.0, w2=-3}. Equations like the one above are referred to as the “model” and w0, w1, w2 the “weights” of the model. The classification boundary separates the x-y plane into two regions described by :
- w0 + w1*x + w2* y > 0, where the model’s decision is positive
- w0 + w1*x + w2* y < 0, where the model’s decision is negative.
Given any point (x, y) on the plane, you can check on which side of the classification boundary it lies, and that gives you a decision for that point. In practical terms, given the color or firmness of any avocado, you can plug in the x and y values in the model and get a decision for whether it’s good or bad depending on whether w0 + w1*x + w2* y > 0 or w0 + w1*x + w2* y < 0.
To get a feel for different linear models, we use gnuplot. You can install gnuplot on Mac with “brew install gnuplot
” or on linux try “sudo apt-get install gnuplot-x11".
Once you have gnuplot, you can see how different models behave. The following example plots the decision regions for a linear model with weights w0 = -5, w1 = -2, w2 = 1. You can change the weights to see the decision regions for different models.
gnuplot> f(x,y) = (-5 - 2*x + y > 0)? 1 : 1/0
gnuplot> unset colorbox
gnuplot> set isosample 300, 300
gnuplot> set sample 300
gnuplot> set pm3d map
gnuplot> splot [-5:5] [-5:5] f(x,y)
Once you start gnuplot and type those commands at the gnuplot prompt, you should see the figure below. The model gives a positive decision for all points in the shaded region, and it gives a negative decision for the points in the white region.
Linear models are simple enough that you don’t need this kind of visualization. But this exercise is in preparation for a similar exercise with neural nets.
Decision Trees
A decision tree model builds a tree structure, where each node in the tree “splits” the range of one of the input features. A sample tree and the resulting classification boundries are shown below. The splits create decision boundaries that are perpendicular to the axis being split. For example, the split point “x=a” separates the x-y plane into two halves x > a and x < a, followed by the horizontal splits along “y=b” and “y=c”.
To visualize the decision tree using gnuplot, we plot a tree with a = 3, b = 3, c = 1 below:
The commands to type at the gnuplot prompt would be:
gnuplot> f(x,y) = (x > 3? (y > 1? 1 : 1/0) : (y > 3? 1/0: 1))
gnuplot> unset colorbox
gnuplot> set isosample 300, 300
gnuplot> set sample 300
gnuplot> set pm3d map
gnuplot> splot [0:5] [0:5] f(x,y)
Neural Networks
Ok, now we are ready for some action! To keep the discussion tractable, we focus on the most popular flavor of neural network, one based on Relu. In a Relu network, you start off just like a linear model. But then you apply the Relu activation which is simply
output = input > 0 ? input : 0
Relu stands for Rectified Linear Unit, it’s just a linear model which is “rectified” in the way above.
This one line code is at the heart of the neural net revolution, so it’s worth repeating. In a linear model, your output would be
linear_output = w0 + w1*x + w2*y
Inside a neural network, the Relu output would be a modification of the linear output as
relu_output = linear_output > 0 ? linear_output : 0
You then repeatedly apply this trick, feeding Relu outputs to other Relu units. This builds up the classic neural network architecture. Below we show a simple network with two Relu units. The first Relu has weights w0, w1, and w2. The second Relu has weights u0, u1, u2. We then combine the two Relus by a linear model with weights v0, v1, v2 to get the final output.
Note, without the Relu activations, the final output is a linear model on top of linear models, which is also a linear model.
output_without_relu
= v0 + v1*(w0 + w1*x + w2*y) + v2*(u0 + u1*x + u2*y)
= k0 + k1*x + k2*ywhere,
k0 = v0 + v1*w0 + v2*u0
k1 = v1*w0
k2 = v2*u0
If we remove the Relus, the whole network reduces to yet another straight line classification boundary. The entire novelty, so to speak, is in the Relu activation. To understand the difference Relu brings in, we apply Relu to the linear model we looked at before with weights w0 = -5, w1 = -2, w2 = 1. The gnuplot commands are shown below. Relu introduces the additional “y > 0 ? ...
” part.
gnuplot> f(x,y) = y > 0 ? ((-5 - 2*x + y > 0)? 1 : 1/0) : 1/0
gnuplot> unset colorbox
gnuplot> set isosample 300, 300
gnuplot> set sample 300
gnuplot> set pm3d map
gnuplot> splot [-5:5] [-5:5] f(x,y)
The resulting positive and negative decision regions are shown below. For comparison, we put the original linear model on the left.
With Relu, the classification boundary is no longer a straight line running across the x-y place. The part below the x-axis is chopped off, resulting in the angular region shown in the plot on the right.
You can put in different weights to see how it affects the decision regions. Some examples are shown below.
The set of Relus, also called the “hidden layer” of the network since they are hidden between the inputs and outputs, produce these angular decision regions. The final output is constructed by applying a linear model to add all these angular regions together. An example of such an output decision region is shown below:
gnuplot> f(x,y) = (y > 0 ? ((-4 -2*x + y > 0)?1:1/0) : 1/0) + (y>0?((4 + 2*x + y) > 0 ? 1: 1/0 ): 1/0) + (y>0?((4 - x - 2*y)>0? 1: 1/0):1/0) + (y>0?((9 + 2*x - y)>0 ? 1: 1/0): 1/0)
gnuplot> unset colorbox
gnuplot> set isosample 300, 300
gnuplot> set sample 300
gnuplot> set pm3d map
gnuplot> splot [-5:5] [-5:5] f(x,y)
The command looks a little complicated, but all it’s doing is adding four Relus together.
What makes neural networks unique is this ability to draw these arbitrary shaped classification boundaries. Given enough Relus, you can approximate any curve. Returning to our original example, you can now imagine building a network with 20 ~ 30 Relus and getting a classification boundary that looks like the one below.
You often hear people describing neural networks as highly non-linear models. That’s because of these very angular regions the model builds. The intuitive summary is that neural networks have the ability to build arbitrary shaped classification boundaries, which make them a very effective tool.
In practice though, there are tons and tons of caveats. Most algorithms don’t look for the classification boundary directly. Instead they search for model weights that map the given positive examples as close to 1 as possible, and push the negative examples towards 0. The classification boundary is a by-product of this process, along with choosing a suitable classification threshold.
Another point is that practical linear models often have tricks that allow them to bend the classification boundaries. Neural networks have tricks that reduce some of the sharpness of the angular regions. The variant of decision trees that’s often used in practice builds lot of trees and then averages them together, much like the output node of a neural network. The application side is just full of nuances. Hopefully the visualization above gives you a base to dive deeper.
See also: