QBoard » Artificial Intelligence & ML » AI and ML - Conceptual » How to save/restore a model after training?

How to save/restore a model after training?

  •   August 23, 2021 8:02 PM IST
    0
  • In (and after) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graph and tf.train.import_meta_graph according to https://www.tensorflow.org/programmers_guide/meta_graph.
    Save the model
    w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1') w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2') tf.add_to_collection('vars', w1) tf.add_to_collection('vars', w2) saver = tf.train.Saver() sess = tf.Session() sess.run(tf.global_variables_initializer()) saver.save(sess, 'my-model') # `save` method will call `export_meta_graph` implicitly. # you will get saved graph files:my-model.meta
    Restore the model
    sess = tf.Session() new_saver = tf.train.import_meta_graph('my-model.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./')) all_vars = tf.get_collection('vars') for v in all_vars: v_ = sess.run(v) print(v_)
      August 24, 2021 4:37 PM IST
    0
  • From the docs:
    
    Save
    # Create some variables.
    v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
    v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
    
    inc_v1 = v1.assign(v1+1)
    dec_v2 = v2.assign(v2-1)
    
    # Add an op to initialize the variables.
    init_op = tf.global_variables_initializer()
    
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    
    # Later, launch the model, initialize the variables, do some work, and save the
    # variables to disk.
    with tf.Session() as sess:
      sess.run(init_op)
      # Do some work with the model.
      inc_v1.op.run()
      dec_v2.op.run()
      # Save the variables to disk.
      save_path = saver.save(sess, "/tmp/model.ckpt")
      print("Model saved in path: %s" % save_path)
      August 24, 2021 11:44 PM IST
    0
  • I am improving my answer to add more details for saving and restoring models.

    In(and after) Tensorflow version 0.11:

    Save the model:

    import tensorflow as tf
    
    #Prepare to feed input, i.e. feed_dict and placeholders
    w1 = tf.placeholder("float", name="w1")
    w2 = tf.placeholder("float", name="w2")
    b1= tf.Variable(2.0,name="bias")
    feed_dict ={w1:4,w2:8}
    
    #Define a test operation that we will restore
    w3 = tf.add(w1,w2)
    w4 = tf.multiply(w3,b1,name="op_to_restore")
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    #Create a saver object which will save all the variables
    saver = tf.train.Saver()
    
    #Run the operation by feeding input
    print sess.run(w4,feed_dict)
    #Prints 24 which is sum of (w1+w2)*b1 
    
    #Now, save the graph
    saver.save(sess, 'my_test_model',global_step=1000)​

     

    Restore the model:

    import tensorflow as tf
    
    sess=tf.Session()    
    #First let's load meta graph and restore weights
    saver = tf.train.import_meta_graph('my_test_model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    
    
    # Access saved Variables directly
    print(sess.run('bias:0'))
    # This will print 2, which is the value of bias that we saved
    
    
    # Now, let's access and create placeholders variables and
    # create feed-dict to feed new data
    
    graph = tf.get_default_graph()
    w1 = graph.get_tensor_by_name("w1:0")
    w2 = graph.get_tensor_by_name("w2:0")
    feed_dict ={w1:13.0,w2:17.0}
    
    #Now, access the op that you want to run. 
    op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
    
    print sess.run(op_to_restore,feed_dict)
    #This will print 60 which is calculated 

     

    This and some more advanced use-cases have been explained very well here.

    A quick complete tutorial to save and restore Tensorflow models

     

      August 25, 2021 2:45 PM IST
    0
  • As Yaroslav said, you can hack restoring from a graph_def and checkpoint by importing the graph, manually creating variables, and then using a Saver.
    I implemented this for my personal use, so I though I'd share the code here.
    (This is, of course, a hack, and there is no guarantee that models saved this way will remain readable in future versions of TensorFlow.)
      August 25, 2021 11:21 PM IST
    0