Loading and Continuing Training of Keras Models: Technical Analysis of Saving and Resuming Training States

Dec 07, 2025 · Programming · 11 views · 7.8

Keywords: Keras | Model Saving | Continued Training | Optimizer State | TensorFlow Format

Abstract: This article provides an in-depth exploration of saving partially trained Keras models and continuing their training. By analyzing model saving mechanisms, optimizer state preservation, and the impact of different data formats, it explains how to effectively implement training pause and resume. With concrete code examples, the article compares H5 and TensorFlow formats and discusses the influence of hyperparameters like learning rate on continued training outcomes, offering systematic guidance for model management in deep learning practice.

Basic Mechanisms of Model Saving and Loading

In deep learning practice, it is often necessary to save partially trained models for later continuation with new data. The Keras framework provides convenient model saving functionality through the model.save() method. This method saves not only the model architecture and weights but also the optimizer state, which is crucial for seamless training continuation.

From a technical implementation perspective, when calling model.save('partly_trained.h5'), Keras serializes the following information to file: model structure definition (layer configuration, connectivity), current values of all trainable parameters, optimizer type and its internal state (such as momentum caches, learning rate scheduling states). This comprehensive saving approach ensures that training can resume from exactly the same state after reloading.

Preservation and Verification of Optimizer State

As noted in the best answer, the integrity of optimizer state is key to successful continued training. To verify this, a simple test procedure can be designed: first train the model for several epochs, save the model, then immediately reload and continue training on the same dataset. If the optimizer state is correctly saved, the training loss should continue to decrease smoothly without significant jumps or resets.

In actual code, this process can be represented as:

# Initial training
model.fit(first_data, first_labels, epochs=10, batch_size=32)

# Save model
model.save('intermediate_model.h5')

# Reload
from keras.models import load_model
reloaded_model = load_model('intermediate_model.h5')

# Verify optimizer state
reloaded_model.fit(first_data, first_labels, epochs=5, batch_size=32)

If the training loss curve after reloading smoothly connects with the pre-save curve, it indicates correct optimizer state preservation. If sudden loss increases occur, it may suggest the optimizer state was not properly restored.

Data Format Selection: H5 vs. TensorFlow Format Comparison

With the widespread adoption of TensorFlow 2.x, the choice of model saving format has become more important. The traditional H5 format (.h5 files) can completely save model information in most cases but may have limitations when dealing with custom loss functions or complex optimizer configurations.

The TensorFlow format (.tf directory structure) provides more reliable saving mechanisms, particularly for preserving optimizer states. Code for saving models in TensorFlow format:

# Save using TensorFlow format
model.save('./saved_model_tf', save_format='tf')

# Load TensorFlow format model
loaded_model = tf.keras.models.load_model('./saved_model_tf')

This format saves the model as a directory structure containing multiple files, where saved_model.pb stores the model architecture and the variables/ directory stores weights and optimizer states. This separated storage approach improves compatibility and extensibility.

Continuity of Learning Rate and Optimizer Configuration

When continuing training, optimizer configuration parameters must remain consistent with those at saving time. Particularly the learning rate—if a different initial learning rate is used when retraining, it may cause training instability. For example, if the original training ended with a learning rate reduced to 0.000003 via a scheduler, but retraining starts at 0.0003, training loss may experience severe fluctuations.

To ensure learning rate continuity, the current learning rate can be recorded before saving the model:

# Get current learning rate
current_lr = keras.backend.get_value(model.optimizer.lr)
print(f"Current learning rate: {current_lr}")

# Save model
model.save('model_with_lr.h5')

After reloading, verify that the learning rate remains consistent:

loaded_model = load_model('model_with_lr.h5')
loaded_lr = keras.backend.get_value(loaded_model.optimizer.lr)
print(f"Loaded learning rate: {loaded_lr}")

Practical Considerations in Real Applications

When continuing model training in real projects, several additional important factors must be considered:

First, new training data should have distribution characteristics similar to the original data. If data distribution changes significantly, it may be necessary to reevaluate whether continuing training is appropriate or consider using transfer learning techniques.

Second, batch size should remain consistent. Changing batch size affects the variance of gradient estimates, thereby impacting the optimization process. If batch size must be changed, learning rate adjustments may be necessary.

Third, for training using callback functions such as ModelCheckpoint, ReduceLROnPlateau, etc., ensure these callback states can also be correctly saved and restored. TensorFlow format typically handles this situation better.

Finally, regularly evaluating model performance on validation sets is crucial. Continued training should not proceed blindly but should be based on informed decisions from validation metrics to avoid overfitting.

Code Example: Complete Continued Training Workflow

The following is a complete example demonstrating model saving and continued training on the MNIST dataset:

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

# Set random seed for reproducibility
np.random.seed(42)

# Load and preprocess data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0

# Create model
model = keras.Sequential([
    layers.Dense(512, activation='relu', input_shape=(784,)),
    layers.Dropout(0.2),
    layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# First phase training
print("Starting first phase training...")
history1 = model.fit(x_train[:3000], y_train[:3000],
                     epochs=10,
                     batch_size=32,
                     validation_split=0.2,
                     verbose=1)

# Save model
print("\nSaving model...")
model.save('partly_trained_model', save_format='tf')

# Reload model
print("\nReloading model...")
loaded_model = keras.models.load_model('partly_trained_model')

# Verify optimizer state
print("\nVerifying optimizer state...")
initial_loss, initial_acc = loaded_model.evaluate(x_test, y_test, verbose=0)
print(f"Initial test loss after loading: {initial_loss:.4f}, accuracy: {initial_acc:.4f}")

# Second phase training
print("\nStarting second phase training...")
history2 = loaded_model.fit(x_train[3000:6000], y_train[3000:6000],
                            epochs=10,
                            batch_size=32,
                            validation_split=0.2,
                            verbose=1)

# Final evaluation
final_loss, final_acc = loaded_model.evaluate(x_test, y_test, verbose=0)
print(f"\nFinal test loss: {final_loss:.4f}, accuracy: {final_acc:.4f}")

This example shows the complete workflow: initial training, model saving, reloading, state verification, continued training, and final evaluation. By monitoring training loss and validation metrics, the effectiveness of the continued training process can be ensured.

Conclusion

The Keras framework provides powerful model saving and loading capabilities, making continued training possible. The key to successfully implementing this functionality lies in ensuring complete preservation of optimizer states, selecting appropriate data formats, and maintaining consistency in training configurations. By following the best practices introduced in this article, developers can effectively manage long-term training processes, fully utilize additional data without needing to retrain models from scratch.

As deep learning projects continue to scale, model management becomes increasingly important. Mastering the techniques of model saving and continued training not only improves development efficiency but also establishes the foundation for implementing advanced application scenarios such as incremental learning and continual learning.

Copyright Notice: All rights in this article are reserved by the operators of DevGex. Reasonable sharing and citation are welcome; any reproduction, excerpting, or re-publication without prior permission is prohibited.