使用tensorflow训练模型时需要保存训练后得到的模型,并在测试时加载模型。
Tensorflow中可以使用Saver类进行参数保存。
保存参数是可以选择的,如果不传参就是保存所有参数。
weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.save(sess,'model.ckpt')
保存所有参数:
saver_out=tf.train.Saver()
saver_out.save(sess,'file_name')
注意这里的file_name需要是绝对路径,还有人说不能和代码在一个路径里面(我没试过),另外我自己操作的时候报错PermissionDenied我以为没有写权限,然后查了一下发现是在定义路径时候文件名前必须加上“./”
checkpoint_path = os.path.join(logs_train_dir, './thing.ckpt')
运行完上面的代码之后,我们会发现在当前的程序目录下产生四个文件checkpoint、model.ckpt.data-00000-of-00001、model.ckpt.index、model.ckpt.meta。会产生四个文件的原因,之前有介绍过TensorFlow的程序是由计算图所组成的,所以在持久化的时候TensorFlow会将计算图的结果和图上的参数值分成不同的文件进行保存。
恢复参数:
(这个恢复参数要注意,model_filename是你要恢复的模型文件,整段代码的意思是从model_filename文件里只恢复weight的这些参数,如果model_filename里面没有这些参数,则报错。)
weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.restore(sess, model_filename)