Plotting Scikit-Learn Classification Report for Analysis

Doug Creates
9 min readMar 23, 2024

--

The problem involves creating a visual representation of a classification report generated by scikit-learn, utilizing matplotlib for plotting to enhance understanding and analysis of model performance, transforming numerical data into a comprehensible chart.

This is a recipe from PythonFleek. Get the free e-book today!

Code

Craft visual depiction, scikit-learn report, via matplotlib, for elite model analysis.

import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import seaborn as sns
import pandas as pd

# Load dataset
iris = load_iris()
X, y = iris.data, iris.target

# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train)

# Predict
y_pred = model.predict(X_test)

# Generate classification report
class_report = classification_report(y_test, y_pred, output_dict=True)


def visualize_report(report):
# Define custom colors
colors = ['#ffdfdf', '#dfffff', '#dfffdf', '#dfdfff', '#ffdfff', '#ffffdf']

# Extracting metrics for each class
metrics = ['precision', 'recall', 'f1-score', 'support']
data = {metric: [] for metric in metrics}
labels = []

# Include class-specific metrics
for cls, metrics_values in report.items():
if cls.isdigit() or cls in ['macro avg', 'weighted avg']:
labels.append(cls)
for metric in metrics:
data[metric].append(metrics_values.get(metric, None))

# Convert data to DataFrame for easy plotting
df = pd.DataFrame(data, index=labels)

# Creating subplots for each metric
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
for ax, metric in zip(axes, metrics[:-1]): # Exclude 'support' for plotting
sns.barplot(ax=ax, x=df.index, y=metric, data=df, palette=colors)
ax.set_title(f'{metric.capitalize()} by Class')
ax.set_ylim(0, 1.1)

plt.tight_layout()
plt.show()

visualize_report(class_report)

Explanation

Data scientists, machine learning engineers, and analysts looking to evaluate and present the performance of classification models.

Import necessary libraries: matplotlib for plotting, sklearn.metrics for generating the classification report, sklearn.model_selection for splitting the dataset, sklearn.datasets for loading a sample dataset, and sklearn.ensemble for using a RandomForestClassifier.
Load dataset: The Iris dataset is loaded using sklearn’s load_iris function.
Split dataset: The dataset is split into training and testing sets using train_test_split.
Train model: A RandomForestClassifier model is trained on the training set.
Predict: The trained model is used to make predictions on the test set.
Generate classification report: A classification report is generated using sklearn’s classification_report function, with output_dict=True to facilitate plotting.
Plotting setup: A matplotlib figure and axis are created for plotting.
Plotting the report: The accuracy from the classification report is plotted using matplotlib’s matshow function with a color map.
Show plot: The plot is displayed using plt.show().
Note: The code snippet assumes familiarity with sklearn’s API and basic plotting with matplotlib.

Creating Visual Classification Reports with Scikit-learn and Matplotlib

Transform complex classification metrics into intuitive visual reports to enhance model performance analysis.

Why: Visual representations of classification reports provide a more intuitive understanding of model performance metrics such as precision, recall, and F1-score, facilitating better communication and decision-making.

Install: pip install scikit-learn matplotlib

Algorithm

A method to plot a classification report generated by scikit-learn using matplotlib, making it easier to understand and analyze the performance of machine learning classification models.

import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
import numpy as np

