tensorflow 从最近一次 checkpoint 加载模型last

    xiaoxiao2025-05-01  17

    import tensorflow as tf tf.reset_default_graph() global_step = tf.Variable(1, name="global_step") global_step1 = tf.Variable(1, name="global_step1") add_0 = global_step + global_step1 ckpt_path = r"E:\temp\global_model" sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) saver = tf.train.Saver() begin = global_step.eval(session=sess) for i in range(begin, 50): #global_step.assign(i).eval(session=sess) sess.run(global_step.assign_add(i)) print(i, " : ", sess.run(global_step), sess.run(add_0)) saver.save(sess, ckpt_path, global_step=global_step) import tensorflow as tf tf.reset_default_graph() global_step = tf.Variable(1, name="global_step") global_step1 = tf.Variable(1, name="global_step1") add_0 = global_step + global_step1 sess = tf.Session() saver = tf.train.Saver() model_file = tf.train.latest_checkpoint(r"E:\temp") saver.restore(sess, model_file) print(sess.run(global_step)) print(sess.run(add_0))

     

    最新回复(0)