How to plot wholesome confusion matrix?

Deepanshu Jindal
3 min readMay 11, 2019

--

Confusion matrix provides an elegant way to analyse and evaluate the results of your classifier models. It is a specific table layout that gives you insights into what your classifier is doing right and what classes it is getting “confused” between. It also provides a good way to analyse the F1 scores for your models. The need for plotting confusion matrix arise when you wish to make sense of the results without having to go through raw numbers and also to present your results in a nice and appealing way.

If you are not familiar with confusion matrix it is a good time to read about them here before moving ahead.

While evaluating classifiers I often had to lookup StackOverflow or API References for specific syntax related to confusion matrix plotting and labelling, spending time in looking for questions like “How to plot confusion matrix with labels?”, “How to plot with different colour maps?” or “How to add raw numbers to confusion matrix plots?”. So here I present a comprehensive guide to plot different styles of confusion matrices.

The quick and easy way

One often needs just a rough approximate of how the model is performing for which one may need just a clean and quick method to visualise the confusion matrix without bothering about labels and everything. In such cases the standard sklearn and matplotlib modules are a great way to go.

Following is a confusion matrix which I generated for a classifier for sentiment mining on Yelp dataset which predicts the stars given by user based on the review text.

With these 5–10 lines of code you can get a nice visual representation of confusion matrix, instead of ugly table with raw numbers.

However, there is a problem with this confusion matrix. The matrix is not very informative for labels 1 and 2 due to the heavy skew in distribution towards the later classes. In such cases where there is a heavy skew in distribution of classes it is often helpful to normalize the confusion matrix. To do so we do a row-wise sum and divide each element with the sum of its row i.e. we divide each entry of a row with number of instances where the true label is that row.

conf_mat_norm = conf_mat.astype('float') / conf_mat.sum(axis=1)[:, np.newaxis]

Now using the normalised confusion matrix we get the following plot.

Clearly, this normalised plot helps us understand more easily which classes our classifier is confusing between. Such analysis is crucial while optimising for F1 scores.

Building a fully customizable confusion matrix

The above method was a quick method with a few calls to sklearn and matplotlib functions to get some decent confusion matrix plots. But often we need detailed confusion matrix with lots of customizations to suit our needs. In such a scenario we need a customizable function to handle all our requirements.

I found such a holy grail function of confusion matrix here. After adding a few touches over it here is the modified version which I always go to while plotting confusion matrix.

This function takes in confusion matrix as input and produces a plot with options to add labels for each entry of matrix, normalize the matrix and annotations for class labels.

A sample confusion matrix generated using the above code.

Confusion matrix plots are a great way for evaluating classifier performance and often provide crucial insights about the data distribution and errors being done by the model. Moreover, such plots are great for showing model performance in presentations :)

--

--

Deepanshu Jindal

IIT Delhi Senior Undergrad, Artificial Intelligence enthusiast