# Generating synthetic dataset
X, y = make_classification(n_samples=1000, n_features=20, n_classes=3, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Training a simple RandomForest Classifier
clf = RandomForestClassifier(random_state=42)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

# Generating classification report
report = classification_report(y_test, y_pred, output_dict=True)

# Function to plot classification report
def plot_classification_report(report):
labels = list(report.keys())[:-3] # Exclude 'accuracy', 'macro avg', 'weighted avg'
values = [report[label]['precision'] for label in labels] + [report[label]['recall'] for label in labels] + [report[label]['f1-score'] for label in labels]
labels = ['Precision']*len(labels) + ['Recall']*len(labels) + ['F1-Score']*len(labels)
metrics = list(report.keys())[:-3] * 3
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(metrics, values, color=['blue', 'green', 'red']*len(report.keys()))
ax.set_xlabel('Scores')
ax.set_title('Classification Report')
plt.tight_layout()
plt.show()

# Plotting the classification report
plot_classification_report(report)

Demo 1

Create a visual depiction of scikit-learn’s classification report, leveraging matplotlib for insightful model performance analysis.

Craft visual depiction of scikit-learn’s classification metrics, employing matplotlib for insightful model performance analysis.

# Demo 1: Simple visualization of a classification report using matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
import numpy as np

# Generating synthetic dataset
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Training a simple RandomForest Classifier
clf = RandomForestClassifier(random_state=42)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

# Generating classification report
report = classification_report(y_test, y_pred, output_dict=True)

# Function to plot classification report
def plot_classification_report(report):
labels = list(report.keys())[:-3] # Exclude 'accuracy', 'macro avg', 'weighted avg'
values = [report[label]['precision'] for label in labels] + [report[label]['recall'] for label in labels] + [report[label]['f1-score'] for label in labels]
labels = ['Precision']*len(labels) + ['Recall']*len(labels) + ['F1-Score']*len(labels)
metrics = list(report.keys())[:-3] * 3
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(metrics, values, color=['#9999ff', '#99ff99', '#ff9999']*len(report.keys()))
ax.set_xlabel('Scores')
ax.set_title('Classification Report')
plt.tight_layout()
plt.show()

# Plotting the classification report
plot_classification_report(report)

Case Study

Suppose we were tasked with improving the interpretability of machine learning model performance for a team of data scientists at a tech company. The team has been using scikit-learn to build classification models but finds the numerical classification reports challenging to quickly interpret. To address this, we decide to create a visual representation of the classification report using matplotlib, a popular plotting library in Python. We start by generating a classification report for a model predicting customer churn. Next, we extract the precision, recall, f1-score, and support for each class from the report and plot these metrics as a bar chart. This visual approach allows the team to easily identify which classes the model performs well on and which ones need improvement, leading to more efficient model tuning and better overall performance.

Pitfalls

Complexity of Data: Handling and parsing the classification report’s data can be complex, especially for multi-class scenarios.
Plot Customization: Customizing the plot for clarity, such as adjusting labels and colors, can be challenging for those unfamiliar with matplotlib.
Scalability: The approach might need to be adapted for very large datasets or a high number of classes, which could clutter the visualization.
Interpretation: Misinterpretation of the plotted data can lead to incorrect conclusions about model performance.
Integration: Integrating this visualization into existing model evaluation workflows may require additional steps or adjustments.

Tips for Production

Automation: Automate the process of generating and plotting the classification report as part of the model evaluation pipeline.
Interactivity: Implement interactive elements in the plot, such as tooltips or zooming, to allow users to explore the data more deeply.
Customization: Provide options for customization, including the ability to select which metrics to display and how to format the plot.
Integration: Ensure the plotting function integrates seamlessly with existing data science and machine learning workflows.
Scalability: Optimize the plotting function for performance to handle large datasets and a high number of classes efficiently.

Recommended Reading

Scikit-learn: A powerful Python library for machine learning that provides simple and efficient tools for data analysis and modeling. https://en.wikipedia.org/wiki/Scikit-learn
Matplotlib: A comprehensive library for creating static, animated, and interactive visualizations in Python. https://en.wikipedia.org/wiki/Matplotlib
Classification report: A performance evaluation metric in machine learning that shows the precision, recall, F1-score, and support for each class. https://en.wikipedia.org/wiki/Precision_and_recall

Encore

Throughout this manual, we’ve explored the intricacies of visualizing scikit-learn classification reports using matplotlib. By generating synthetic datasets, training RandomForestClassifiers, and crafting detailed classification reports, we’ve demonstrated how to transform these numerical reports into insightful visual representations. This approach not only aids in the interpretability of model performance metrics but also enhances the decision-making process for data scientists and machine learning engineers. As we conclude, it’s clear that integrating visual analysis into model evaluation workflows can significantly improve our understanding and communication of model performance. Moving forward, I plan to further refine these visualization techniques, ensuring they remain an integral part of our model evaluation and tuning processes.

# Import necessary libraries
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification

# Generate a synthetic dataset
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Initialize and train the classifier
clf = RandomForestClassifier(random_state=42)
clf.fit(X_train, y_train)

# Predict on the test set
y_pred = clf.predict(X_test)

# Generate the classification report
report = classification_report(y_test, y_pred, output_dict=True)

# Plotting the classification report
fig, ax = plt.subplots(figsize=(10, 6))

# Categories for the classification
categories = list(report.keys())[:-3] # Exclude 'accuracy', 'macro avg', and 'weighted avg'

# Extracting precision, recall, and f1-score
precision = [report[category]['precision'] for category in categories]
recall = [report[category]['recall'] for category in categories]
f1_score = [report[category]['f1-score'] for category in categories]

# Setting the positions and width for the bars
pos = list(range(len(categories)))
width = 0.25

# Plotting each metric
plt.bar(pos, precision, width, alpha=0.5, color='#ffdfdf', label='Precision')
plt.bar([p + width for p in pos], recall, width, alpha=0.5, color='#dfffff', label='Recall')
plt.bar([p + width*2 for p in pos], f1_score, width, alpha=0.5, color='#dfffdf', label='F1-Score')

# Adding the aesthetics
plt.xlabel('Category')
plt.ylabel('Score')
plt.title('Classification Report')
plt.xticks([p + width for p in pos], categories)

# Adding the legend and showing the plot
plt.legend(['Precision', 'Recall', 'F1-Score'], loc='upper left')
plt.grid()
plt.show()

# Conclusion: This script demonstrates how to create a visual representation of a classification report using matplotlib. It includes generating a synthetic dataset, training a RandomForestClassifier, predicting on the test set, generating a classification report, and finally plotting precision, recall, and f1-score for each category in the report.

Demo 2

# Demo 2: Comprehensive visualization of classification report with precision, recall, f1-score, and support using matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
import numpy as np

# Generating synthetic dataset
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Training a RandomForest Classifier
clf = RandomForestClassifier(random_state=42)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

# Generating classification report
report = classification_report(y_test, y_pred, output_dict=True)

# Function to plot classification report with support
def plot_classification_report_with_support(report):
labels = list(report.keys())[:-3] # Exclude 'accuracy', 'macro avg', 'weighted avg'
metrics = ['precision', 'recall', 'f1-score', 'support']
data = np.array([[report[label][metric] for metric in metrics] for label in labels])
fig, ax = plt.subplots(figsize=(12, 6))
cax = ax.matshow(data, cmap='coolwarm')
plt.xticks(range(len(metrics)), metrics)
plt.yticks(range(len(labels)), labels)
plt.colorbar(cax)
# Adding the text
for (i, j), val in np.ndenumerate(data):
ax.text(j, i, f'{val:.2f}', ha='center', va='center', color='white')
plt.xlabel('Metrics')
plt.ylabel('Classes')
plt.title('Classification Report with Support')
plt.show()

# Plotting the classification report with support
plot_classification_report_with_support(report)

scikit learn, classification report, plot, visualization, machine learning

--

--