Demystifying the Math behind Machine Learning

My dad once told me “a computer is a dumb machine — it does exactly what it is told to do, and nothing more.” — he was right: a computer merely executes a set of instructions coded by a programmer. So how can a computer be ‘intelligent’ or ‘learn’? While the basic mathematical concepts behind Machine Learning (ML) are pretty simple, most posts out there are jargon-heavy and scare non-experts away. I decided to write this post as an introduction for people who don’t have a math background but would like to get a sense of what is actually happening under the hood in ML.

In search of a magical machine

We are looking for a Machine*

All the problems that ML solves are the same in nature: We are looking to build a ‘machine’ that converts something (let’s call it Input) into something else (let’s call it Output). For example, we want a ‘machine’ that takes as an input a picture, and as an output it tells us if it’s a cat or not. Or it can be a machine that takes some music as an input, and as an output it tells us the name of the singer. Another example would be a machine that takes a human as an input, and tells us his or her gender as output — You get the point, the input and output can be literally anything (as long as they are related).

Everything is numbers

The very cool thing about computers is that music, text, video, or really anything is actually represented in numbers. For example, a picture is nothing more than a bunch of numbers, each representing a color and a position on the screen.

We (humans) can hand-pick and measure the key features of any object and have a simplified representation of the object as numbers. For example, we can represent a human by his/her height and weight (if we think that’s what the machine needs to make the prediction). In the jargon converting complex objects into numbers is called Feature Selection.

As far as computers are concerned, the ‘machine’ we are looking for takes a number (or a set of numbers) as an input and gives a number as an output. This type of machines that deal with numbers are called mathematical functions, we typically represent them with the letter f.

A Quiz !

Now that we have narrowed down the problem to numbers, let’s do a quick exercise — if you look at the chart below, and assume that you know the outputs on the left side, can you ‘predict’ what the output for 5 is?


YES! The answer is 10. And congratulations, the process that your brain just went through is exactly what Machine Learning is all about! Here is probably how you handled the exercise:

Step 0: Get some data: a bunch of inputs for which we know the corresponding output. (Note: I did this step for you 😊)

Step 1: Assume the function has a ‘familiar’ form. (Note: In the quiz you probably assumed that there is some basic multiplication or addition involved)

Step 2: Look at the data and search for a function that seems to work. (Note: In the quiz the function you found was Output = 2 * Input a.k.a f(x) = 2 * x )

Step 3: Test the function on the rest of the data to see if it really works on all inputs in the data.

Step 4: Apply the function to the new input (5) in order to predict output.(Note: f (5) = 2 * 5 = 10 )

Let’s pause here and define some jargon: the data you got at step 0 is called the ‘training data’. Step 2 & 3 are called ‘model training’. Model training is basically the part where you’re searching for a general function that seems to replicate the training data — this is the part that can get automated — the automated process to search and find the right function is called a ‘machine learning algorithm’.

Two types of Machine Learning: Classification and Regression

Before we go any further, I just would like to introduce two types of problems we solve with machine learning: Classification and Regression.

Classification is when the output can only take a finite number of possibilities (say for example we are trying to predict if an email is spam or not spam, or if a picture contains a cat or not, or if a person is male or female).

Regression is when the output can take a continuous set of values, say for example we want to predict how much a customer is going to spend based on their demographics (age, earnings, etc.). The output can be $10, or $10.1 or $10.11 or $10.12 or $20, etc.

Can a computer find the right function by itself?

Computers are not creative, they do exactly what they’re told. There is however one thing in which computers are superior to humans: They calculate very fast. They therefore can relatively quickly search for and find a function that works, as long as we tell them where to search and give them a step by step process of how to search. The following paragraphs explain a few approaches of how we can teach a computer to search for a function.

Approach 1: Want to classify? Draw a thick line! (a.k.a SVM)

Before we get into this approach, let me just point out that a popular way of teaching ML involves functions that take 2 numbers as an input. That’s because in the real world we need to solve problems where we are predicting outputs based on more than just a single input. We typically choose to use 2 numbers as input because it’s easier to visualize graphically with two axis representing the two inputs.

Say for example we are looking to use ML to build a tool that predicts a person’s gender based on his/her height and weight. Because the output can only take 2 possible values (male or female), this is a classification problem.

From a graphical perspective, the training data can be represented this way:

Height, Weight and Gender (I made this chart up, it is not based on actual data)

The pink and blue points are our training data, the color is the known output (gender). Since the pinks seem to be close to each other, and the blues seems to be close to each other, let’s just draw a straight line. Everything that falls on one side of the line would be pink, and everything that falls on the other side would be blue.

As you can see in the chart, there is more than 1 line that does the job, and they’re not all as good. If you look at the two lines shown above: Line 1 passes awfully close to a pink data-point, while line 2 leaves more room for separation.

A popular algorithm to draw a good separating line is called Support Vector Machines (SVM). SVM basically looks for the straight line that maximizes the distance between the line and the closest training data-points, so that we ‘draw a thick separating line’. SVM is useful because it turns out maximizing things is something we have ‘off-the-shelf’ algorithms for computers to do (see below). Even though it may sound conceptually very simple to just draw a line, SVM is pretty widely used. For example these medical researchers used it to do classify cancers using genetic data.

Approach 2: Show me your neighbor, I’ll tell you who you are

