QBoard » Artificial Intelligence & ML » AI and ML - Tensorflow » How to get the global_step when restoring checkpoints in Tensorflow?

How to get the global_step when restoring checkpoints in Tensorflow?

  • I'm saving my session state like so:

    self._saver = tf.saver()
    self._saver.save(self._session, '/network', global_step=self._time)​

    When I later restore I want to get the value of the global_step for the checkpoint I restore from. This is in order to set some hyper parameters from it.

    The hacky way to do this would be to run through and parse the file names in the checkpoint directory. But surly there has to be a better, built in way to do this?

    This post was edited by Viaan Prakash at September 4, 2020 5:08 PM IST
      September 4, 2020 5:07 PM IST
    0
  • General pattern is to have a global_step variable to keep track of steps

    global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = optimizer.minimize(loss, global_step=global_step)​

    Then you can save with

    saver.save(sess, save_path, global_step=global_step)​

    When you restore, the value of global_step is restored as well
      September 4, 2020 5:10 PM IST
    0
  • You can use the global_step variable to keep track of steps, but if in your code, you are initializing or assigning this value to another step variable, it may not be consistent.

    For instance, you define your global_step using:

    global_step = tf.Variable(0, name='global_step', trainable=False)​

    Assign to your training operation:

    train_op = optimizer.minimize(loss, global_step=global_step)​

    Save in your checkpoint:

    saver.save(sess, checkpoint_path, global_step=global_step)​

    And restore from your checkpoint:

    saver.restore(sess, checkpoint_path) 
    ​

    the value of global_step is restored as well but if you are assigning it to another variable, say step, then you must do something like:

    step = global_step.eval(session=sess)​

    The variable step, contains the last saved global_step in the checkpoint.

    It will be nice to also define the global_step from graph than as zero variable (as earlier defined):

    global_step = tf.train.get_or_create_global_step()​

    This will get your last global_step if exist or create one if not.
      September 4, 2020 5:15 PM IST
    0
  • The reason that a variable is not restored as expected is most likely due to the fact that it was created after your tf.Saver() object was created.

    The place where you create the tf.Saver() object matters when you don't explicitly specify a var_list, or specify None for var_list. The expected behavior for many programmers is that all variables in the graph are saved when the save() method is called, but this is not the case, and it should perhaps be documented as such. A snapshot of all variables in the graph is saved at the time of object creation.

    Unless you're having any performance issues, it's safest to create the saver object right when you decide to save your progress. Otherwise, make sure to create the saver object after you create all your variables.

    Also, the global_step that is passed to saver.save(sess, save_path, global_step=global_step) is merely a counter used for creating the filename and has nothing to do with whether it will be restored as a global_step variable. This is a parameter misnomer IMO since if you're saving your progress at the end of each epoch, it's probably best to pass your epoch number for this parameter.
      September 4, 2020 5:18 PM IST
    0
  • I had the same issue as Lawrence Du, I could not find a way to get the global_step by restoring the model. So I applied his hack to the inception v3 training code in the Tensorflow/models github repo I'm using. The code below also contains a fix related to the pretrained_model_checkpoint_path.

    If you have a better solution, or know what I'm missing please leave a comment!

    In any case, this code works for me:

    ...
    
    # When not restoring start at 0
    last_step = 0
    if FLAGS.pretrained_model_checkpoint_path:
    # A model consists of three files, use the base name of the model in
    # the checkpoint path. E.g. my-model-path/model.ckpt-291500
    #
    # Because we need to give the base name you can't assert (will always fail)
    # assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
    
    variables_to_restore = tf.get_collection(
    slim.variables.VARIABLES_TO_RESTORE)
    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
    print('%s: Pre-trained model restored from %s' %
    (datetime.now(), FLAGS.pretrained_model_checkpoint_path))
    
    # HACK : global step is not restored for some unknown reason
    last_step = int(os.path.basename(FLAGS.pretrained_model_checkpoint_path).split('-')[1])
    
    # assign to global step
    sess.run(global_step.assign(last_step))
    
    ...
    
    for step in range(last_step + 1, FLAGS.max_steps):
    
    ...​
      September 4, 2020 5:26 PM IST
    0