Tensorflow: how to save/restore a model?

  • After you train a model in Tensorflow:

    1. How do you save the trained model?
    2. How do you later restore this saved model?
      September 14, 2020 2:42 PM IST
  • I am improving my answer to add more details for saving and restoring models.

    In(and after) Tensorflow version 0.11:

    Save the model:

    import tensorflow as tf
    #Prepare to feed input, i.e. feed_dict and placeholders
    w1 = tf.placeholder("float", name="w1")
    w2 = tf.placeholder("float", name="w2")
    b1= tf.Variable(2.0,name="bias")
    feed_dict ={w1:4,w2:8}
    #Define a test operation that we will restore
    w3 = tf.add(w1,w2)
    w4 = tf.multiply(w3,b1,name="op_to_restore")
    sess = tf.Session()
    #Create a saver object which will save all the variables
    saver = tf.train.Saver()
    #Run the operation by feeding input
    print sess.run(w4,feed_dict)
    #Prints 24 which is sum of (w1+w2)*b1 
    #Now, save the graph
    saver.save(sess, 'my_test_model',global_step=1000)

    Restore the model:

    import tensorflow as tf
    #First let's load meta graph and restore weights
    saver = tf.train.import_meta_graph('my_test_model-1000.meta')
    # Access saved Variables directly
    # This will print 2, which is the value of bias that we saved
    # Now, let's access and create placeholders variables and
    # create feed-dict to feed new data
    graph = tf.get_default_graph()
    w1 = graph.get_tensor_by_name("w1:0")
    w2 = graph.get_tensor_by_name("w2:0")
    feed_dict ={w1:13.0,w2:17.0}
    #Now, access the op that you want to run. 
    op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
    print sess.run(op_to_restore,feed_dict)
    #This will print 60 which is calculated 
      October 14, 2020 4:04 PM IST
  • My environment: Python 3.6, Tensorflow 1.3.0

    Though there have been many solutions, most of them is based on tf.train.Saver. When we load a .ckpt saved by Saver, we have to either redefine the tensorflow network or use some weird and hard-remembered name, e.g. 'placehold_0:0','dense/Adam/Weight:0'. Here I recommend to use tf.saved_model, one simplest example given below, your can learn more from Serving a TensorFlow Model:

    Save the model:

    import tensorflow as tf
    # define the tensorflow network and do some trains
    x = tf.placeholder("float", name="x")
    w = tf.Variable(2.0, name="w")
    b = tf.Variable(0.0, name="bias")
    h = tf.multiply(x, w)
    y = tf.add(h, b, name="y")
    sess = tf.Session()
    # save the model
    export_path =  './savedmodel'
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
    prediction_signature = (
          inputs={'x_input': tensor_info_x},
          outputs={'y_output': tensor_info_y},
      sess, [tf.saved_model.tag_constants.SERVING],

    Load the model:

    import tensorflow as tf
    signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    input_key = 'x_input'
    output_key = 'y_output'
    export_path =  './savedmodel'
    meta_graph_def = tf.saved_model.loader.load(
    signature = meta_graph_def.signature_def
    x_tensor_name = signature[signature_key].inputs[input_key].name
    y_tensor_name = signature[signature_key].outputs[output_key].name
    x = sess.graph.get_tensor_by_name(x_tensor_name)
    y = sess.graph.get_tensor_by_name(y_tensor_name)
    y_out = sess.run(y, {x: 3.0})
      October 14, 2020 4:09 PM IST
  • In (and after) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graph and tf.train.import_meta_graph according to https://www.tensorflow.org/programmers_guide/meta_graph.

    Save the model

    w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
    w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
    tf.add_to_collection('vars', w1)
    tf.add_to_collection('vars', w2)
    saver = tf.train.Saver()
    sess = tf.Session()
    saver.save(sess, 'my-model')
    # `save` method will call `export_meta_graph` implicitly.
    # you will get saved graph files:my-model.meta

    Restore the model

    sess = tf.Session()
    new_saver = tf.train.import_meta_graph('my-model.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./'))
    all_vars = tf.get_collection('vars')
    for v in all_vars:
        v_ = sess.run(v)
      September 15, 2020 11:27 AM IST
  • For TensorFlow version < 0.11.0RC1:

    The checkpoints that are saved contain values for the Variables in your model, not the model/graph itself, which means that the graph should be the same when you restore the checkpoint.

    Here's an example for a linear regression where there's a training loop that saves variable checkpoints and an evaluation section that will restore variables saved in a prior run and compute predictions. Of course, you can also restore variables and continue training if you'd like.

    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)
    w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
    b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
    y_hat = tf.add(b, tf.matmul(x, w))
    ...more setup for optimization and what not...
    saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
    with tf.Session() as sess:
        if FLAGS.train:
            for i in xrange(FLAGS.training_steps):
                ...training loop...
                if (i + 1) % FLAGS.checkpoint_steps == 0:
                    saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
            # Here's where you're restoring the variables w and b.
            # Note that the graph is exactly as it was when the variables were
            # saved in a prior training run.
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                ...no checkpoint found...
            # Now you can run the model to get predictions
            batch_x = ...load some data...
            predictions = sess.run(y_hat, feed_dict={x: batch_x})​

    Here are the docs for Variables, which cover saving and restoring. And here are the docs for the Saver.

      September 15, 2020 11:34 AM IST
  • There are two parts to the model, the model definition, saved by Supervisor as graph.pbtxt in the model directory and the numerical values of tensors, saved into checkpoint files like model.ckpt-1003418.

    The model definition can be restored using tf.import_graph_def, and the weights are restored using Saver.

    However, Saver uses special collection holding list of variables that's attached to the model Graph, and this collection is not initialized using import_graph_def, so you can't use the two together at the moment (it's on our roadmap to fix). For now, you have to use approach of Ryan Sepassi -- manually construct a graph with identical node names, and use Saver to load the weights into it.

    (Alternatively you could hack it by using by using import_graph_def, creating variables manually, and using tf.add_to_collection(tf.GraphKeys.VARIABLES, variable) for each variable, then using Saver)
      September 15, 2020 11:39 AM IST