Keywords: TensorFlow | model saving | model restoration | checkpoints | SavedModel
Abstract: This article provides a comprehensive guide on saving and restoring trained models in TensorFlow, covering methods such as checkpoints, SavedModel, and HDF5 formats. It includes code examples using the tf.keras API and discusses advanced topics like custom objects. Aimed at machine learning developers and researchers.
In machine learning and deep learning, saving the state of a trained model is a critical step. It not only allows resuming training from interruptions but also facilitates model sharing and deployment. TensorFlow, as a popular deep learning framework, provides multiple methods for saving and restoring models. This article focuses on model saving and restoration in TensorFlow 2.x using the tf.keras API, covering checkpoints, SavedModel format, and HDF5 format, with detailed code examples.
Methods for Saving Models
TensorFlow 2.x recommends using the high-level API tf.keras for building and training models. The main methods for saving models are checkpoints, SavedModel, and HDF5 formats. Checkpoints are primarily used for saving model weights and are suitable for resuming training; the SavedModel format saves the entire model, including architecture, weights, and training configuration, making it ideal for deployment; HDF5 is another way to save the entire model, but it is gradually being replaced by the new .keras format. The new .keras format is the recommended format for Keras v3, as it is simpler, more efficient, and supports name-based saving.
Saving Checkpoints
Checkpoints allow periodic saving of model weights during training to resume after interruptions. Using tf.train.Checkpoint and tf.train.CheckpointManager makes it easy to manage checkpoints. The following example code demonstrates how to define a model, train it, and save checkpoints.
import tensorflow as tf
# Define a simple linear model
class Net(tf.keras.Model):
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
# Create model instance and optimizer
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
# Assume a toy dataset
def toy_dataset():
inputs = tf.range(10.0)[:, None]
labels = inputs * 5.0 + tf.range(5.0)[None, :]
return tf.data.Dataset.from_tensor_slices(dict(x=inputs, y=labels)).repeat().batch(2)
dataset = toy_dataset()
iterator = iter(dataset)
# Create Checkpoint object including step, optimizer, model, and iterator
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)
# Training step function
def train_step(net, example, optimizer):
with tf.GradientTape() as tape:
output = net(example["x"])
loss = tf.reduce_mean(tf.abs(output - example["y"]))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
# Restore the latest checkpoint if available
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
# Training loop
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))Restoring the Model
To restore the model, you can use the same Checkpoint object in another script or session. First, reinitialize the objects, then use the CheckpointManager to restore the latest checkpoint.
import tensorflow as tf
# Recreate objects
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)
# Restore checkpoint
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
# Continue training or evaluation
for _ in range(50):
example = next(iterator)
# Use the model for prediction or further training
output = net(example["x"])
print("Output:", output)Using the SavedModel Format
The SavedModel format saves the entire model, including the computation graph, making it suitable for deployment to production environments. Use the model.save() method to easily save as SavedModel.
# Save the model as SavedModel
model.save('saved_model/my_model') # Assume model is a tf.keras.Model instance
# Restore the model
new_model = tf.keras.models.load_model('saved_model/my_model')
# Use the restored model for prediction
example = next(iter(toy_dataset()))
predictions = new_model.predict(example["x"])
print("Predictions:", predictions)Advanced Topics: Custom Objects
If the model contains custom layers or models, additional steps are needed for saving and loading. TensorFlow provides methods such as using the @tf.keras.utils.register_keras_serializable decorator or passing the custom_objects parameter to load_model.
# Example: Custom layer
@tf.keras.utils.register_keras_serializable()
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, units=32, **kwargs):
super(CustomLayer, self).__init__(**kwargs)
self.units = units
def build(self, input_shape):
self.w = self.add_weight(shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=True)
self.b = self.add_weight(shape=(self.units,),
initializer='random_normal',
trainable=True)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
def get_config(self):
config = super(CustomLayer, self).get_config()
config.update({'units': self.units})
return config
# Build a model with the custom layer
model = tf.keras.Sequential([
CustomLayer(10),
tf.keras.layers.Dense(1)
])
# Save the model
model.save('custom_model.keras')
# Load the model, no extra parameters needed since registered
loaded_model = tf.keras.models.load_model('custom_model.keras')Conclusion
TensorFlow's model saving and restoration capabilities are powerful and flexible. Checkpoints are suitable for training resumption, SavedModel for deployment, and the .keras format is the recommended efficient method. Through the examples in this article, developers can easily implement model persistence, improving work efficiency. It is advised to choose the appropriate method based on the specific scenario and handle custom objects to ensure compatibility.