Decision tree regression
The decision tree model, as the name suggests, is a tree like model that has leaves, branches, and nodes. It is used for the task of regression which can be used to predict continuous valued outputs instead of discrete outputs.
Mean square error:
We can predict the accuracy of a decision tree regression model using mean square error (MSE). And just as the name suggests, MSE finds the mean of all the squared errors. It’s value lies between 0 to ∞. Lower the MSE, the closer is it’s prediction to actual.
Implementation:
We’ll build a sample decision tree regression model for better understanding.
Import all the necessary modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
Read the csv
df=pd.read_csv('decision-tree-regression-dataset.csv')
df
You can download the csv here
Select the input and output columns
x=df.iloc[:,0:1].values # all rows, only 0 column
y=df.iloc[:,1].values
Here, x is the input and y is the output
Fit decision tree regressor to the dataset
reg= DecisionTreeRegressor()
y_pred=DecisionTreeRegressor(random_state=0)
reg.fit(x,y)
Now, predict the output by giving some random value
y_predict=reg.predict([[4.3]])
You get the predicted value as output
array([60.])
We can get the MSE value to check how accurate our model is
from sklearn.metrics import mean_squared_error
# Given value
Y_true = [60]
# Calculation of Mean Squared Error (MSE)
mean_squared_error(Y_true,y_predict)
This command returns 0 which means that our model is a perfect fit although this doesn’t always happen with real life scenarios.
Now, we’ll plot a graph to understand the relationship between the input and output variables
plt.scatter(x,y,color='blue')
plt.plot(x,reg.predict(x),color='red')
plt.xlabel("Level")
plt.ylabel("Salary-LPA")
plt.title("Salary vs level")
plt.show()
From the graph, we can infer that as the level keeps increasing, salary keeps decreasing. This means that the level and salary are inversely proportional.
Also, we can infer from the graph that the outputs are continuous values and not discrete. This is the main reason as to why we are choosing decision tree regression for this problem.
Happy coding !