Display Prediction Probabilities of Multiclass Classification Using Bar Chart

Photo by Annie Spratt on Unsplash

Visualize, visualize, visualize

As a data scientist, you probably always hear that quote. The quote is so plain, and there is no need to think to understand it. As data scientists, the most thing we will do is visualisation. Why? Because most people understand visuals more than the bulk of data.

In this article, I want to share as briefly as I can. This idea is come out when I want to display prediction probabilities of multiclass classification instead of only outputting one value. The data that I have is like this.

As you can see above, the data is an array with nine data representing your classification problem's prediction probabilities. You can duplicate the output above using NumPy to create the same array, so you only need to focus on creating the bar chart.

If you run the code in google collaboratory, you can see this output.

Because the array is multidimensional, we need to take the data we need. If you run, a.shape you will see the output like this (1, 9) . I will try to explain it using the image below.

From the shape, you have a two dimensions array; the outer dimensions contain one data inside it; it’s the inner dimensions. The inner dimensions contain 9 data, the prediction probabilities we want to display before. To understand more about multidimensional arrays, you could read this article.

After understanding the picture above, we know that our prediction probabilities lie on the index 0 of the array. You can try to runa[0] to access the inner dimensions.

Because the inner dimensions are a one-size array, we could display it in the chart we want. Before that, you need to save the inner dimensions into variables to make it easier to plot the data. Also, you need to declare the class names of your predictions.

Right now, you have the data and the label. The next move is creating the bar chart. Below is the code that I used.

Let’s break the code:

  1. Figure Size → This line is to set the image size of the data.
  2. Declare Plot → This line is to create a bar chart as I was using plt.bar you could try another chart using commands like barh, scatter, etc.
  3. Title → This line to give the title of your bar chart.
  4. Label (xlabel & ylabel) → This line is helping your bar chart to be more understandable.
  5. Rotation (xticks) → This line will edit the rotation of the label you have. For example, I am using rotation in xticks.
  6. Show → This line is used to display your plot more safely. (Without alert)
  7. Savefig → If you want to save your image after running the code, this line will save your bar chart on your computer.

There is more command you could explore, but it’s beyond my target for this article. You can read the Matplotlib documentation to get more insight here. After you run the code, you will get the output like this.

From the output above, we can understand that the created model predicts the data I give as a bird with the probabilities above 0.4; the model also thinks the data that I give as a horse or a cat with the probabilities around 0.2, etc. It’s easier to understand, right?. Using prediction probabilities also gives us insight into if the model got wrong predictions.

Conclusion

In this article, I try to share how to display prediction probabilities into a bar chart in multiclass classification. This will give us insight whenever our model predicts the wrong classes because remember this motto.

There is no perfect model, there only us who creates the model ~ handhikayp

Thanks for reading.

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store