10 Must-Know Models for ML Beginners: Linear Regression
The Line Fitter
This article is part of the series 10 Must-Know Models for ML Beginners.
Introduction
Linear regression is a cornerstone of the machine learning and statistical world. If you’re starting your journey into either field, linear regression is one of the first fundamental concepts you need in your toolkit. But what exactly is this all-important idea, and why should you care? Let’s break it down.
What is Linear Regression?
Imagine you’ve collected data on the relationship between the size of a house (in square feet) and its price. As you plot this data, you notice that larger houses generally tend to be more expensive. That shouldn’t be too surprising.
Linear regression allows you to model this relationship mathematically. Its goal is to find a straight line, usually known as the “best fit line,” that most closely represents the relationship between your two variables (house size and price).
Why Does It Matter?
Here’s why linear regression is important:
- Understanding Relationships: Linear regression helps reveal whether a relationship exists between two variables and quantifies its strength. For example, it’s not enough to know size impacts house price; we want to know by how much.
- Simple Predictions: Using a linear regression model, you can make predictions. If you know the size of a house, you can plug it into your trained regression model, and it will predict the likely price.
- Foundations for More Complexity: While linear regression itself can seem simple, it’s a fundamental building block for understanding more complex machine learning models.
How Does It Work?
At its heart, linear regression tries to find the best line to minimize the errors between predictions and reality. Let me explain that:
Linear Regression Equation
The core equation of linear regression is:
Y = wX + b
Where:
- Y is the target variable we want to predict
- X is the independent variable
- w is the slope of the line (how much Y changes with a one-unit change in X)
- b is the y-intercept (where the line crosses the y-axis)
A common practice is to extend X with an all ones column as the first column, so we can unify w and b.
Y = wX
In this new form, X is a (n,m+1) matrix and w is a m+1 dimensional vector.
Loss Function
Linear regression often uses the mean squared error to measure how off its predictions are:
MSE = (1/n) * Σ(Yi - Ŷi)^2
Where:
- n is the number of data points
- Yi is the actual value for the i-th data point
- Ŷi is the predicted value for the i-th data point
Gradient Descent
Gradient descent needs to know the direction to change a and b in the equation to minimize error. Applying chain rule to the loss function, the gradient for linear regression is calculated as follows:
Gradient of w = (2/n) * Σ(Xi * (Ŷi - Yi))
Where:
- Xi is the value of the independent variable for the i-th data point
When to Use Linear Regression
Linear regression is a powerful tool, but it’s important to recognize when it’s the right tool for the job. Here’s a guide:
- When it Fits: Linear regression shines when the relationship between your variables appears reasonably linear. You can check this by plotting the data upfront. If you see a roughly straight-line pattern, linear regression is a good starting point.
- Simplicity and Interpretability: If you need a model that’s both easy to explain and easy to understand, linear regression often fits the bill. The importance of different variables is clear from the coefficients in the equation.
When to Consider Alternatives
- Nonlinear Relationships: When the relationship in your data is clearly curved or shows another non-linear pattern, linear regression will struggle. Other models like polynomial regression, decision trees, or neural networks might be better suited to capture complex patterns.
- Categorical Data: Linear regression works best with numeric data. If your key variables are categorical (e.g., “type of house”, “city”), you’ll need techniques like embeddings to adapt linear regression.
- Outliers: Linear regression can be sensitive to outliers (extreme data points). Always examine your data plot and consider the robustness of your model if outliers are suspected.
Python Implementation from Scratch
Let’s implement linear regression from scratch with Python. The code is available in this colab notebook.
import numpy as np
import matplotlib.pyplot as plt
def compute_gradient(X, y, theta):
"""Calculates the gradient of the loss function for linear regression.
Args:
X: The feature matrix.
y: The target values.
theta: The current model parameters.
Returns:
The gradient vector.
"""
m = X.shape[0]
prediction = np.dot(X, theta)
error = prediction - y
gradient = (2/m) * np.dot(X.T, error)
return gradient
def gradient_descent(X, y, learning_rate=0.01, iterations=1000):
"""Performs gradient descent to optimize linear regression parameters.
Args:
X: The feature matrix.
y: The target values.
learning_rate: The step size for updating parameters.
iterations: The number of iterations to run gradient descent.
Returns:
The optimized model parameters (theta).
"""
m = X.shape[0]
theta = np.zeros(2)
for _ in range(iterations):
gradient = compute_gradient(X, y, theta)
theta -= learning_rate * gradient
return theta
# Generate sample linear regression data
np.random.seed(42) # For reproducibility
n = 100
X = 2 * np.random.rand(n, 1)
# y = mx + b + noise
y = 3 + 5 * X + np.random.randn(n, 1)
# Reshap y from (n,1) to (n,)
y = y.flatten()
# Add a column of ones for the intercept term
X_b = np.c_[np.ones((n, 1)), X]
# Run gradient descent
theta = gradient_descent(X_b, y)
print(f"Theta (Intercept, Slope): {theta}")
# Evaluation (Using Mean Squared Error)
predictions = np.dot(X_b, theta)
mse = np.mean((predictions - y) ** 2)
print(f"Mean Squared Error: {mse}")
# Visualization
plt.figure(figsize=(10,6))
plt.scatter(X, y, s=30)
plt.plot(X, predictions, color='red', linewidth=2)
plt.xlabel("X")
plt.ylabel("y")
plt.title("Linear Regression with Gradient Descent")
plt.show()
Output:
Theta (Intercept, Slope): [3.23378652 4.75361082]
Mean Squared Error: 0.8066900678064775
Conclusion
Linear regression might seem simple, but it’s a powerful tool for understanding relationships in data and making predictions. It’s easy to explain and acts as a great starting point for exploring more complex machine learning techniques. If you’re diving into the world of data analysis, don’t underestimate the value of linear regression!