I'm saving my session state like so:
self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)
When I later restore I want to get the value of the global_step for the checkpoint I restore from. This is in order to set some hyper parameters from it.
The hacky way to do this would be to run through and parse the file names in the checkpoint directory. But surly there has to be a better, built in way to do this?
This post was edited by Viaan Prakash at September 4, 2020 5:08 PM ISTglobal_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
saver.save(sess, save_path, global_step=global_step)
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
saver.save(sess, checkpoint_path, global_step=global_step)
saver.restore(sess, checkpoint_path)
step = global_step.eval(session=sess)
global_step = tf.train.get_or_create_global_step()
...
# When not restoring start at 0
last_step = 0
if FLAGS.pretrained_model_checkpoint_path:
# A model consists of three files, use the base name of the model in
# the checkpoint path. E.g. my-model-path/model.ckpt-291500
#
# Because we need to give the base name you can't assert (will always fail)
# assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
variables_to_restore = tf.get_collection(
slim.variables.VARIABLES_TO_RESTORE)
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
print('%s: Pre-trained model restored from %s' %
(datetime.now(), FLAGS.pretrained_model_checkpoint_path))
# HACK : global step is not restored for some unknown reason
last_step = int(os.path.basename(FLAGS.pretrained_model_checkpoint_path).split('-')[1])
# assign to global step
sess.run(global_step.assign(last_step))
...
for step in range(last_step + 1, FLAGS.max_steps):
...