TensorFlow基础:网络结构的保存

    xiaoxiao2022-07-07  200

    保存当前参数最好时的参数 saver_ckpt = tf.train.Saver(tf.trainable_variables()) if max_ndcg < ndcg: max_ndcg = ndcg max_res = cur_res saver_ckpt.save(sess, ckpt_save_path+ckpt_save_file, global_step=epoch_count) print("saved best", epoch_count) 加载该模型 ckpt_save_path = "Pretrain/dropout/%s_embed_%s/" % (args.dataset, "_".join(map(str, model.nc))) # '[32,32,32,32,32,32]' saver_ckpt = tf.train.Saver(tf.trainable_variables()) #通过checkpoint获取最新的文件 ckpt = tf.train.get_checkpoint_state(ckpt_save_path) saver_ckpt.restore(sess, ckpt.model_checkpoint_path)
    最新回复(0)