There are different ways to save TensorFlow models depending on the API you're using. This guide uses tf.keras, a high-level API to build and train models in TensorFlow. For other approaches see the TensorFlow Save and Restore guide or Saving in eager.
pip install pyyaml h5py # Required to save models in HDF5 format
TensorFlow
Learn
TensorFlow Core
Tutorials
Was this helpful?
Save and load models
Run in Google Colab
View source on GitHub
Download notebook
Model progress can be saved during and after training. This means a model can resume where it left off and avoid long training times. Saving also means you can share your model and others can recreate your work. When publishing research models and techniques, most machine learning practitioners share:
code to create the model, and
the trained weights, or parameters, for the model
Sharing this data helps others understand how the model works and try it themselves with new data.
Caution: TensorFlow models are code and it is important to be careful with untrusted code. See Using TensorFlow Securely for details.
Options
There are different ways to save TensorFlow models depending on the API you're using. This guide uses tf.keras, a high-level API to build and train models in TensorFlow. For other approaches see the TensorFlow Save and Restore guide or Saving in eager.
Setup
Installs and imports
Install and import TensorFlow and dependencies:
pip install pyyaml h5py # Required to save models in HDF5 format
import os
import tensorflow as tf
from tensorflow import keras
print(tf.version.VERSION)
2.5.0
Get an example dataset
To demonstrate how to save and load weights, you'll use the MNIST dataset. To speed up these runs, use the first 1000 examples:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Define a model
Start by building a simple sequential model:
# Define a simple sequential model
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
# Create a basic model instance
model = create_model()
# Display the model's architecture
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 512) 401920
_________________________________________________________________
dropout (Dropout) (None, 512) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
Save checkpoints during training
You can use a trained model without having to retrain it, or pick-up training where you left off in case the training process was interrupted. The tf.keras.callbacks.ModelCheckpoint callback allows you to continually save the model both during and at the end of training.
Checkpoint callback usage
Create a tf.keras.callbacks.ModelCheckpoint callback that saves weights only during training:
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images, test_labels),
callbacks=[cp_callback]) # Pass callback to training
# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10
32/32 [==============================] - 1s 7ms/step - loss: 1.1572 - sparse_categorical_accuracy: 0.6500 - val_loss: 0.7440 - val_sparse_categorical_accuracy: 0.7800
Epoch 00001: saving model to training_1/cp.ckpt
Epoch 2/10
32/32 [==============================] - 0s 4ms/step - loss: 0.4429 - sparse_categorical_accuracy: 0.8740 - val_loss: 0.5538 - val_sparse_categorical_accuracy: 0.8260
Epoch 00002: saving model to training_1/cp.ckpt
Epoch 3/10
32/32 [==============================] - 0s 4ms/step - loss: 0.2902 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.5104 - val_sparse_categorical_accuracy: 0.8410
Epoch 00003: saving model to training_1/cp.ckpt
Epoch 4/10
32/32 [==============================] - 0s 4ms/step - loss: 0.2225 - sparse_categorical_accuracy: 0.9430 - val_loss: 0.4639 - val_sparse_categorical_accuracy: 0.8530
Epoch 00004: saving model to training_1/cp.ckpt
Epoch 5/10
32/32 [==============================] - 0s 4ms/step - loss: 0.1649 - sparse_categorical_accuracy: 0.9610 - val_loss: 0.4476 - val_sparse_categorical_accuracy: 0.8610
Epoch 00005: saving model to training_1/cp.ckpt
Epoch 6/10
32/32 [==============================] - 0s 4ms/step - loss: 0.1192 - sparse_categorical_accuracy: 0.9800 - val_loss: 0.4489 - val_sparse_categorical_accuracy: 0.8570
Epoch 00006: saving model to training_1/cp.ckpt
Epoch 7/10
32/32 [==============================] - 0s 4ms/step - loss: 0.0888 - sparse_categorical_accuracy: 0.9870 - val_loss: 0.4190 - val_sparse_categorical_accuracy: 0.8650
Epoch 00007: saving model to training_1/cp.ckpt
Epoch 8/10
32/32 [==============================] - 0s 4ms/step - loss: 0.0674 - sparse_categorical_accuracy: 0.9920 - val_loss: 0.4086 - val_sparse_categorical_accuracy: 0.8670
Epoch 00008: saving model to training_1/cp.ckpt
Epoch 9/10
32/32 [==============================] - 0s 4ms/step - loss: 0.0507 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.4145 - val_sparse_categorical_accuracy: 0.8630
Epoch 00009: saving model to training_1/cp.ckpt
Epoch 10/10
32/32 [==============================] - 0s 4ms/step - loss: 0.0385 - sparse_categorical_accuracy: 0.9990 - val_loss: 0.4140 - val_sparse_categorical_accuracy: 0.8670
Epoch 00010: saving model to training_1/cp.ckpt
<tensorflow.python.keras.callbacks.History at 0x7f213dba4fd0>
This creates a single collection of TensorFlow checkpoint files that are updated at the end of each epoch:
os.listdir(checkpoint_dir)
['cp.ckpt.index', 'cp.ckpt.data-00000-of-00001', 'checkpoint']
As long as two models share the same architecture you can share weights between them. So, when restoring a model from weights-only, create a model with the same architecture as the original model and then set its weights.
Now rebuild a fresh, untrained model and evaluate it on the test set. An untrained model will perform at chance levels (~10% accuracy):
# Create a basic model instance
model = create_model()
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 2.3208 - sparse_categorical_accuracy: 0.0990
Untrained model, accuracy: 9.90%
Then load the weights from the checkpoint and re-evaluate:
# Loads the weights
model.load_weights(checkpoint_path)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc)