Keywords: Confusion Matrix | Scikit-learn | Matplotlib | Data Visualization | Machine Learning Evaluation
Abstract: This article provides a comprehensive guide on visualizing classifier performance with labeled confusion matrices using Scikit-learn and Matplotlib. It begins by analyzing the limitations of basic confusion matrix plotting, then focuses on methods to add custom labels via the Matplotlib artist API, including setting axis labels, titles, and ticks. The article compares multiple implementation approaches, such as using Seaborn heatmaps and Scikit-learn's ConfusionMatrixDisplay class, with complete code examples and step-by-step explanations. Finally, it discusses practical applications and best practices for confusion matrices in model evaluation.
Introduction
In machine learning classification tasks, the confusion matrix is a crucial tool for evaluating model performance. It illustrates the correspondence between true labels and predicted labels, helping to identify patterns of misclassification. However, when using Scikit-learn's confusion_matrix function and Matplotlib's matshow for basic plotting, the default output only displays numerical values without class labels, reducing interpretability. Based on high-scoring Stack Overflow answers, this article delves into how to add custom labels through the Matplotlib artist API and offers comprehensive solutions by integrating other methods.
Limitations of Basic Confusion Matrix Plotting
The initial code computes the confusion matrix with confusion_matrix and visualizes it via matshow:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
y_test = ['business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business']
pred = ['health', 'business', 'business', 'business', 'business', 'business', 'health', 'health', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'health', 'health', 'business', 'health']
cm = confusion_matrix(y_test, pred)
plt.matshow(cm)
plt.title('Confusion matrix of the classifier')
plt.colorbar()
plt.show()This code generates a confusion matrix plot, but the axes show only numerical indices (e.g., 0, 1) instead of actual class labels (e.g., 'business', 'health'), making interpretation difficult. The issue arises because Matplotlib does not automatically map label names by default.
Adding Labels Using the Matplotlib Artist API
To address this, leverage Matplotlib's low-level artist API by explicitly setting axis tick labels. Here is the improved code:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
labels = ['business', 'health']
cm = confusion_matrix(y_test, pred, labels=labels)
print(cm)
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(cm)
plt.title('Confusion matrix of the classifier')
fig.colorbar(cax)
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()Step-by-step code analysis:
- Define Label List:
labels = ['business', 'health']specifies the class order, ensuring consistency with the confusion matrix. - Compute Confusion Matrix:
confusion_matrix(y_test, pred, labels=labels)includes thelabelsparameter to align matrix rows and columns as specified. - Create Figure and Axes:
fig.add_subplot(111)creates a single subplot, returning anaxobject for further customization. - Display Matrix:
ax.matshow(cm)plots the confusion matrix on the axes, returning acaxobject for the colorbar. - Set Labels:
ax.set_xticklabels([''] + labels)andax.set_yticklabels([''] + labels)set x-axis and y-axis tick labels. The empty string['']is used for alignment to prevent label offset. - Add Title and Axis Labels:
plt.xlabel('Predicted')andplt.ylabel('True')clarify axis meanings.
This method provides direct control over Matplotlib elements, offering high flexibility for complex customizations.
Comparison of Alternative Implementation Approaches
Using Seaborn Heatmap
Seaborn's heatmap function offers a high-level interface that simplifies label addition:
import seaborn as sns
import matplotlib.pyplot as plt
ax = plt.subplot()
sns.heatmap(cm, annot=True, fmt='g', ax=ax)
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')
ax.xaxis.set_ticklabels(['business', 'health'])
ax.yaxis.set_ticklabels(['business', 'health'])
plt.show()Advantages include automatic cell annotation (annot=True) and disabling scientific notation (fmt='g'), but it requires installing the Seaborn library.
Using Scikit-learn's ConfusionMatrixDisplay
Scikit-learn 1.0+ introduced the ConfusionMatrixDisplay class, designed specifically for confusion matrix visualization:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
cm = confusion_matrix(y_test, pred, labels=labels)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot()
plt.show()This class encapsulates common settings, allowing direct label specification via display_labels and supporting advanced features like normalization. For example, the from_predictions method enables direct plotting from predictions:
disp = ConfusionMatrixDisplay.from_predictions(y_test, pred, display_labels=labels)
disp.plot()
plt.show()This approach results in concise, integrated code and is recommended for standard scenarios.
Practical Applications and Best Practices
Confusion matrices are not only for visualization but also for comprehensive model evaluation when combined with other metrics:
- Normalization: Use parameters like
normalize='true'to convert the matrix into proportions, facilitating comparisons across datasets of different sizes. - Multi-class Handling: For multi-class problems, ensure the
labelslist includes all classes to avoid index errors. - Performance Analysis: Calculate metrics such as accuracy, precision, and recall from the confusion matrix to identify model weaknesses, such as misclassification of specific classes.
Example: Adding accuracy to the plot:
import numpy as np
accuracy = np.trace(cm) / np.sum(cm)
plt.xlabel(f'Predicted\nAccuracy: {accuracy:.2f}')
plt.show()In real-world projects, using ConfusionMatrixDisplay is advised to maintain code consistency and maintainability.
Conclusion
Adding labels to confusion matrices via the Matplotlib artist API is an effective method that offers high customizability. Integrating tools like Seaborn or Scikit-learn can further streamline the process. Understanding these techniques aids in creating clear, informative visualizations that enhance model evaluation efficiency. Developers should choose the appropriate approach based on project needs and adhere to best practices to ensure accurate and interpretable results.