Confusion matrix

Rakesh Rajpurohit
3 min readMay 12, 2018

--

In this story, I am going to explain how to plot the confusion matrix, and visualization using python and after that understanding/reading confusion matrix.

Introduction:

Confusion matrix represents the accuracy of the model in the tabular format by representing the count of correct/incorrect labels.

Calculating/Plotting confusion matrix:

Below is the process for calculating a confusion Matrix.

  1. You need a test dataset or a validation dataset with expected outcome values.
  2. Make a prediction for each row in your test dataset.
  3. From the expected outcomes and predictions count:
  • The number of correct predictions for each class, and
  • The number of incorrect predictions for each class, organized by the class that was predicted.

These numbers are then organized into a table, or a matrix as follows:

  • Expected down the side: Each row of the matrix corresponds to a predicted class.
  • Predicted across the top: Each column of the matrix corresponds to an actual class.

The counts of correct and incorrect classification are then filled into the table.

Reading Confusion matrix:

The total number of correct predictions for a class go into the expected row for that class value and the predicted column for that class value.

In the same way, the total number of incorrect predictions for a class go into the expected row for that class value and the predicted column for that class value.

The diagonal elements represent the number of points for which the predicted label is equal to the true label, while off-diagonal elements are those that are mislabelled by the classifier. The higher the diagonal values of the confusion matrix the better, indicating many correct predictions.

Plot Confusion Matrix with Python :

import itertoolsimport numpy as npimport matplotlib.pyplot as pltfrom sklearn import svm, datasetsfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import confusion_matrix# import some data to play withiris = datasets.load_iris()X = iris.datay = iris.targetclass_names = iris.target_names# Split the data into a training set and a test setX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)y_pred = classifier.fit(X_train, y_train).predict(X_test)def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')

print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
title='Confusion matrix, without normalization')
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
title='Normalized confusion matrix')
plt.show()
Confusion matrix, without normalization
[[13 0 0]
[ 0 10 6]
[ 0 0 9]]
Normalized confusion matrix
[[ 1. 0. 0. ]
[ 0. 0.62 0.38]
[ 0. 0. 1. ]]

--

--