Let’s go back to our quiz. Here is a pretty simple ML algorithm that we could have used to solve it : To find the output corresponding to 5 we look at the training data, find the closest number to 5 (in this case that’s 4) and assume that the output for 5 is just the same as the output for its nearest neighbor. This means f(5) = f(4) = 8. Sounds pretty dumb right? Yes it is, but it kind of works (8 is kind of close to 10). To get a little more sophisticated, we could take the average of the 2 closest inputs to 5, in this case that’s 4 and 7. That would mean f(5) = average (8,14) = 11. In fact this can be generalized to take the ‘k’ nearest neighbors. This algorithm is called the k-Nearest Neighbors (a.k.a k-NN).

The fundamental assumption behind k-NN is that inputs that are close to each other are likely to have the same output. This algorithm works well for functions for which the outputs can only take a few possible distinct values. Here is a real-life example in which medical researchers used k-NN to classify and diagnose different types of cancer.

Approach 3: Linear Regression

Yes! Linear regression is technically machine learning 😊. It’s not really used in complex Machine Learning, but the concepts apply to most regression algorithms so I think it is useful to understand how it works.

In linear regression, we assume that for any input x, the output is in the form of f(x) = a * x + b, where a and b are some numbers that we don’t know (yet). Note that we here assume (before we actually do any math or anything) that our function has to look that way. We then search for the ‘best’ a and b i.e. the ones that replicate the training data (or at least get as close as possible).

Side note with some jargon: When we assumed that the function has to ‘look a certain way’, we introduced new variables that we called a and b. In ML, these variables a and b are called parameters. A lot of time in ML is spent searching for the best parameters. Some models have millions of parameters!

In order to find the best a, b, for each pair (a,b) we define the Error(a,b), which basically measures how bad the function is doing in replicating the training data. Mathematically, Error(a,b) is the difference between what the function should have predicted (according to the training data) and what it actually predicts (according to the formula).

We have just transformed our ML problem from “let’s find a function” into “let’s find parameters that minimize the error”. As it turns out, we have off-the-shelf algorithms that teach computers to find minimums (or maximums). This is the second time we came across this problem so let’s dig into it a little bit.

Another Quiz — How to find a maximum?

Assume you are visiting a coastal city (say Seattle), you are walking in the street, check the news on your phone, and hear that there is a massive tsunami coming your way. You’re alone, have no idea what the landscape looks like, but you want to get as high from the sea-level as possible to protect yourself. What do you do? (PS: Google maps is not working, you can’t get inside or climb any building, and you’re surrounded by tall buildings, so you can’t see much apart from your immediate surroundings)

The best you can do is to follow the slope. Just walk uphill, and at every crossing check the slope again: If turning right or left gets you higher, do that, if not, keep going.This approach to find a maximum or a minimum is called the gradient descent. So to minimize our Error or maximize the ‘thickness’ of the separating line, and solve our above problems, the computer would pick some random starting point, and start walking downhill or uphill from there.

I know what you’re thinking — this approach is kind of dumb. Obviously at some point you may get stuck at the top of a ‘hill’: any direction you take will be going down. This means you may miss out on a potential mountain nearby. This problem is called the ‘local optimum’ problem. To solve this, some algorithms introduce ‘randomness’ in order to escape any local hill. These are called probabilistic algorithms or randomized algorithms.

Approach 4: Let’s replicate a brain — Neural Networks

If when you hear Neural Networks you think of some magical thing that replicates a human brain with freewill and all — sorry to disappoint you, it’s much simpler than that. Neural Networks are just a regression algorithm. Just like in linear regression, the Neural Network algorithm assumes that the function has to look a certain way, and searches for the best parameters.

In Neural Networks, the basic idea is that the function we are looking for is assumed to be a weighted average of step-functions. A step-function is a function that has an output of 0 when the input is below a certain threshold, and an output of 1 above that threshold. So the algorithm has 2 type of parameters to look for: the weights, and the tresholds. Why does the algorithm assume the function has to look this way? Because it is mathematically proven that if we pick the right number of step-functions, with the right tresholds, and the right weights for each, we can approximate any function whatever its shape is.

A problem we face with step functions is that when the input is very close to the threshold, the output can change radically (from 0 to 1) if the input is changed just a little bit. And this is not a natural phenomenon — in real life there is a certain ‘smoothness’ in the world. A picture of a cat is not going to turn into something else if we change the color of 1 pixel a little bit. To solve this problem, instead of using step-functions we use activation functions, which are similar but transition smoothly from 0 to 1.

Step Function (Not smooth)
Activation Function (Smooth)

What I have just described above is a 1-layer Neural Network. A multi-layer neural network is a succession of 1-layer neural networks: The output of layer N is the input of layer N+1. The output of the last layer is the actual output we are looking for.

Supervised vs Unsupervised Learning

Everything we have talked about here so far falls under the supervised learning category. Unsupervised learning is when your training data has no outputs, just bunch of inputs. For example, while in supervised learning your training data would be a bunch of pictures labeled ‘cat’ and ‘not cat’. In unsupervised learning your training data would be just a bunch of pictures. The algorithm would by itself find that there are two categories of pictures, and come up with a function that classifies the inputs so that those that are similar have the same output.


The idea behind this post is to give a high-level understanding of what Machine Learning is about from a conceptual point of view, without getting into the weeds of the mathematical and computational algorithms. I hope this post will be useful to people who are curious about the topic!