最近学习了CRNN网络,大体训练流程如下:
1、准备输入数据和标签,标签为稀疏矩阵
inputs = tf.placeholder(tf.float32, [batch_size, input_height, input_width, 1]) label = tf.sparse_placeholder(tf.int32, name='label') seq_len = tf.placeholder(tf.int32, [None], name='seq_len')2、通过CNN网络提取特征
cnn_out = self._cnn(inputs)3、通过2次双向RNN,得到神经网络输出结果
crnn_model = self._rnn(cnn_out, self._seq_len)4、根据最终字符的类别得到最终的输出
logits = tf.reshape(crnn_model, [-1, 512]) W = tf.Variable(tf.truncated_normal([512, self._class_num], stddev=0.1), name="W") b = tf.Variable(tf.constant(0., shape=[self._class_num]), name="b") logits = tf.matmul(logits, W) + b logits = tf.reshape(logits, [self._batch_size, -1, self._class_num]) # 网络层输出 net_output = tf.transpose(logits, (1, 0, 2))5、解析网络输出,其中decoded[0]是一个稀疏张量,类型和label一样
decoded, log_prob = tf.nn.ctc_greedy_decoder(net_output, self._seq_len)6、损失函数loss
with tf.name_scope('loss'): loss = tf.nn.ctc_loss(self._label, self._net_output, self._seq_len) loss = tf.reduce_mean(loss)7、优化器optimizer
with tf.name_scope('optimizer'): train_op = tf.train.AdamOptimizer(self._learning_rate).minimize(loss)8、准确率accuracy
with tf.name_scope('accuracy'): accuracy = 1 - tf.reduce_mean(tf.edit_distance(tf.cast(self._decoded[0], tf.int32), self._label)) accuracy_broad = tf.summary.scalar("accuracy", accuracy)9、喂数据进行训练
feed_dict = {self._inputs: batch_data,self._label: batch_label, \ self._seq_len: [self._max_char_count] * self.batch_size} sess.run(train_op, feed_dict=feed_dict)