estimator是tensorflow高度封装的一个类,里面有一些可以直接使用的分类和回归模型,例如tf.estimator.DNNClassifier,但这不是这篇博客的主题,而是怎么使用estimator来实现我们自定义模型的训练。它的步骤主要分为以下几个部分:
构建model_fn,在这个方法里面定义自己的模型以及训练和测试过程要做的事情;构建input_fn,在这个方法数据的来源和喂给模型的方式;最后,创建estimator对象,然后开始训练模型了。可以添加一些config,比如:loss的输出频率等。
构建model_fn方法
import tensorflow
as tf
def model_fn(features
, labels
, mode
, params
):
lr
= params
['lr']
try:
init_checkpoint
= params
['init_checkpoint']
except KeyError
:
init_checkpoint
= None
x
= features
['inputs']
y
= features
['labels']
pre
= tf
.layers
.dense
(x
, 1)
loss
= tf
.reduce_mean
(tf
.pow(pre
-y
, 2), name
='loss')
assignment_map
= dict()
if init_checkpoint
:
for var
in tf
.train
.list_variables
(init_checkpoint
):
assignment_map
[var
[0]] = var
[0]
tf
.train
.init_from_checkpoint
(init_checkpoint
, assignment_map
)
if mode
== tf
.estimator
.ModeKeys
.TRAIN
:
optimizer
= tf
.train
.AdamOptimizer
(lr
)
train_op
= optimizer
.minimize
(loss
, global_step
=tf
.train
.get_global_step
())
output_spec
= tf
.estimator
.EstimatorSpec
(mode
, loss
=loss
, train_op
=train_op
)
elif mode
== tf
.estimator
.ModeKeys
.EVAL
:
metrics
= {'eval_loss': loss
}
output_spec
= tf
.estimator
.EstimatorSpec
(mode
, loss
=loss
, eval_metric_ops
=metrics
)
elif mode
== tf
.estimator
.ModeKeys
.PREDICT
:
predictions
= {'predictions': pre
}
output_spec
= tf
.estimator
.EstimatorSpec
(mode
, predictions
=predictions
)
else:
raise TypeError
return output_spec
提几点需要注意的地方:
model_fn方法返回的是tf.estimator.EstimatorSpec;TRAIN、EVAL和PREDICT模式不可缺少的参数是不一样的。
构建input_fn方法
def input_fn_bulider(inputs_file
, batch_size
, is_training
):
name_to_features
= {'inputs': tf
.FixedLenFeature
([3], tf
.float32
),
'labels': tf
.FixedLenFeature
([], tf
.float32
)}
def input_fn(params
):
d
= tf
.data
.TFRecordDataset
(inputs_file
)
if is_training
:
d
= d
.repeat
()
d
= d
.shuffle
()
d
= d
.apply(tf
.contrib
.data
.map_and_batch
(lambda x
: tf
.parse_single_example
(x
, name_to_features
),
batch_size
=batch_size
))
return d
return input_fn
执行eatimator
if __name__
== '__main':
tf
.logging
.set_verbosity
(tf
.logging
.INFO
)
runConfig
= tf
.estimator
.RunConfig
(save_checkpoints_steps
=1,
log_step_count_steps
=1)
estimator
= tf
.estimator
.Estimator
(model_fn
, model_dir
='your_save_path',
config
=runConfig
, params
={'lr': 0.01})
logging_hook
= tf
.train
.LoggingTensorHook
(every_n_iter
=1,
tensors
={'loss': 'loss'})
input_fn
= input_fn_bulider
('test.tfrecord', batch_size
=1, is_training
=True)
estimator
.train
(input_fn
, max_steps
=1000)
欢迎关注同名公众号:“我就算饿死也不做程序员”。 交个朋友,一起交流,一起学习,一起进步。