Keywords: Keras | Accuracy Visualization | Deep Learning Monitoring
Abstract: This article provides a comprehensive guide on visualizing accuracy and loss curves during neural network training in Keras, with special focus on test set accuracy plotting. Through analysis of model training history and test set evaluation results, multiple visualization methods including matplotlib and plotly implementations are presented, along with in-depth discussion of EarlyStopping callback usage. The article includes complete code examples and best practice recommendations for comprehensive model performance monitoring.
Introduction
Accurate monitoring of the training process is crucial for evaluating model performance and preventing overfitting in deep learning development. While Keras provides robust tools for recording and visualizing training metrics, many developers face challenges when visualizing test set accuracy. This article systematically explains how to comprehensively plot accuracy curves for training, validation, and test sets.
Model Training and History Data Recording
Keras's fit method automatically records various metrics to the history object during training. Proper configuration of training parameters ensures complete training history data:
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping
# Build example model
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=20))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Configure early stopping callback
early_stop = EarlyStopping(monitor='loss', patience=3)
# Train model and record history data
history = model.fit(X_train, y_train,
validation_data=(X_test, y_test),
epochs=50,
batch_size=32,
callbacks=[early_stop])
In the above code, the use of validation_data parameter ensures independent evaluation of validation set, while the EarlyStopping callback prevents overfitting by monitoring loss values.
Basic Visualization Methods
Using matplotlib enables quick plotting of accuracy and loss curves for training and validation sets:
import matplotlib.pyplot as plt
# Plot accuracy curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
# Plot loss curves
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.show()
Test Set Accuracy Acquisition and Visualization
Test set accuracy requires separate evaluation steps, then combined with training history for visualization:
# Evaluate test set performance
test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
# Extend history data to include test set accuracy
epochs = range(1, len(history.history['accuracy']) + 1)
plt.figure(figsize=(10, 5))
plt.plot(epochs, history.history['accuracy'], 'b-', label='Training Accuracy', alpha=0.7)
plt.plot(epochs, history.history['val_accuracy'], 'r-', label='Validation Accuracy', alpha=0.7)
# Add test set accuracy horizontal line
plt.axhline(y=test_accuracy, color='g', linestyle='--',
label=f'Test Accuracy: {test_accuracy:.4f}')
plt.title('Model Accuracy Comparison')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Advanced Visualization Solutions
Using Plotly enables creation of interactive charts for richer visualization experience:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Create dual Y-axis chart
fig = make_subplots(specs=[[{"secondary_y": True}]])
# Add loss curves (primary Y-axis)
fig.add_trace(
go.Scatter(x=list(epochs), y=history.history['loss'],
name="Training Loss", line=dict(color='blue')),
secondary_y=False,
)
fig.add_trace(
go.Scatter(x=list(epochs), y=history.history['val_loss'],
name="Validation Loss", line=dict(color='red')),
secondary_y=False,
)
# Add accuracy curves (secondary Y-axis)
fig.add_trace(
go.Scatter(x=list(epochs), y=history.history['accuracy'],
name="Training Accuracy", line=dict(color='green')),
secondary_y=True,
)
fig.add_trace(
go.Scatter(x=list(epochs), y=history.history['val_accuracy'],
name="Validation Accuracy", line=dict(color='orange')),
secondary_y=True,
)
# Add test set accuracy
fig.add_trace(
go.Scatter(x=[min(epochs), max(epochs)], y=[test_accuracy, test_accuracy],
name=f"Test Accuracy: {test_accuracy:.4f}",
line=dict(color='purple', dash='dash')),
secondary_y=True,
)
# Configure chart properties
fig.update_layout(
title_text="Comprehensive Model Performance Analysis",
width=800,
height=500
)
fig.update_xaxes(title_text="Epoch")
fig.update_yaxes(title_text="<b>Loss</b>", secondary_y=False)
fig.update_yaxes(title_text="<b>Accuracy</b>", secondary_y=True)
fig.show()
DataFrame Quick Visualization
For rapid exploratory analysis, use pandas to directly plot all historical metrics:
import pandas as pd
# Convert history data to DataFrame
history_df = pd.DataFrame(history.history)
# Add test set accuracy
history_df['test_accuracy'] = test_accuracy
# Plot all metrics
history_df.plot(figsize=(10, 6))
plt.title('All Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Metric Value')
plt.grid(True, alpha=0.3)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()
Best Practices and Considerations
In practical applications, consider the following points:
1. Data Splitting Strategy: Ensure consistent data distribution across training, validation, and test sets to avoid data leakage.
2. Early Stopping Optimization: Adjust EarlyStopping monitoring metrics and patience values based on specific tasks:
# More refined early stopping configuration
optimal_early_stop = EarlyStopping(
monitor='val_accuracy',
patience=10,
restore_best_weights=True,
mode='max'
)
3. Performance Monitoring: Regularly save best model weights for subsequent analysis and deployment:
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(
'best_model.h5',
monitor='val_accuracy',
save_best_only=True,
mode='max',
verbose=1
)
Conclusion
Through the methods introduced in this article, developers can comprehensively monitor deep learning model training processes. Combining visualization of training, validation, and test set accuracy enables more accurate model performance evaluation and timely detection of overfitting or underfitting issues. It is recommended to select appropriate visualization solutions based on specific project requirements and establish complete model evaluation workflows.